如何在tf.Tensor中重塑和填充零



下面的代码试图将张量转换为Tensorflow中的(x,y)维数组。

"a"可以转换为"b"通过使用这个代码,但是"不能。

下面是测试代码:
def reshape_array(old_array, x, y):
new_array = tf.reshape(old_array, [-1])

current_size = tf.size(new_array)
reshape_size = tf.math.multiply(x, y)

diff = tf.math.subtract(reshape_size, current_size)
if tf.greater_equal(diff, tf.constant([0])):
new_array = tf.pad(new_array, [[0,0],[0, diff]], mode='CONSTANT', constant_values=0)
new_array = tf.reshape(new_array, (x, y))
else:
new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
new_array = tf.reshape(new_array, (x, y))

return tf.cast(new_array, old_array.dtype)
a = tf.zeros(256*192*1)
print("a.shape: {}".format(a.shape))
b = reshape_array(a, 28, 28)
print("b.shape: {}".format(b.shape))
c = tf.constant([1, 2, 3, 4, 5, 6])
print("c.shape: {}".format(c.shape))
d = reshape_array(c, 28, 28)
print("d.shape: {}".format(d.shape))

输出如下:

a.shape: (49152,)
b.shape: (28, 28)
c.shape: (6,)
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/tmp/ipykernel_7071/4036910860.py in <cell line: 26>()
24 c = tf.constant([1, 2, 3, 4, 5, 6])
25 print("c.shape: {}".format(c.shape))
---> 26 d = reshape_array(c, 28, 28)
27 print("d.shape: {}".format(d.shape))
/tmp/ipykernel_7071/4036910860.py in reshape_array(old_array, x, y)
9     diff = tf.math.subtract(reshape_size, current_size)
10     if tf.greater_equal(diff, tf.constant([0])):
---> 11         new_array = tf.pad(new_array, [[0,0],[0, diff]], mode='CONSTANT', constant_values=0)
12         new_array = tf.reshape(new_array, (x, y))
13     else:
/usr/local/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
151     except Exception as e:
152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
154     finally:
155       del filtered_tb
/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
52   try:
53     ctx.ensure_initialized()
---> 54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
55                                         inputs, attrs, num_outputs)
56   except core._NotOkStatusException as e:
InvalidArgumentError: The first dimension of paddings must be the rank of inputs[2,2] [6] [Op:Pad]

我的代码有什么问题,如何修复?

你在第二个例子中使用的是1D张量,所以试试:

import tensorflow as tf
def reshape_array(old_array, x, y):
new_array = tf.reshape(old_array, [-1])

current_size = tf.size(new_array)
reshape_size = tf.math.multiply(x, y)

diff = tf.math.subtract(reshape_size, current_size)
if tf.greater_equal(diff, tf.constant([0])):
print(diff)
new_array = tf.pad(new_array, [[0, diff]], mode='CONSTANT', constant_values=0)
new_array = tf.reshape(new_array, (x, y))
else:
new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
new_array = tf.reshape(new_array, (x, y))

return tf.cast(new_array, old_array.dtype)
a = tf.zeros(256*192*1)
print("a.shape: {}".format(a.shape))
b = reshape_array(a, 28, 28)
print("b.shape: {}".format(b.shape))
c = tf.constant([1, 2, 3, 4, 5, 6])
print("c.shape: {}".format(c.shape))
d = reshape_array(c, 28, 28)
print("d.shape: {}".format(d.shape))

在您的情况下,我通常更喜欢使用tf.concat填充:

def reshape_array(old_array, x, y):
new_array = tf.reshape(old_array, [-1])

current_size = tf.size(new_array)
reshape_size = tf.math.multiply(x, y)

diff = tf.math.subtract(reshape_size, current_size)
if tf.greater_equal(diff, tf.constant([0])):
new_array = tf.concat([new_array, tf.repeat([0], repeats=diff)], axis=0)
new_array = tf.reshape(new_array, (x, y))
else:
new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
new_array = tf.reshape(new_array, (x, y))

return tf.cast(new_array, old_array.dtype)

最新更新