如何将此numpy转换为与tf.function兼容的代码



我正在尝试将numpy转换为tensorflow等效代码,以便与tf.function兼容。。。

给定一个(32, 6)numpy数组target_values,如下所示:

array([[-0.01656106,  0.04762066,  0.05735449, -0.0284767 , -0.02237438,
-0.00042562],
[-0.01420249,  0.0477839 ,  0.0563598 , -0.02971786, -0.02367548,
0.00001262],
[-0.01695916,  0.04826669,  0.05893629, -0.03067053, -0.02261235,
0.00345904],
[-0.01953977,  0.04540274,  0.05829531, -0.02759781, -0.02390759,
-0.00487727],
[-0.01708016,  0.04894669,  0.0606699 , -0.02576046, -0.02461138,
-0.00068538],
[-0.01604217,  0.04770135,  0.05761468, -0.02858265, -0.02624938,
-0.00084356],
[-0.01527106,  0.04699571,  0.05959677, -0.02956396, -0.02510098,
-0.00223234],
[-0.01448676,  0.04620824,  0.05775366, -0.03008122, -0.02655901,
-0.00159649],
[-0.0172577 ,  0.04814827,  0.05807308, -0.02916523, -0.02367857,
-0.00100602],
[-0.01690523,  0.0484785 ,  0.05807881, -0.02960616, -0.02560546,
-0.00065042],
[-0.0166171 ,  0.0488232 ,  0.05776291, -0.03231864, -0.02132723,
-0.00033605],
[-0.01541627,  0.04840397,  0.0580376 , -0.02927143, -0.02461101,
0.00121263],
[-0.01685588,  0.047661  ,  0.05873172, -0.02989979, -0.02574112,
-0.00126612],
[-0.01333553,  0.05043796,  0.05915743, -0.02990219, -0.02657976,
-0.0007656 ],
[-0.01531163,  0.04781894,  0.05637252, -0.02968849, -0.02225551,
-0.00151382],
[-0.01357749,  0.04807179,  0.05955081, -0.02748637, -0.02498721,
-0.00040934],
[-0.01606943,  0.04768877,  0.05455931, -0.03136749, -0.02475093,
0.00245846],
[-0.01609829,  0.04687681,  0.05982678, -0.02886578, -0.02608151,
0.00015348],
[-0.01503662,  0.04740106,  0.05958583, -0.03141545, -0.02522127,
-0.00063602],
[-0.01697148,  0.04910276,  0.05744712, -0.02858391, -0.02481578,
-0.00072039],
[-0.01503395,  0.04843756,  0.05773868, -0.03061879, -0.02586869,
-0.00025573],
[-0.0152991 ,  0.04847359,  0.05739099, -0.0299796 , -0.02552593,
-0.00334571],
[-0.01324895,  0.04529134,  0.05534273, -0.03109139, -0.02304241,
-0.00143186],
[-0.01280282,  0.05004944,  0.05856398, -0.0314032 , -0.02394999,
-0.00030306],
[-0.01677033,  0.04876196,  0.05794405, -0.02888608, -0.02658239,
-0.00015171],
[-0.01572544,  0.04779808,  0.05939355, -0.03048976, -0.02896303,
-0.00090334],
[-0.01542805,  0.04709881,  0.05839922, -0.02894112, -0.02240603,
-0.00188624],
[-0.01493233,  0.0476524 ,  0.0581631 , -0.0297201 , -0.02485022,
-0.00087418],
[-0.01804641,  0.04739738,  0.06070606, -0.02981704, -0.02543145,
-0.00115484],
[-0.01518638,  0.04843838,  0.05744548, -0.02980216, -0.02420005,
0.00036349],
[-0.01442349,  0.04673778,  0.05804737, -0.03062913, -0.02476445,
-0.00066772],
[-0.01598305,  0.04622466,  0.0588723 , -0.03096713, -0.02364032,
-0.00005574]])

给定另一个(32,)索引数组actions,其值在(5(范围内(包括(:

array([0, 2, 5, 5, 1, 1, 3, 4, 0, 5, 4, 3, 4, 5, 1, 0, 3, 0, 0, 2, 2, 2,
0, 1, 4, 1, 4, 4, 0, 4, 1, 0])

我期待这个结果:

array([-0.01656106,  0.0563598 ,  0.00345904, -0.00487727,  0.04894669,
0.04770135, -0.02956396, -0.02655901, -0.0172577 , -0.00065042,
-0.02132723, -0.02927143, -0.02574112, -0.0007656 ,  0.04781894,
-0.01357749, -0.03136749, -0.01609829, -0.01503662,  0.05744712,
0.05773868,  0.05739099, -0.01324895,  0.05004944, -0.02658239,
0.04779808, -0.02240603, -0.02485022, -0.01804641, -0.02420005,
0.04673778, -0.01598305], dtype=float32)

对于self.batch_size == 32,我可以使用在numpy中实现我需要的内容

state_action_values = target_values[np.arange(self.batch_size), actions]

对于target_value_update是另一个新值的(32,)数组,我需要使用将新值分配给这个切片

target_values[np.arange(self.batch_size), actions] = target_value_update

然而,在tf.function下的tensorflow中,这是不可能的,我得到了以下错误:

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

所以我尝试:

target_values = tf.Variable(target_values)
state_action_values = tf.gather(target_values, actions, axis=1)

然而,这里是state_action_values的值,它应该是(32,)而不是(32, 32)

Tensor("GatherV2:0", shape=(32, 32), dtype=float32)

使用gather_nd():

a = tf.range(32)[:, tf.newaxis]
a = tf.concat((a, actions[:, tf.newaxis]), -1)
output = tf.gather_nd(target_values, a)

相关内容

  • 没有找到相关文章

最新更新