无法将feed_dict键解释为张量:张量不是此图的元素



试图预测gen的类型,但出现了一些错误,你能指出哪里出了问题吗?任何帮助都将不胜感激。在其他情况下,在ZOLANDO数据集中预测衣服的类型是有效的。但在其他情况下,我被卡住了:(

#some code of gen1, gen2 and merged dataFrames

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
y_train = y_train['y'].values.reshape(1,14000).astype('int32')
labels_ = np.zeros((14000,2))
labels_[np.arange(14000), y_train] = 1        
X_train = np.array(X_train)
X_train = X_train.transpose()
X_train = np.where(X_train<0, X_train ** 2, X_train)        
n_dim = X_train.shape[0]
ops.reset_default_graph()
tf.compat.v1.disable_eager_execution()
n1 = 2
n2 = 2
cost_history = np.empty(shape=[1], dtype = float)
learning_rate = tf.compat.v1.placeholder(tf.float32, shape=())
X = tf.compat.v1.placeholder(tf.float32, shape=(n_dim, None))
#X = np.array([n_dim, None],dtype="float32")
Y = tf.compat.v1.placeholder(tf.float32, shape=(n2, None))
W1 = tf.Variable(tf.random.truncated_normal([n1,n_dim], stddev=.1))
b1 = tf.Variable(tf.zeros([n1,1]))
W2 = tf.Variable(tf.random.truncated_normal([n2,n1], stddev=.1))
b2 = tf.Variable(tf.zeros([n2,1]))


Z1 = tf.nn.relu(tf.matmul(W1,X) + b1)
Z2 = tf.nn.relu(tf.matmul(W2,Z1) + b2)
y_ = tf.nn.softmax(Z2,0)
cost =  - tf.reduce_mean(Y * tf.math.log(y_) + (1-Y) * tf.math.log(1-y_) )
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.compat.v1.global_variables_initializer()

ops.reset_default_graph()
tf.compat.v1.disable_eager_execution()

sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
training_epochs = 100

cost_history = []
for epoch in range(training_epochs+1):
sess.run(optimizer, feed_dict = {X: X_train, Y: labels_, learning_rate: 0.001})
cost_ = sess.run(cost, feed_dict = { X: X_train, Y: labels_, learning_rate: 0.001})
cost_history = np.append(cost_history, cost_)

if (epoch % 10 == 0):
print("Reached epoch",epoch,"cost J =", cost_)

获得:

类型错误:无法将feed_dict键解释为张量:张量张量("占位符_1:0",shape=(2,None(,dtype=float32(不是此图的元素。

graph = tf.Graph()    
with graph.as_default():
(your code of tf)

最新更新