Tensorflow 2.3:属性错误:'Tensor'对象没有属性'numpy'



我想加载从这里借来的文本文件,其中每一行都表示一个json字符串,如下所示:

{"overall": 2.0, "verified": true, "reviewTime": "02 4, 2014", "reviewerID": "A1M117A53LEI8", "asin": "7508492919", "reviewerName": "Sharon Williams", "reviewText": "DON'T CARE FOR IT.  GAVE IT AS A GIFT AND THEY WERE OKAY WITH IT.  JUST NOT WHAT I EXPECTED.", "summary": "CASE", "unixReviewTime": 1391472000}

我想使用tensorflow从数据集中只提取reviewTextoverall特征,但面临以下错误。

AttributeError: in user code:
<ipython-input-4-419019a35c5e>:9 None  *
line_dataset = line_dataset.map(lambda row: transform(row))
<ipython-input-4-419019a35c5e>:2 transform  *
str_example = example.numpy().decode("utf-8")
AttributeError: 'Tensor' object has no attribute 'numpy'

我的代码片段如下:

def transform(example):
str_example = example.numpy().decode("utf-8")
json_example = json.loads(str_example)
overall = json_example.get('overall', None)
text = json_example.get('reviewText', None)
return (overall, text)
line_dataset = tf.data.TextLineDataset(filenames = [file_path])
line_dataset = line_dataset.map(lambda row: transform(row))
for example in line_dataset.take(5):
print(example)

我使用的是tensorflow 2.3.0。

数据集的输入管道总是被跟踪到图中(就像您使用了@tf.function一样(,以使其更快,这意味着,除其他外,您不能使用.numpy()。但是,您可以使用tf.numpy_function以NumPy数组的形式访问图形中的数据:

def transform(example):
# example will now by a NumPy array
str_example = example.decode("utf-8")
json_example = json.loads(str_example)
overall = json_example.get('overall', None)
text = json_example.get('reviewText', None)
return (overall, text)
line_dataset = tf.data.TextLineDataset(filenames = [file_path])
line_dataset = line_dataset.map(
lambda row: tf.numpy_function(transform, row, (tf.float32, tf.string)))
for example in line_dataset.take(5):
print(example)

有点冗长,但可以这样尝试:

def transform(example):     
str_example = example.numpy().decode("utf-8")     
json_example = json.loads(str_example)     
overall = json_example.get('overall', None)     
text = json_example.get('reviewText', None)     
return (overall, text)  
line_dataset = tf.data.TextLineDataset(filenames = [file_path]) 
line_dataset = line_dataset.map(
lambda input:     
tf.py_function(transform, [input], (tf.float32, tf.string))
)  
for example in line_dataset.take(5):     
print(example)

这个特定的代码段适用于任何python函数,而不仅仅适用于for numpy函数。因此,如果您需要printinput等函数,可以使用它。你不必知道所有的细节,但如果你感兴趣,请问我

最新更新