我正在为我的机器学习课程研究这个国际象棋算法,但我不确定哪里出了问题。我正在看一个视频,但是当我试图适应我的模型时,似乎一切都出错了。我附上了下面的代码,它设置了一个棋盘,然后是一个卷积网络。我一直得到错误:
InvalidArgumentError: Graph Execution Error which points to model.fit(x_train, y_train).
The size of x_train is (150000, 14, 8, 8) while y_train is (150000, )
代码:
def random_board(max_depth=200):
board = chess.Board()
depth = random.randrange(0, max_depth)
for _ in range(depth):
all_moves = list(board.legal_moves)
random_move = random.choice(all_moves)
board.push(random_move)
if board.is_game_over():
break
return board
squares_index = {
'a': 0,
'b': 1,
'c': 2,
'd': 3,
'e': 4,
'f': 5,
'g': 6,
'h': 7
}
# example: h3 -> 17
def square_to_index(square):
letter = chess.square_name(square)
return 8 - int(letter[1]), squares_index[letter[0]]
def split_dims(board):
# create empty 3d matrix for board
board3d = numpy.zeros((14, 8, 8), dtype=numpy.int8)
# here we add the pieces's view on the matrix
for piece in chess.PIECE_TYPES:
for square in board.pieces(piece, chess.WHITE):
idx = numpy.unravel_index(square, (8, 8))
board3d[piece - 1][7 - idx[0]][idx[1]] = 1
for square in board.pieces(piece, chess.BLACK):
idx = numpy.unravel_index(square, (8, 8))
board3d[piece + 5][7 - idx[0]][idx[1]] = 1
# add attacks and valid moves too
# so the network knows what is being attacked
aux = board.turn
board.turn = chess.WHITE
for move in board.legal_moves:
i, j = square_to_index(move.to_square)
board3d[12][i][j] = 1
board.turn = chess.BLACK
for move in board.legal_moves:
i, j = square_to_index(move.to_square)
board3d[13][i][j] = 1
board.turn = aux
return board3d
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
import tensorflow.keras.utils as utils
import tensorflow.keras.optimizers as optimizers
def build_model(conv_size, conv_depth):
board3d = layers.Input(shape=(14, 8, 8))
#convolutional layers
x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(board3d)
for _ in range(conv_depth):
previous = x
x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters=conv_size, kernel_size=3, padding='same', data_format='channels_first')(x)
x = layers.BatchNormalization()(x)
x = layers.Add()([x, previous])
x = layers.Activation('relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(1, 'sigmoid')(x)
return models.Model(inputs=board3d, outputs=x)
model = build_model(32, 4)
utils.plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=False)
import tensorflow.keras.callbacks as callbacks
def get_dataset():
container = numpy.load('dataset\dataset.npz')
b, v = container['b'], container['v']
v = numpy.asarray(v / abs(v).max()/2 + 0.5, dtype=numpy.float32) #normalize
return b, v
x_train, y_train = get_dataset()
model.compile(optimizer=optimizers.Adam(5e-4), loss='mean_squared_error')
model.summary()
model.fit(x_train, y_train,
batch_size=2048,
epochs=1000,
verbose=1,
validation_split=0.1,
callbacks=[callbacks.ReduceLROnPlateau(monitor='loss', patience=10),
callbacks.EarlyStopping(monitor='loss', patience=15, min_delta=1e-4)])
model.save('model.h5')
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
~AppDataLocalTemp/ipykernel_15172/3395566405.py in <module>
1 model.compile(optimizer=optimizers.Adam(5e-4), loss='mean_squared_error')
2 model.summary()
----> 3 model.fit(x_train, y_train,
4 batch_size=2048,
5 epochs=1000,
C:ProgramDataAnaconda3libsite-packageskerasutilstraceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
C:ProgramDataAnaconda3libsite-packagestensorflowpythoneagerexecute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
52 try:
53 ctx.ensure_initialized()
---> 54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
InvalidArgumentError: Graph execution error:
编辑:model.summary()的输出日志
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 14, 8, 8)] 0 []
conv2d (Conv2D) (None, 32, 8, 8) 4064 ['input_1[0][0]']
conv2d_1 (Conv2D) (None, 32, 8, 8) 9248 ['conv2d[0][0]']
batch_normalization (BatchNorm (None, 32, 8, 8) 32 ['conv2d_1[0][0]']
alization)
activation (Activation) (None, 32, 8, 8) 0 ['batch_normalization[0][0]']
conv2d_2 (Conv2D) (None, 32, 8, 8) 9248 ['activation[0][0]']
batch_normalization_1 (BatchNo (None, 32, 8, 8) 32 ['conv2d_2[0][0]']
rmalization)
add (Add) (None, 32, 8, 8) 0 ['batch_normalization_1[0][0]',
'conv2d[0][0]']
activation_1 (Activation) (None, 32, 8, 8) 0 ['add[0][0]']
conv2d_3 (Conv2D) (None, 32, 8, 8) 9248 ['activation_1[0][0]']
batch_normalization_2 (BatchNo (None, 32, 8, 8) 32 ['conv2d_3[0][0]']
rmalization)
activation_2 (Activation) (None, 32, 8, 8) 0 ['batch_normalization_2[0][0]']
conv2d_4 (Conv2D) (None, 32, 8, 8) 9248 ['activation_2[0][0]']
batch_normalization_3 (BatchNo (None, 32, 8, 8) 32 ['conv2d_4[0][0]']
rmalization)
add_1 (Add) (None, 32, 8, 8) 0 ['batch_normalization_3[0][0]',
'activation_1[0][0]']
activation_3 (Activation) (None, 32, 8, 8) 0 ['add_1[0][0]']
conv2d_5 (Conv2D) (None, 32, 8, 8) 9248 ['activation_3[0][0]']
batch_normalization_4 (BatchNo (None, 32, 8, 8) 32 ['conv2d_5[0][0]']
rmalization)
activation_4 (Activation) (None, 32, 8, 8) 0 ['batch_normalization_4[0][0]']
conv2d_6 (Conv2D) (None, 32, 8, 8) 9248 ['activation_4[0][0]']
batch_normalization_5 (BatchNo (None, 32, 8, 8) 32 ['conv2d_6[0][0]']
rmalization)
add_2 (Add) (None, 32, 8, 8) 0 ['batch_normalization_5[0][0]',
'activation_3[0][0]']
activation_5 (Activation) (None, 32, 8, 8) 0 ['add_2[0][0]']
conv2d_7 (Conv2D) (None, 32, 8, 8) 9248 ['activation_5[0][0]']
batch_normalization_6 (BatchNo (None, 32, 8, 8) 32 ['conv2d_7[0][0]']
rmalization)
activation_6 (Activation) (None, 32, 8, 8) 0 ['batch_normalization_6[0][0]']
conv2d_8 (Conv2D) (None, 32, 8, 8) 9248 ['activation_6[0][0]']
batch_normalization_7 (BatchNo (None, 32, 8, 8) 32 ['conv2d_8[0][0]']
rmalization)
add_3 (Add) (None, 32, 8, 8) 0 ['batch_normalization_7[0][0]',
'activation_5[0][0]']
activation_7 (Activation) (None, 32, 8, 8) 0 ['add_3[0][0]']
flatten (Flatten) (None, 2048) 0 ['activation_7[0][0]']
dense (Dense) (None, 1) 2049 ['flatten[0][0]']
==================================================================================================
Total params: 80,353
Trainable params: 80,225
Non-trainable params: 128
__________________________________________________________________________________________________
data_format = channels_first
是罪魁祸首,在删除代码后执行完美(但AI并非完美)。