我想用一个大数据集训练一个CNN。目前我将所有数据加载到tf中。常量,然后在tf.Session()中循环使用一个小的Batch大小。这对于数据集的一小部分很有效,但是当我增加输入大小时,我得到了错误:
ValueError: Cannot create a tensor proto whose content is larger than 2GB.
我怎样才能避免呢?
不要将数据加载到常量中,它将成为你计算图的一部分。
你应该:
- 创建一个以流方式加载数据的op
- 在python部分加载数据,并使用feed_dict将批处理传递到图
对于TensorFlow 1。x和Python 3,这是我的简单解决方案:
X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
X = tf.Variable(X_init)
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})
在实践中,你将主要指定Graph和Session进行连续计算,下面的代码将帮助你:my_graph = tf.Graph()
sess = tf.Session(graph=my_graph)
with my_graph.as_default():
X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
X = tf.Variable(X_init)
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})
.... # build your graph with X here
.... # Do some other things here
with my_graph.as_default():
output_y = sess.run(your_graph_output, feed_dict={other_placeholder: other_data})