我正在尝试用TensorFlow lite编译一个有状态图。我在tensorflow 2.8.1上。作为下面的例子,我使用了累积和:
@dataclass
class CumulativeSum:
cumsum: tf.Variable = field(default_factory=lambda: tf.Variable(initial_value=0., dtype=tf.float64))
def add(self, x):
self.cumsum.assign(self.cumsum + x)
return self.cumsum.value()
但当我尝试将其转换为tf lite模型时。。。
# Make concrete function
cumsummer = CumulativeSum()
concrete_func = tf.function(
input_signature=[tf.TensorSpec(shape=(), dtype=tf.float64)],
)(cumsummer.add).get_concrete_function()
# Check that concrete function works
assert [concrete_func(float(x)) for x in range(4)] == [0, 1, 3, 6]
# Save tflite model
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS] # enable TensorFlow Lite ops.]
serialized_model = converter.convert()
# ^^^ ABOVE LINE THROWS:
# ValueError: Input 0 of node AssignVariableOp was passed double from ReadVariableOp/resource:0 incompatible with expected resource.
我得到错误
ValueError: Input 0 of node AssignVariableOp was passed double from ReadVariableOp/resource:0 incompatible with expected resource.
完整的测试代码复制这是在这个colab笔记本:
https://colab.research.google.com/drive/1KPjHhlCMVs2oFxodrFAO7YcPrfBADC2k?usp=sharing
有状态图是否可以在TFLite中工作?
想清楚了。关键在于TFLiteConverter.experimental_enable_resource_variables
的文档中。
experimental_enable_resource_variables: Experimental flag, subject to
change. Enables resource variables to be converted by this converter. This
is only allowed if from_saved_model interface is used. (default False)
我需要使用tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
创建模型,以便变量赋值工作。
这意味着
- 使类成为
tf.Module
的子类,如class CumulativeSum(tf.Module)
- 保存模型:
# Save tflite model
cumsummer = CumulativeSum() # Re-instantiate so that we start from blank state
concrete_func = tf.function(
input_signature=[tf.TensorSpec(shape=(), dtype=tf.float64)],
)(cumsummer.add).get_concrete_function()
saved_model_dir = os.path.expanduser('~/Downloads/test_save_model')
tf.saved_model.save(obj=cumsummer, export_dir=saved_model_dir, signatures={"add": concrete_func})
- 用
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
加载转换器
Colab笔记本电脑,完整的工作示例:
https://colab.research.google.com/drive/1Ud3lcyweOIeHRUgsoqr7-kBxndZSSBcH?usp=sharing
这里有另一个colab笔记本,它具有有用的辅助功能,可以帮助保存和加载模型,所以完整的代码只是
@dataclass
class CumulativeSum(tf.Module):
cumsum: tf.Variable = field(default_factory=lambda: tf.Variable(initial_value=0., dtype=tf.float64))
def add(self, x):
self.cumsum.assign(self.cumsum + x)
return self.cumsum.value()
cumsummer = CumulativeSum()
model_path = '~/Downloads/my_test_model.tflite'
save_model_function_to_tflite(
cumsummer.add,
input_signature=[tf.TensorSpec((), tf.float64)],
path=model_path
)
cumsum_add = load_tflite_model_func(model_path)
assert [cumsum_add(float(i)) for i in range(4)] == [0, 1, 3, 6]