PyTorch中是否有沿着dim(blocky Softmax)分段的Softmax实现



例如,给定logits、dim和boundary,

boundary = torch.tensor([[0, 3, 4, 8, 0]
[1, 3, 5, 7, 9]]
# representing sections look like:
#    [[00012222_]
#     [_00112233]
#  in shape: (2, 9)
# (sections cannot be sliced)
logits = torch.rand(2, 9, 100)
result = blocky_softmax(logits, dim = 1, boundary = boundary)
# result[:, :, 0] may look like:
#   [[0.33, 0.33, 0.33, 1.00, 0.25, 0.25, 0.25, 0.25, 0.0 ]
#    [0.0,  0.50, 0.50, 0.50, 0.50, 0.50, 0.50, 0.50, 0.50]]
# other 99 slices looks similar with each blocks sum to 1.

我们希望Softmax应用于dim=1,但分段也应用于该dim。我目前使用PyTorch的实现是使用for。它速度慢并且占用太多内存,看起来像:

def blocky_softmax(logits, splits, map_inf_to = None):
_, batch_len, _ = logits.shape
exp_logits    = logits.exp() # [2, 9, 100]
batch_seq_idx = torch.arange(batch_len, device = logits.device)[None, :]
base          = torch.zeros_like(logits)
_, n_blocks   = splits.shape
for nid in range(1, n_blocks):
start = splits[:, nid - 1, None]
end   = splits[:, nid,     None]
area = batch_seq_idx >= start
area &= batch_seq_idx < end
area.unsqueeze_(dim = 2)
blocky_z = area * blocky_z
base = base + blocky_z
if map_inf_to is not None:
good_base = base > 0
ones = torch.ones_like(base)
base = torch.where(good_base, base, ones)
exp_logits = torch.where(good_base, exp_logits, ones * map_inf_to)
return exp_logits / base

这个实现被n_blocks倍地减慢和增加。但它可以与每个部分平行。如果没有现成的函数,我应该写一个CUDA/C++库吗?我希望你能帮我解决这个问题。

为了进一步推广,我希望boundary/sections中存在不连续性。

sections = torch.tensor([[ 0,  0,  0, -1,  2,  3,  2,  3,  0,  3]
[-1,  0,  0,  1,  2,  1,  2,  1, -1,  1]]
# [[000_232303]
#  [_0012121_1]]

感谢阅读:(

我意识到scatter_addgather完美地解决了这个问题。

最新更新