numpy与python:修补round函数



我有一个文件,它定义了

def foo(x):
return round(x)

这里的round是python的built-in函数。

现在,我想用一个numpy数组来调用这个函数。Numpy也有一个round函数。不幸的是,存在问题(例如#111557(

import numpy as np
foo(7.6)
foo(np.array([7.6]))       # works with python2 + np_1.14.0, not with 3.6.9 + np_1.14.5
foo(np.array([7.6, 8.9]))  # TypeError: type numpy.ndarray doesn't define __round__ method

是否有可能在foo函数中用np.round替换函数round?导入前的round = np.round或导入后的foo.round = np.round之类的补丁?

编辑:我正在寻找一个不修改文件的解决方案。

是的,当然可以。只需使用静态检查x的类即可

import numpy as np
def foo(x):
if isinstance(x, np.ndarray):
return x.round()
return round(x)

print(foo(7.6))
print(foo(np.array([7.6])))
print(foo(np.array([7.6, 8.9])))

子类是一种变通方法(正如@Anton Pomieshchenko所评论的(

class Array(np.ndarray):
def __new__(cls, array, **kwargs):
return np.asarray(array, **kwargs).view(cls)
def __round__(self):
return np.round(self.data)
foo(Array([7.6, 8.9]))

适用于python 3,而在python 2中我得到了TypeError: only size-1 arrays can be converted to Python scalars

我也尝试了np.__round__ = np.round(TypeError:只有size-1数组可以转换为Python标量(或np.ndarray.__round__ = np.round(TypeError:无法设置内置/扩展类型"numpy.ndarray"的属性(,但没有成功。

这需要修改文件,但文件仍然独立于numpy。修改后的文件bar.py

def foo(x):
return foo.round(x)
foo.round = round

然后导入后可以替换round功能

from bar import foo
import numpy as np
foo(7.6)
foo.round = np.round
foo([7.6, 8.9])       # 

灵感来自https://stackoverflow.com/a/338420/11769765.

或者,使用

from builtins import round
def foo(x):
return round(x)

对于bar.py,使圆形方法可访问

from bar import foo
import bar
import numpy as np
foo(7.6)           # 8
bar.round = np.round
foo([7.6, 8.9])    # array([8., 9.])

最新更新