Python TensorFlow将经典for循环重写为tf.while_oop



我有一个函数

def multivariate_data(dataset, target, start_index, end_index, history_size,
target_size, step, single_step=False):
data = []
labels = []
start_index = start_index + history_size
if end_index is None:
end_index = len(dataset) - target_size
#print(history_size)
for i in range(start_index, end_index):
indices = range(i-history_size, i, step)
data.append(dataset[indices])
if single_step:
labels.append(target[i+target_size])
else:
labels.append(target[i:i+target_size])
return np.array(data), np.array(labels)

我想在GPU上进行计算。但只有张量运算可以在GPU上运行。所以我需要重写我的函数。for循环必须更改为tf.while_oop。我所有的numpy数组都必须改为张量。如何将for循环的函数重写为tf.while_oop?

II没有使用GPU进行测试,我假设输入数据是秩1张量,并删除了一些参数。我没有使用标签。也没有异常处理,但这是可以重构的。

我将张量连接到self上_数据,但还有其他有效的"附加"方法。

self._data = tf.concat([self._data,tf.gather(dataset, tf.range(1, 3, 1))],0)

这一行只是显示了一个范围可以用来从一个张量中选取值,并将其附加到另一个张量。由于数据是固定的,因此不会处理异常。

import tensorflow as tf


class MultiVariate():
def __init__(self):
self._data = None
self._labels = None
def multivariate_data(self,
dataset,
start_index,
end_index,
history_size,
target_size,
single_step=False):
start_index = start_index + history_size
tf.print("end_index ", end_index)
tf.print("start_index ", start_index)
if self._data is None:
self._data = tf.cast(tf.Variable(tf.reshape((), (0,))),dtype=tf.int32)
if self._labels is None:
self._labels = tf.cast(tf.Variable(tf.reshape((), (0,))),dtype=tf.int32)
if end_index is None:
end_index = len(dataset) - target_size
def cond(i, j):
return tf.less(i, j)
def body(i, j):
#A range of values are gathered
self._data = tf.concat([self._data,[tf.gather(dataset, i)]],0)
if ( i == start_index ): #Showing how A range of values are gathered and appended
self._data = tf.concat([self._data,tf.gather(dataset, tf.range(1, 3, 1))],0)
return tf.add( i , 1 ), j
_,_ = tf.while_loop(cond, body, [start_index,end_index],shape_invariants=[start_index.get_shape(), end_index.get_shape()])
return self._data
mv = MultiVariate()
d =    mv.multivariate_data(
tf.constant([1,88,99,4,5,6,7,8,9]),
tf.constant(2),
tf.constant(8),
tf.constant(1),
tf.constant(2),
tf.constant(2))
print("print ",d)

打印tf张量([4 88 99 5 6 7 8],shape=(7,(,dtype=int32(

最新更新