我正在自定义数据集中训练pytorch-yolov3。我准备了所有需要的txt,数据和名称文件。
运行以下命令时:
python3 train.py --model_def config/yolov3.cfg --data_config config/custom.data
我得到以下错误:
Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (expandTensors at /pytorch/aten/src/ATen/native/IndexingUtils.h:20)
Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (expandTensors at /pytorch/aten/src/ATen/native/IndexingUtils.h:20)
Traceback (most recent call last):
File "train.py", line 136, in <module>
logger.list_of_scalars_summary(tensorboard_log, batches_done)
File "/home/sudip/torch/PyTorch-YOLOv3/utils/logger.py", line 16, in list_of_scalars_summary
summary = tf.summary(value=[tf.summary.Value(tag=tag, simple_value=value) for tag, value in tag_value_pairs])
File "/home/sudip/torch/PyTorch-YOLOv3/utils/logger.py", line 16, in <listcomp>
summary = tf.summary(value=[tf.summary.Value(tag=tag, simple_value=value) for tag, value in tag_value_pairs])
AttributeError: module 'tensorboard.summary._tf.summary' has no attribute 'Value'
这是logger.py
文件:
import tensorflow as tf
class Logger(object):
def __init__(self, log_dir):
self.writer = tf.summary.create_file_writer(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.summary(value=[tf.summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
def list_of_scalars_summary(self, tag_value_pairs, step):
"""Log scalar variables."""
summary = tf.summary(value=[tf.summary.Value(tag=tag, simple_value=value) for tag, value in tag_value_pairs])
self.writer.add_summary(summary, step)
有什么解决这个问题的想法或建议吗?
如有任何帮助,我们将不胜感激。
谢谢
更改
summary = tf.summary(value=[tf.summary.Value(tag=tag, simple_value=value)])
至
summary = tf.summary.scalar(tag=tag, simple_value=value)
对我来说,这个问题是由于代码在TF 1.x中,而我有TF 2.x。因此,用tf.compat.v1.Summary
替换self.tf.Summary
(正如dspencer所建议的(就解决了这个问题。
将logger.py
文件更新为实际从train.py
调用的版本后,错误
AttributeError: module 'tensorboard.summary._tf.summary' has no attribute 'Value'
发生。这可能是因为您使用的是tensorflow 2.1.0,而来自开源项目的logger.py
脚本使用的是带有不同API的tensorflow的早期版本。
在最新版本的TensorFlow中,您应该如下修改logger.py:
import tensorflow as tf
class Logger(object):
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.create_file_writer(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
with self.writer.as_default():
tf.summary.scalar(tag, value, step=step)
self.writer.flush()
def list_of_scalars_summary(self, tag_value_pairs, step):
"""Log scalar variables."""
with self.writer.as_default():
for tag, value in tag_value_pairs:
tf.summary.scalar(tag, value, step=step)
self.writer.flush()