检测python中意外的类型转换



我有一段复杂的Python代码,涉及使用32位数值(为了节省内存和带宽)。但后来我发现,在一些高级函数中,许多32位数字被隐式地转换为64位。例如sum函数,默认情况下,可以将32位数组转换为64位数字。

In [152]: x32
Out[152]:
array([  0.      ,   1.010101,   2.020202,   3.030303,   4.040404,
5.050505,   6.060606,   7.070707,   8.080808,   9.090909,
10.10101 ,  11.111111,  12.121212,  13.131313,  14.141414,
15.151515,  16.161615,  17.171717,  18.181818,  19.19192 ,
20.20202 ,  21.212122,  22.222221,  23.232323,  24.242424,
25.252525,  26.262627,  27.272728,  28.282827,  29.292929,
30.30303 ,  31.313131,  32.32323 ,  33.333332,  34.343433,
35.353535,  36.363636,  37.373737,  38.38384 ,  39.39394 ,
40.40404 ,  41.414143,  42.424244,  43.434345,  44.444443,
45.454544,  46.464645,  47.474747,  48.484848,  49.49495 ,
50.50505 ,  51.515152,  52.525253,  53.535355,  54.545456,
55.555557,  56.565655,  57.575756,  58.585857,  59.59596 ,
60.60606 ,  61.61616 ,  62.626263,  63.636364,  64.64646 ,
65.65656 ,  66.666664,  67.676765,  68.68687 ,  69.69697 ,
70.70707 ,  71.71717 ,  72.72727 ,  73.73737 ,  74.747475,
75.757576,  76.76768 ,  77.77778 ,  78.78788 ,  79.79798 ,
80.80808 ,  81.818184,  82.828285,  83.83839 ,  84.84849 ,
85.85859 ,  86.86869 ,  87.878784,  88.888885,  89.89899 ,
90.90909 ,  91.91919 ,  92.92929 ,  93.93939 ,  94.94949 ,
95.959595,  96.969696,  97.9798  ,  98.9899  , 100.      ],
dtype=float32)
In [153]: sum(x32)
Out[153]: 4999.999972701073
In [154]: type(sum(x32))
Out[154]: numpy.float64

在这种情况下sum(x32)是64位的原因应该来自sum的默认累加器,0,如下所示:

In [156]: type(sum(x32, start=np.float32(0)))
Out[156]: numpy.float32  

上面,我使用sum函数作为例子,来解释如果我使用32位作为输入,类型转换无处不在。我已经更改了sum部分,以避免这种隐式类型转换。但我想知道,如果在我的库调用内部,有任何其他意外的32位->64位转换。是否有一个通用的编程语言解决方案来监控任何可能的类型转换?例如,我可以用一些特殊的调试工具运行我的python代码,以便从32位到64位的任何类型转换都会触发警报或被记录吗?

说实话,我觉得你快成功了。

original_dtype = x32.dtype
new_dtype = sum(x32, start=np.float32(0))).dtype
assert new_dtype == original_dtype, f"dtypes differ, {new_dtype=} != {original_dtype=}"

要全局使用此方法,可以这样写:

def type_checker_func(func,input_array,*args):
dtype_orig = input_array.dtype
result = func(input_array,*args)
dtype_new = result.dtype
if dtype_new != dtype_orig:
print(f"dtypes differ, {dtype_new=} != {dtype_orig=}")
return result
my_answer = type_checker_func(sum,x32,start=np.float32(0))

但是我不确定你如何最好地处理多个返回值(考虑np.histogram),各种参数等。

我也不确定如何全局/隐式地调用type_checker_func(如果仅用于numpy fns)。

更新:我发布了一个github问题,询问如何使用line_profiler为每个函数调用执行此操作-参见https://github.com/pyutils/line_profiler/issues/188 -手指交叉。

最新更新