当参数保持不变时,最大限度地减少代价高昂的函数调用次数(python)



假设有一个函数costly_function_a(x)使得:

  1. 就执行时间而言,它是非常昂贵的
  2. 每当相同的CCD_ 2被馈送到它时,它就返回相同的输出;以及
  3. 除了返回输出之外,它不执行"附加任务"

在这些条件下,我们可以将结果存储在一个临时变量中,然后使用该变量进行这些计算,而不是用相同的x连续调用函数两次。

现在假设有一些函数(以下示例中的f(x)g(x)h(x))调用costly_function_a(x),并且这些函数中的一些可以相互调用(在以下示例中,g(x)h(x)都调用f(x))。在这种情况下,使用上面提到的简单方法仍然会导致对具有相同xcostly_function_a(x)的重复调用(参见下面的OkayVersion)。我确实找到了一种方法来最大限度地减少呼叫次数,但这是"丑陋的"(请参阅下面的FastVersion)。有什么更好的方法吗?

#Dummy functions representing extremely slow code.
#The goal is to call these costly functions as rarely as possible.
def costly_function_a(x):
    print("costly_function_a has been called.")
    return x #Dummy operation.
def costly_function_b(x):
    print("costly_function_b has been called.")
    return 5.*x #Dummy operation.
#Simplest (but slowest) implementation.
class SlowVersion:
    def __init__(self,a,b):
        self.a = a
        self.b = b
    def f(self,x): #Dummy operation.
        return self.a(x) + 2.*self.a(x)**2
    def g(self,x): #Dummy operation.
        return self.f(x) + 0.7*self.a(x) + .1*x
    def h(self,x): #Dummy operation.
        return self.f(x) + 0.5*self.a(x) + self.b(x) + 3.*self.b(x)**2
#Equivalent to SlowVersion, but call the costly functions less often.
class OkayVersion:
    def __init__(self,a,b):
        self.a = a
        self.b = b
    def f(self,x): #Same result as SlowVersion.f(x)
        a_at_x = self.a(x)
        return a_at_x + 2.*a_at_x**2
    def g(self,x): #Same result as SlowVersion.g(x)
        return self.f(x) + 0.7*self.a(x) + .1*x
    def h(self,x): #Same result as SlowVersion.h(x)
        a_at_x = self.a(x)
        b_at_x = self.b(x)
        return self.f(x) + 0.5*a_at_x + b_at_x + 3.*b_at_x**2
#Equivalent to SlowVersion, but calls the costly functions even less often.
#Is this the simplest way to do it? I am aware that this code is highly
#redundant. One could simplify it by defining some factory functions...
class FastVersion:
    def __init__(self,a,b):
        self.a = a
        self.b = b
    def f(self, x, _at_x=None): #Same result as SlowVersion.f(x)
        if _at_x is None:
            _at_x = dict()
        if 'a' not in _at_x:
            _at_x['a'] = self.a(x)
        return _at_x['a'] + 2.*_at_x['a']**2
    def g(self, x, _at_x=None): #Same result as SlowVersion.g(x)
        if _at_x is None:
            _at_x = dict()
        if 'a' not in _at_x:
            _at_x['a'] = self.a(x)
        return self.f(x,_at_x) + 0.7*_at_x['a'] + .1*x
    def h(self,x,_at_x=None): #Same result as SlowVersion.h(x)
        if _at_x is None:
            _at_x = dict()
        if 'a' not in _at_x:
            _at_x['a'] = self.a(x)
        if 'b' not in _at_x:
            _at_x['b'] = self.b(x)
        return self.f(x,_at_x) + 0.5*_at_x['a'] + _at_x['b'] + 3.*_at_x['b']**2
if __name__ == '__main__':
    slow = SlowVersion(costly_function_a,costly_function_b)
    print("Using slow version.")
    print("f(2.) = " + str(slow.f(2.)))
    print("g(2.) = " + str(slow.g(2.)))
    print("h(2.) = " + str(slow.h(2.)) + "n")
    okay = OkayVersion(costly_function_a,costly_function_b)
    print("Using okay version.")
    print("f(2.) = " + str(okay.f(2.)))
    print("g(2.) = " + str(okay.g(2.)))
    print("h(2.) = " + str(okay.h(2.)) + "n")
    fast = FastVersion(costly_function_a,costly_function_b)
    print("Using fast version 'casually'.")
    print("f(2.) = " + str(fast.f(2.)))
    print("g(2.) = " + str(fast.g(2.)))
    print("h(2.) = " + str(fast.h(2.)) + "n")
    print("Using fast version 'optimally'.")
    _at_x = dict()
    print("f(2.) = " + str(fast.f(2.,_at_x)))
    print("g(2.) = " + str(fast.g(2.,_at_x)))
    print("h(2.) = " + str(fast.h(2.,_at_x)))
    #Of course, one must "clean up" _at_x before using a different x...

