加速pytorch自定义消息退出操作



我试图在PyTorch Geometric中实现自定义messageppassing卷积中的消息dropout。消息丢失包括随机忽略图中p%的边。我的想法是从forward()的输入edge_index中随机删除p%。

edge_index是形状为(2, num_edges)的张量,其中第一维为来自;节点ID,第二个是"节点ID"。所以我认为我可以做的是选择range(N)的随机样本,然后用它来掩盖其余的索引:

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # TODO: this is way too slow (4-5 times slower than without it)
            # message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
            random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
            edge_index_to_use = edge_index[:, random_keep_inx]
            edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr
        ...

然而,它太慢了,它使一个epoch变为5'而不是1'(慢5倍)。在PyTorch中是否有更快的方法来做到这一点?

编辑:瓶颈似乎是random.sample()调用,而不是屏蔽。所以我想我应该要求的是更快的替代方案。

我设法使用PyTorch的Dropout from function创建一个布尔掩码,这要快得多。现在一个历元又需要~1'。比我在其他地方找到的其他排列解决方案要好。

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # message dropout -> randomly ignore p % of edges in the graph
            mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
            edge_index_to_use = edge_index[:, mask]
            edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr
        ...

最新更新