我正在尝试为交易创建一个自定义的tf代理环境。当我试图通过调用utils.validate_py_environment(environment, episodes=1)
来验证它时,我会得到一个ValueError'time_step' doesn't match 'time_step_spec'
我已经试着找出区别有一段时间了,但似乎找不到。我是不是遗漏了什么?
观测规范
self._observation_spec = {
# market_history x [o, h, l, c, sma_5, ema_13, volume, trades, rsi, macd_latest, macd_signal]
'visible_market_data': array_spec.ArraySpec(
shape=(self._market_history,11), dtype=np.float32, name='visible_market_data'),
'current_trade': {
'trade_type': array_spec.BoundedArraySpec( # 0 = no position, 1 = long, 2 = short
shape=(), dtype=np.int32, minimum=0, maximum=2, name='trade_type'),
'open_intervals': array_spec.ArraySpec(
shape=(), dtype=np.int32, name='open_intervals'),
},
'action_mask': array_spec.BoundedArraySpec( # [do nothing, buy, sell, close]
shape=(4,), dtype=np.int32, minimum=0, maximum=1, name='action_mask')
}
我如何保持观察
self._observation = {
"visible_market_data": self._data[0],
"current_trade": {"trade_type": np.array(0, dtype=np.int32), "open_intervals": np.array(0, dtype=np.int32)},
"action_mask": np.array([1,1,1,0], dtype=np.int32),
}
错误消息
ValueError: Given `time_step`: TimeStep(
{'discount': array(1., dtype=float32),
'observation': {'action_mask': array([1, 1, 1, 0]),
'current_trade': {'open_intervals': array(0),
'trade_type': array(0)},
'visible_market_data': array([[0.35, 0.42, 0.33, 0.38, 0.41, 0.53, 0.34, 0.27, 0.31, 0.43, 0.5 ],
[0.38, 0.39, 0.29, 0.31, 0.39, 0.52, 0.42, 0.36, 0.28, 0.42, 0.47],
[0.31, 0.33, 0.2 , 0.2 , 0.34, 0.5 , 0.62, 0.52, 0.2 , 0.37, 0.44],
[0.2 , 0.35, 0.2 , 0.35, 0.32, 0.48, 0.32, 0.35, 0.2 , 0.36, 0.4 ],
[0.35, 0.41, 0.32, 0.39, 0.33, 0.48, 0.33, 0.32, 0.28, 0.29, 0.36],
[0.37, 0.42, 0.31, 0.34, 0.32, 0.47, 0.29, 0.3 , 0.21, 0.23, 0.3 ],
[0.34, 0.48, 0.31, 0.4 , 0.34, 0.46, 0.46, 0.46, 0.29, 0.22, 0.25],
[0.42, 0.53, 0.41, 0.48, 0.39, 0.46, 0.51, 0.29, 0.43, 0.2 , 0.22],
[0.48, 0.53, 0.42, 0.44, 0.41, 0.47, 0.35, 0.29, 0.39, 0.22, 0.23],
[0.44, 0.44, 0.38, 0.44, 0.42, 0.45, 0.3 , 0.27, 0.43, 0.25, 0.24],
[0.44, 0.44, 0.39, 0.39, 0.43, 0.46, 0.4 , 0.21, 0.37, 0.3 , 0.26],
[0.39, 0.4 , 0.29, 0.33, 0.42, 0.44, 0.39, 0.37, 0.33, 0.29, 0.25],
[0.33, 0.5 , 0.3 , 0.5 , 0.42, 0.44, 0.46, 0.33, 0.44, 0.32, 0.24],
[0.49, 0.51, 0.43, 0.5 , 0.43, 0.45, 0.25, 0.28, 0.48, 0.31, 0.25],
[0.5 , 0.51, 0.46, 0.49, 0.44, 0.45, 0.2 , 0.22, 0.48, 0.3 , 0.25],
[0.5 , 0.55, 0.44, 0.45, 0.45, 0.46, 0.34, 0.3 , 0.5 , 0.36, 0.27],
[0.46, 0.55, 0.43, 0.5 , 0.49, 0.44, 0.49, 0.25, 0.62, 0.37, 0.29],
[0.48, 0.5 , 0.34, 0.37, 0.46, 0.46, 0.37, 0.42, 0.43, 0.43, 0.34],
[0.37, 0.45, 0.34, 0.37, 0.44, 0.43, 0.45, 0.3 , 0.39, 0.43, 0.37],
[0.37, 0.38, 0.31, 0.33, 0.41, 0.44, 0.38, 0.29, 0.4 , 0.5 , 0.41],
[0.33, 0.35, 0.27, 0.31, 0.38, 0.41, 0.4 , 0.3 , 0.33, 0.47, 0.44],
[0.27, 0.43, 0.27, 0.38, 0.36, 0.42, 0.3 , 0.3 , 0.32, 0.53, 0.47],
[0.42, 0.42, 0.33, 0.38, 0.36, 0.41, 0.24, 0.22, 0.35, 0.5 , 0.5 ],
[0.4 , 0.53, 0.38, 0.43, 0.37, 0.41, 0.31, 0.28, 0.39, 0.53, 0.53],
[0.5 , 0.5 , 0.43, 0.44, 0.39, 0.41, 0.32, 0.22, 0.45, 0.56, 0.59],
[0.44, 0.49, 0.42, 0.45, 0.42, 0.41, 0.31, 0.25, 0.52, 0.6 , 0.62],
[0.49, 0.5 , 0.42, 0.44, 0.43, 0.41, 0.26, 0.26, 0.33, 0.63, 0.65],
[0.44, 0.47, 0.4 , 0.4 , 0.43, 0.42, 0.21, 0.22, 0.29, 0.66, 0.68],
[0.41, 0.42, 0.39, 0.4 , 0.43, 0.41, 0.2 , 0.2 , 0.29, 0.7 , 0.7 ],
[0.4 , 0.45, 0.38, 0.4 , 0.42, 0.41, 0.24, 0.21, 0.34, 0.68, 0.68],
[0.4 , 0.54, 0.4 , 0.54, 0.44, 0.41, 0.39, 0.3 , 0.46, 0.69, 0.68],
[0.52, 0.77, 0.52, 0.67, 0.48, 0.43, 0.8 , 0.8 , 0.72, 0.68, 0.68],
[0.67, 0.74, 0.65, 0.73, 0.55, 0.44, 0.44, 0.54, 0.75, 0.68, 0.69],
[0.72, 0.73, 0.68, 0.68, 0.6 , 0.47, 0.35, 0.38, 0.74, 0.74, 0.72],
[0.71, 0.8 , 0.68, 0.78, 0.68, 0.47, 0.5 , 0.43, 0.8 , 0.77, 0.71],
[0.78, 0.8 , 0.69, 0.73, 0.72, 0.51, 0.37, 0.45, 0.71, 0.8 , 0.7 ]])},
'reward': array(0., dtype=float32),
'step_type': array(0)})
does not match expected `time_step_spec`:
TimeStep(
{'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0,
maximum=1.0),
'observation': {'action_mask': BoundedArraySpec(shape=(4,), dtype=dtype('int32'), name='action_mask', minimum=0, maximum=1),
'current_trade': {'open_intervals': ArraySpec(shape=(), dtype=dtype('int32'), name='open_intervals'),
'trade_type': BoundedArraySpec(shape=(), dtype=dtype('int32'), name='trade_type', minimum=0, maximum=2)},
'visible_market_data': ArraySpec(shape=(36, 11), dtype=dtype('float32'), name='visible_market_data')},
'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),
'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')})
我发现了我的问题。我不得不显式地用dtype=np.float32
重新初始化visible_market_data
numpy数组。