此代码的输出为:

Using slow version.
costly_function_a has been called.
costly_function_a has been called.
f(2.) = 10.0
costly_function_a has been called.
costly_function_a has been called.
costly_function_a has been called.
g(2.) = 11.6
costly_function_a has been called.
costly_function_a has been called.
costly_function_a has been called.
costly_function_b has been called.
costly_function_b has been called.
h(2.) = 321.0
Using okay version.
costly_function_a has been called.
f(2.) = 10.0
costly_function_a has been called.
costly_function_a has been called.
g(2.) = 11.6
costly_function_a has been called.
costly_function_b has been called.
costly_function_a has been called.
h(2.) = 321.0
Using fast version 'casually'.
costly_function_a has been called.
f(2.) = 10.0
costly_function_a has been called.
g(2.) = 11.6
costly_function_a has been called.
costly_function_b has been called.
h(2.) = 321.0
Using fast version 'optimally'.
costly_function_a has been called.
f(2.) = 10.0
g(2.) = 11.6
costly_function_b has been called.
h(2.) = 321.0

请注意,我不想"存储"过去使用的x的所有值的结果(因为这需要太多内存)。此外,我不希望有返回形式为(f,g,h)的元组的函数,因为在某些情况下我只想要f(因此不需要评估costly_function_b)。

您要查找的是LRU缓存;只缓存最近使用的项,从而限制内存使用,以平衡调用成本和内存需求。

当使用不同的x值调用代价高昂的函数时,会缓存多达多个返回值(每个唯一的x0值),当缓存已满时,会丢弃最近使用最少的缓存结果。

从Python 3.2开始,标准库附带了一个装饰器实现:@functools.lru_cache():

from functools import lru_cache
@lru_cache(16)  # cache 16 different `x` return values
def costly_function_a(x):
    print("costly_function_a has been called.")
    return x #Dummy operation.
@lru_cache(32)  # cache 32 different `x` return values
def costly_function_b(x):
    print("costly_function_b has been called.")
    return 5.*x #Dummy operation.

后台端口可用于早期版本,或者选择其他可用库中的一个来处理PyPI上可用的LRU缓存。

如果您只需要缓存一个最新项目,请创建自己的装饰器:

from functools import wraps
def cache_most_recent(func):
    cache = [None, None]
    @wraps(func)
    def wrapper(*args, **kw):
        if (args, kw) == cache[0]:
            return cache[1]
        cache[0] = args, kw
        cache[1] = func(*args, **kw)
        return cache[1]
    return wrapper
@cache_most_recent
def costly_function_a(x):
    print("costly_function_a has been called.")
    return x #Dummy operation.
@cache_most_recent
def costly_function_b(x):
    print("costly_function_b has been called.")
    return 5.*x #Dummy operation.

这种更简单的装饰器比更具特色的functools.lru_cache()具有更少的开销。

我接受@MartijnPieters的解决方案,因为对于99%会遇到类似问题的人来说,这可能是正确的方法。然而,在我的特殊情况下,我只需要一个"缓存1",所以花哨的@lru_cache(1)装饰器有点过头了。我最终写了自己的装饰器(多亏了这个很棒的stackoverflow答案),我在下面提供了它。请注意,我是Python的新手,所以这段代码可能并不完美。

from functools import wraps
def last_cache(func):
    """A decorator caching the last value returned by a function.
    If the decorated function is called twice (or more) in a row with exactly
    the same parameters, then this decorator will return a cached value of the
    decorated function's last output instead of calling it again. This may
    speed up execution if the decorated function is costly to call.
    The decorated function must respect the following conditions:
    1.  Repeated calls return the same value if the same parameters are used.
    2.  The function's only "task" is to return a value.
    """
    _first_call = [True]
    _last_args = [None]
    _last_kwargs = [None]
    _last_value = [None]
    @wraps(func)
    def _last_cache_wrapper(*args, **kwargs):
        if _first_call[0] or (args!=_last_args[0]) or (kwargs!=_last_kwargs[0]):
            _first_call[0] = False
            _last_args[0] = args
            _last_kwargs[0] = kwargs
            _last_value[0] = func(*args, **kwargs)
        return _last_value[0]
    return _last_cache_wrapper

最新更新