更新类的属性时清除某些方法的lru_cache?



我有一个带有方法/属性multiplier的对象。此方法在我的程序中被调用了很多次,所以我决定在其上使用lru_cache()来提高执行速度。正如预期的那样,它要快得多:

以下代码显示了该问题:

from functools import lru_cache
class MyClass(object):
def __init__(self):
self.current_contract = 201706
self.futures = {201706: {'multiplier': 1000},
201712: {'multiplier': 25}}
@property
@lru_cache()
def multiplier(self):
return self.futures[self.current_contract]['multiplier']
CF = MyClass()
assert CF.multiplier == 1000
CF.current_contract = 201712
assert CF.multiplier == 25

第 2assert失败,因为缓存的值为 1000,因为lru_cache()不知道基础属性current_contract已更改。

有没有办法在更新self.current_contract时清除缓存?

谢谢!

是的,很简单:使current_contract成为读/写属性并清除属性的 setter 中的缓存:

from functools import lru_cache
class MyClass(object):
def __init__(self):
self.futures = {201706: {'multiplier': 1000},
201712: {'multiplier': 25}}
self.current_contract = 201706
@property
def current_contract(self):
return self._current_contract
@current_contract.setter
def current_contract(self, value):
self._current_contract = value
type(self).multiplier.fget.cache_clear()
@property
@lru_cache()
def multiplier(self):
return self.futures[self.current_contract]['multiplier']

注意:我假设您的实际用例涉及昂贵的计算,而不仅仅是字典查找 - 否则lru_cache可能有点矫枉过正;)

简答题

更新self.current_contract时不要清除缓存。 这是对缓存起作用并丢弃信息。

相反,只需添加用于__eq____hash__的方法。 这将告诉缓存(或任何其他映射(哪些属性对于影响结果很重要。

计算出的示例

在这里,我们将__eq____hash__添加到您的代码中。 这告诉缓存(或任何其他映射(current_contract是相关的自变量:

from functools import lru_cache
class MyClass(object):
def __init__(self):
self.current_contract = 201706
self.futures = {201706: {'multiplier': 1000},
201712: {'multiplier': 25}}
def __hash__(self):
return hash(self.current_contract)
def __eq__(self, other):
return self.current_contract == other.current_contract
@property
@lru_cache()
def multiplier(self):
return self.futures[self.current_contract]['multiplier']

一个直接的优点是,当您在合同编号之间切换时,之前的结果会保存在缓存中。 尝试在 201706 和 201712 之间切换一百次,您将获得 98 次缓存命中和 2 次缓存未命中:

cf = MyClass()
for i in range(50):
cf.current_contract = 201712
assert cf.multiplier == 25
cf.current_contract = 201706 
assert cf.multiplier == 1000
print(vars(MyClass)['multiplier'].fget.cache_info())

这将打印:

CacheInfo(hits=98, misses=2, maxsize=128, currsize=2)

相关内容

  • 没有找到相关文章

最新更新