我正在使用TensorFlow rs构建一个TensorFlow图,无法从ops:split中获取单个操作



我正在尝试完成这个函数:

///Split layers take in a single layer and splits it into a 
///vector of layers. Since all tensors are two dimensional,
///we can split with a single usize on axis=0.
fn split<O1: Into<Output>>(
input: O1,
num_splits: usize,
scope: &mut Scope,
) -> Result<Vec<Output>, Status> {
let num_splits_op = ops::constant(num_splits.into(), scope)?;
let outputs = vec![];
let split_outputs = ops::split(num_splits_op, input.into(), scope)?;
//TODO: get vector of Outputs.
Ok(outputs)
}

我的问题是split_output类型是Operation。这是有道理的,因为我正在构建图,但我不能索引Operation;我必须提取会话运行参数来检索Tensor类型的TensorArray,然后对提取的对象进行索引并返回每个索引。

我在C++API中找不到索引操作(Rust crate镜像(。有没有这样的运算,或者有没有另一种技术可以给我一个由分裂运算产生的给定张量中的每个子张量的运算?

我需要为每个子张量创建一个Output或Operation,本质上是沿着轴=0返回每个条目,当我完成函数的编写时,轴=0应该是一个长度为num_splits的Operations向量。

我最终通过我编写的提交以及TF rs SIG小组的帮助和代码审查解决了这个问题。请考虑可以在自动生成的文档中找到的操作实例。

https://github.com/tensorflow/rust/commit/8bbf9f9234906a06a9635607d45b319d52897918

简写函数当前已损坏。我已经提出了我遇到的问题,但我解决了以下问题:

let axis_zero = ops::constant(0, scope)?;
let split_operation =
ops::Split::new()
.num_split(num_splits)
.build(axis_one.clone(), input.0, scope)?;

如果您不能用简写函数构建操作,那么很可能没有通过简写函数或简写函数正在调用的build_impl来分配参数。您可以通过从ops:yourop::new((方法链接构建器,然后显式调用不在build_impl或简写函数中的每个参数构建器来解决此问题,就像我在这里为ops:Split所做的那样。

为了解决我的特定操作索引问题,我从拆分操作中的每个索引显式地创建了一个输出类型:

// Create an output object for each split for graph connectionism
let mut outputs = vec![];
for i in 0..num_splits {
outputs.push(Output {
operation: split_operation.clone(),
index: i as i32,
});
}
Ok(outputs)

这明确地避免了写入TensorArray,并为每个";子张量";通过拆分操作进行拆分,这正是我所需要的。

相关内容

  • 没有找到相关文章

最新更新