如何将张量 dtype=tf.float32_ref 转换为 dtype=tf.float32



我想通过函数tf.cast()使用从float32_reffloat32的修改word_embeddings dtype:

   word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)

但它没有按预期工作,word_embeddings_modify dtype 仍然tf.float32_ref。

   word_embeddings = tf.scatter_nd_update(var_output, error_word_f,sum_all)
   word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)
   word_embeddings_dropout = tf.nn.dropout(word_embeddings_2, dropout_pl)
您可以使用

tf.identity取消引用_ref类型

word_embeddings = tf.identity(word_embeddings)

相关内容

  • 没有找到相关文章

最新更新