我有一段复杂的Python代码,涉及使用32位数值(为了节省内存和带宽)。但后来我发现,在一些高级函数中,许多32位数字被隐式地转换为64位。例如sum函数,默认情况下,可以将32位数组转换为64位数字。
In [152]: x32
Out[152]:
array([ 0. , 1.010101, 2.020202, 3.030303, 4.040404,
5.050505, 6.060606, 7.070707, 8.080808, 9.090909,
10.10101 , 11.111111, 12.121212, 13.131313, 14.141414,
15.151515, 16.161615, 17.171717, 18.181818, 19.19192 ,
20.20202 , 21.212122, 22.222221, 23.232323, 24.242424,
25.252525, 26.262627, 27.272728, 28.282827, 29.292929,
30.30303 , 31.313131, 32.32323 , 33.333332, 34.343433,
35.353535, 36.363636, 37.373737, 38.38384 , 39.39394 ,
40.40404 , 41.414143, 42.424244, 43.434345, 44.444443,
45.454544, 46.464645, 47.474747, 48.484848, 49.49495 ,
50.50505 , 51.515152, 52.525253, 53.535355, 54.545456,
55.555557, 56.565655, 57.575756, 58.585857, 59.59596 ,
60.60606 , 61.61616 , 62.626263, 63.636364, 64.64646 ,
65.65656 , 66.666664, 67.676765, 68.68687 , 69.69697 ,
70.70707 , 71.71717 , 72.72727 , 73.73737 , 74.747475,
75.757576, 76.76768 , 77.77778 , 78.78788 , 79.79798 ,
80.80808 , 81.818184, 82.828285, 83.83839 , 84.84849 ,
85.85859 , 86.86869 , 87.878784, 88.888885, 89.89899 ,
90.90909 , 91.91919 , 92.92929 , 93.93939 , 94.94949 ,
95.959595, 96.969696, 97.9798 , 98.9899 , 100. ],
dtype=float32)
In [153]: sum(x32)
Out[153]: 4999.999972701073
In [154]: type(sum(x32))
Out[154]: numpy.float64
在这种情况下sum(x32)是64位的原因应该来自sum的默认累加器,0,如下所示:
In [156]: type(sum(x32, start=np.float32(0)))
Out[156]: numpy.float32
上面,我使用sum函数作为例子,来解释如果我使用32位作为输入,类型转换无处不在。我已经更改了sum
部分,以避免这种隐式类型转换。但我想知道,如果在我的库调用内部,有任何其他意外的32位->64位转换。是否有一个通用的编程语言解决方案来监控任何可能的类型转换?例如,我可以用一些特殊的调试工具运行我的python代码,以便从32位到64位的任何类型转换都会触发警报或被记录吗?
说实话,我觉得你快成功了。
original_dtype = x32.dtype
new_dtype = sum(x32, start=np.float32(0))).dtype
assert new_dtype == original_dtype, f"dtypes differ, {new_dtype=} != {original_dtype=}"
要全局使用此方法,可以这样写:
def type_checker_func(func,input_array,*args):
dtype_orig = input_array.dtype
result = func(input_array,*args)
dtype_new = result.dtype
if dtype_new != dtype_orig:
print(f"dtypes differ, {dtype_new=} != {dtype_orig=}")
return result
my_answer = type_checker_func(sum,x32,start=np.float32(0))
但是我不确定你如何最好地处理多个返回值(考虑np.histogram
),各种参数等。
我也不确定如何全局/隐式地调用type_checker_func
(如果仅用于numpy fns)。
更新:我发布了一个github问题,询问如何使用line_profiler为每个函数调用执行此操作-参见https://github.com/pyutils/line_profiler/issues/188 -手指交叉。