我试图使用我在线获得的以下代码找出我的模型使用的 FLOPS 数量:
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops # Prints the "flops" of the model.
# .... Define your model here ....
print(get_flops(model))
但是,运行此代码会给我此错误:
Traceback (most recent call last):
File "/Users/Desktop/FYP/Code/Python/code/main.py", line 243, in <module>
print(get_flops(model))
File "/Users/Desktop/FYP/Code/Python/code/main.py", line 232, in get_flops
run_meta = tf.RunMetadata()
AttributeError: module 'tensorflow' has no attribute 'RunMetadata'
如何绕过此错误?我已经在线阅读,我得到的唯一帮助是更新我的张量流版本。但是,这是最新版本。
AttributeError:模块'tensorflow'没有属性'RunMetadata'
您收到此错误是因为您执行的代码仅与 Tensorflow 1.x 兼容。
我已经成功执行了您的代码,而无需对TF 1.x进行任何更改
%tensorflow_version 1.x
import tensorflow as tf
from keras import backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
print(tf.__version__)
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops # Prints the "flops" of the model.
# .... Define your model here ....
model = Sequential()
model.add(Dense(8, activation = 'softmax'))
print(get_flops(model))
输出:
Using TensorFlow backend.
1.15.2
0
在TF 2.x中,您必须使用tf.compat.v1.RunMetadata
而不是tf.RunMetadata
为了在TF 2.1.0中工作您的代码,我已经进行了符合TF 2.x的所有必要更改
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
print(tf.__version__)
def get_flops(model):
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.compat.v1.profiler.profile(graph=tf.compat.v1.keras.backend.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops # Prints the "flops" of the model.
# .... Define your model here ....
model = Sequential()
model.add(Dense(8, activation = 'softmax'))
print(get_flops(model))
输出:
2.1.0
0
请参考 Tensorflow 2 的所有兼容符号
这里将run_meta替换为以下内容
run_meta= tf.compat.v1.RunMetadata()
查看此链接以获取更多信息:https://www.tensorflow.org/guide/migrate