如何阻止 numpy 网格将默认数据类型设置为 int64



>我必须使用numpy网格创建一个非常大的网格。 为了节省内存,我使用 int8 作为我尝试网格化的数组的 dtype。 但是,meshgrid不断将类型更改为int64,这会占用大量内存。 这是问题的简单示例...

import numpy
grids = [numpy.arange(1, 4, dtype=numpy.int8), numpy.arange(1, 5, dtype=numpy.int8)]
print grids
print grids[0].dtype, grids[0].nbytes
x1, y1 = numpy.meshgrid(*grids)
print x1.dtype, x1.nbytes

此脚本打印

[array([1, 2, 3], dtype=int8), array([1, 2, 3, 4], dtype=int8)]
int8 3
int64 96

为什么网格会这样做? 有什么办法可以阻止它吗? 我需要创建一个巨大的数组,这样我就不能使用 meshgrid,除非我可以控制输出的数据类型。 这是预期行为还是 numpy 错误? 我在 numpy 中使用的所有其他函数都保留数据类型或允许您使用 dtype 参数更改它。 网格功能似乎不允许这样做。

您可以将可选的 copy 参数 numpy.meshgrid() 设置为 False(但请注意,它有一些约束(:

meshgrid(*xi, **kwargs)

copybool , 可选

如果False,则返回原始数组的视图,以便 节省内存。 默认值为 True 。 请注意,sparse=Falsecopy=False可能会返回不连续的数组。 此外 广播阵列的多个元素可以引用单个 内存位置。 如果需要写入数组,请制作副本 第一。

证明它有效:

>>> import numpy
>>> 
>>> grids = [numpy.arange(1, 4, dtype=numpy.int8), numpy.arange(1, 5, dtype=numpy.int8)]
>>> 
>>> print grids
[array([1, 2, 3], dtype=int8), array([1, 2, 3, 4], dtype=int8)]
>>> print grids[0].dtype, grids[0].nbytes
int8 3
>>>
>>> x1, y1 = numpy.meshgrid(*grids, copy=False)
>>>                        #        ^^^^^^^^^^
>>> print x1.dtype, x1.nbytes
int8 12

最新更新