CNTK:有条件执行



是否可以在CNTK中创建一个"条件"网络并仅根据另一个输入变量应用于输入之一?请参阅以下代码:

a_in = ct.input_variable(shape=[16,16])
b_in = ct.input_variable(shape=[16,16])
flag = ct.input_variable(shape=[])
a_branch = ct.layers.Sequential([...])
b_branch = ct.layers.Sequential([...])
sel_branch = ct.element_select(flag, a_branch, b_branch)
out = sel_branch(a_in, b_in)

howerer,这是不起作用的,因为sel_branch期望3个参数,而不是a_branchb_branch要求的参数(这是完全正确的(因为我在这里我以错误的方式使用element_select(

请记住,目的是避免执行两个分支,

答案是否定的,目前CNTK中没有有条件执行。一般情况下,标志是向量/张量,其某些元素为0,而其他元素将为1。但是,即使实现了sel_branch的签名,仍然需要3个参数,因为这是一个"编译时间"属性,而上述优化只能在运行时确定。即使在您的情况下,当标志为标量时,它可能是一个批量的0,另一批也可能为1,而sel_branch的签名也无法从批处理变为批量。

最新更新