py_function:替换卷积操作,但保持梯度



我正在尝试在 TensorFlow 2 中训练 LeNet-5 模型,同时将所有 Dense 的矩阵乘法和 Conv2D 的卷积替换为用 C 编写的自定义卷积。更准确地说,我想保持这些梯度在默认情况下是这样的,但使用我的这些操作的实现而不是 TensorFlow 的默认操作。此外,我无法使用 TensorFlow 执行我的自定义卷积和矩阵乘法,我必须通过我的 C 代码,该代码通过CTypes调用。有没有办法这样做?

到目前为止,我尝试的是使用 TensorFlow 的@tf.experimental.dispatch_for_api来调用使用tf.py_function的函数,而 反过来又调用我的 C 代码。然而,这样做似乎会丢失梯度,并且无法训练模型。还有其他方法吗?

@tf.experimental.dispatch_for_api(tf.matmul,
{'a': ApproximateTensor},
{'b': ApproximateTensor},
{'a': tf.Tensor        , 'b': tf.Tensor        },
{'a': ApproximateTensor, 'b': tf.Tensor        },
{'a': tf.Tensor        , 'b': ApproximateTensor},
{'a': ApproximateTensor, 'b': ApproximateTensor},
)
def custom_matmul(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, output_type=None):
# tf.print('MATMUL')
if not isinstance(a, ApproximateTensor):
a = ApproximateTensor(a)
if not isinstance(b, ApproximateTensor):
b = ApproximateTensor(b)
_, out_size = b.shape
return ApproximateTensor(process.linear(a.values, b.values, tf.zeros(out_size)))

然后process.linear执行以下操作:

def linear(input, kernel, bias):
global c_approx
def compute(input, kernel, bias):
global c_approx
# Extract dimensions
batch_size, in_size = input.shape
_, out_size = kernel.shape
if batch_size is None:
batch_size = 1
# Create output
output = np.zeros((batch_size, out_size), dtype=np.float32)
# Compute
for b in range(batch_size):
output[b] = c_approx.custom_matmul(input[b], kernel[i]) + bias
return tf.convert_to_tensor(output)
output = tf.py_function(compute, [input, kernel, bias], input.dtype)
# Manually set output dimensions
batch_size, _ = input.shape
_, out_size = kernel.shape
output.set_shape((batch_size, out_size))
return output

换句话说,我想要与以下代码相反的功能:

@tf.custom_gradient
def custom_conv(x):
def grad_fn(dy):
return dy
return tf.nn.conv2d(x), grad_fn

我想重新定义 Conv2D,同时保持他的默认渐变。

提前致谢

由于这个答案,我实际上想通了:https://stackoverflow.com/a/43952168/9675442

这个想法是这样做的:

y = tf.matmul(a, b) + tf.stop_gradient(compute(a, b) - tf.matmul(a, b))

我希望这会帮助其他人

相关内容

  • 没有找到相关文章

最新更新