Tensorflow py_function设置返回值shape



我正在字符串中设置从tf.py_function调用返回的numpy数组的形状。函数返回数组及其形状的元组,因此我可以在tf.function中设置其形状。但是,我不能在set_shape调用中使用返回的shape数组。我尝试将它强制转换为int32或通过索引等访问它,但所有操作都会导致TypeError: Dimension value must be integer or None or have an __index__ method, got value '<tf.Tensor 'strided_slice:0' shape=<unknown> dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'错误。

@tf.function
def samples(img, mask) -> tuple:
xs, xs_shape = tf.py_function(process,[128, 128], [tf.float32, tf.int32])
xs.set_shape(tf.TensorShape(xs_shape))
return xs
ds = ds.map(samples)

这是一个精心设计的例子。我想这可能是你的要求。

def f(x):
print("print ", x)
# Include Logic using 'x' which is a tensor.
# Sample  return values
return [tf.convert_to_tensor([1,2]), tf.convert_to_tensor([1,2]).shape]

@tf.function
def samples(x) -> tuple:
xs, xs_shape = f(x)
xs.set_shape(xs_shape)
print( xs_shape)
return xs
dataset = tf.data.Dataset.range(1, 6)
dataset = dataset.map(  samples )

最新更新