我试图在TensorFlow中执行稀疏矩阵-密集矩阵乘法,其中两个矩阵都有一个领先的批处理维度(即秩3)。我知道TensorFlow提供了tf.sparse。sparse_dense_matmul函数秩2矩阵,但我正在寻找一种方法来处理秩3矩阵。在TensorFlow中是否有一个内置的函数或方法可以有效地处理这种情况,而不需要昂贵的重塑或切片操作?性能在我的应用程序中是至关重要的。
为了说明我的问题,考虑下面的示例代码:
import tensorflow as tf
# Define sparse and dense matrices with leading batch dimension
sparse_tensor = tf.SparseTensor(indices=[[0, 1, 1], [0, 0, 1], [1, 1, 1], [1, 2, 1], [2, 1, 1]],
values=[1, 1, 1, 1, 1],
dense_shape=[3, 3, 2])
dense_matrix = tf.constant([[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
[[0.9, 0.10, 0.11, 0.12], [0.13, 0.14, 0.15, 0.16]],
[[0.17, 0.18, 0.19, 0.20], [0.21, 0.22, 0.23, 0.24]]], dtype=tf.float32)
# Perform sparse x dense matrix multiplication
result = tf.???(sparse_tensor, dense_matrix) # Result should have shape [3, 3, 4]
在TF中,Sparse和Dense乘法只向Sparse广播Dense。否则,batch_sparse_dense_matmul
可以简单地通过
tf.sparse.reduce_sum(tf.sparse.expand_dims(sparse_tensor,-1)*tf.expand_dims(dense_matrix,1), 2)
#[3,3,2,1] * [3,1,2,4] and reduce sum along dim=2
# the above throws error
# because the last dim of sparse tensor [1] cannot be broadcasted to [4]
为了解决上述问题,我们需要将稀疏张量的最后一个维度tile
使其为4。
k = 4
tf.sparse.concat(-1,[tf.sparse.expand_dims(sparse_tensor, -1)]*k)
##[3,3,2,4]
放在一起,
tf.sparse.reduce_sum(tf.sparse.concat(-1,[tf.sparse.expand_dims(sparse_tensor, -1)]*k)*tf.expand_dims(dense_matrix,1), 2)
#timeit:1.99ms
另一种方法是使用tf.map_fn
,
tf.map_fn(
lambda x: tf.sparse.sparse_dense_matmul(x[0], x[1]),
elems=(tf.sparse.reorder(tf.cast(sparse_tensor, tf.float32)),dense_matrix ), fn_output_signature=tf.float32
)
#timeit:4.42ms