我正在尝试将numpy转换为tensorflow等效代码,以便与tf.function
兼容。。。
给定一个(32, 6)
numpy数组target_values
,如下所示:
array([[-0.01656106, 0.04762066, 0.05735449, -0.0284767 , -0.02237438,
-0.00042562],
[-0.01420249, 0.0477839 , 0.0563598 , -0.02971786, -0.02367548,
0.00001262],
[-0.01695916, 0.04826669, 0.05893629, -0.03067053, -0.02261235,
0.00345904],
[-0.01953977, 0.04540274, 0.05829531, -0.02759781, -0.02390759,
-0.00487727],
[-0.01708016, 0.04894669, 0.0606699 , -0.02576046, -0.02461138,
-0.00068538],
[-0.01604217, 0.04770135, 0.05761468, -0.02858265, -0.02624938,
-0.00084356],
[-0.01527106, 0.04699571, 0.05959677, -0.02956396, -0.02510098,
-0.00223234],
[-0.01448676, 0.04620824, 0.05775366, -0.03008122, -0.02655901,
-0.00159649],
[-0.0172577 , 0.04814827, 0.05807308, -0.02916523, -0.02367857,
-0.00100602],
[-0.01690523, 0.0484785 , 0.05807881, -0.02960616, -0.02560546,
-0.00065042],
[-0.0166171 , 0.0488232 , 0.05776291, -0.03231864, -0.02132723,
-0.00033605],
[-0.01541627, 0.04840397, 0.0580376 , -0.02927143, -0.02461101,
0.00121263],
[-0.01685588, 0.047661 , 0.05873172, -0.02989979, -0.02574112,
-0.00126612],
[-0.01333553, 0.05043796, 0.05915743, -0.02990219, -0.02657976,
-0.0007656 ],
[-0.01531163, 0.04781894, 0.05637252, -0.02968849, -0.02225551,
-0.00151382],
[-0.01357749, 0.04807179, 0.05955081, -0.02748637, -0.02498721,
-0.00040934],
[-0.01606943, 0.04768877, 0.05455931, -0.03136749, -0.02475093,
0.00245846],
[-0.01609829, 0.04687681, 0.05982678, -0.02886578, -0.02608151,
0.00015348],
[-0.01503662, 0.04740106, 0.05958583, -0.03141545, -0.02522127,
-0.00063602],
[-0.01697148, 0.04910276, 0.05744712, -0.02858391, -0.02481578,
-0.00072039],
[-0.01503395, 0.04843756, 0.05773868, -0.03061879, -0.02586869,
-0.00025573],
[-0.0152991 , 0.04847359, 0.05739099, -0.0299796 , -0.02552593,
-0.00334571],
[-0.01324895, 0.04529134, 0.05534273, -0.03109139, -0.02304241,
-0.00143186],
[-0.01280282, 0.05004944, 0.05856398, -0.0314032 , -0.02394999,
-0.00030306],
[-0.01677033, 0.04876196, 0.05794405, -0.02888608, -0.02658239,
-0.00015171],
[-0.01572544, 0.04779808, 0.05939355, -0.03048976, -0.02896303,
-0.00090334],
[-0.01542805, 0.04709881, 0.05839922, -0.02894112, -0.02240603,
-0.00188624],
[-0.01493233, 0.0476524 , 0.0581631 , -0.0297201 , -0.02485022,
-0.00087418],
[-0.01804641, 0.04739738, 0.06070606, -0.02981704, -0.02543145,
-0.00115484],
[-0.01518638, 0.04843838, 0.05744548, -0.02980216, -0.02420005,
0.00036349],
[-0.01442349, 0.04673778, 0.05804737, -0.03062913, -0.02476445,
-0.00066772],
[-0.01598305, 0.04622466, 0.0588723 , -0.03096713, -0.02364032,
-0.00005574]])
给定另一个(32,)
索引数组actions
,其值在(5(范围内(包括(:
array([0, 2, 5, 5, 1, 1, 3, 4, 0, 5, 4, 3, 4, 5, 1, 0, 3, 0, 0, 2, 2, 2,
0, 1, 4, 1, 4, 4, 0, 4, 1, 0])
我期待这个结果:
array([-0.01656106, 0.0563598 , 0.00345904, -0.00487727, 0.04894669,
0.04770135, -0.02956396, -0.02655901, -0.0172577 , -0.00065042,
-0.02132723, -0.02927143, -0.02574112, -0.0007656 , 0.04781894,
-0.01357749, -0.03136749, -0.01609829, -0.01503662, 0.05744712,
0.05773868, 0.05739099, -0.01324895, 0.05004944, -0.02658239,
0.04779808, -0.02240603, -0.02485022, -0.01804641, -0.02420005,
0.04673778, -0.01598305], dtype=float32)
对于self.batch_size == 32
,我可以使用在numpy中实现我需要的内容
state_action_values = target_values[np.arange(self.batch_size), actions]
对于target_value_update
是另一个新值的(32,)
数组,我需要使用将新值分配给这个切片
target_values[np.arange(self.batch_size), actions] = target_value_update
然而,在tf.function
下的tensorflow中,这是不可能的,我得到了以下错误:
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
所以我尝试:
target_values = tf.Variable(target_values)
state_action_values = tf.gather(target_values, actions, axis=1)
然而,这里是state_action_values
的值,它应该是(32,)
而不是(32, 32)
Tensor("GatherV2:0", shape=(32, 32), dtype=float32)
使用gather_nd()
:
a = tf.range(32)[:, tf.newaxis]
a = tf.concat((a, actions[:, tf.newaxis]), -1)
output = tf.gather_nd(target_values, a)