为什么DataLoader的实例是pytorch中的全局变量



我在调试pytorch代码时发现DataLoader类的实例默认情况下似乎是全局变量。我不明白为什么会出现这种情况,但我已经建立了一个最低限度的工作示例,如下所示,应该可以重现我的观察结果。代码如下:

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, df, n_feats, mode):
data = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]).transpose()
x = data[:, list(range(n_feats))]   # features
y = data[:, -1]  # target
self.x = torch.FloatTensor(x)
self.y = torch.FloatTensor(y)

def __getitem__(self, index):
return self.x[index], self.y[index]

def __len__(self):
return len(self.x)
def prep_dataloader(df, n_feats, mode, batch_size):
dataset = MyDataset(df, n_feats, mode)
dataloader = DataLoader(dataset, batch_size, shuffle=False)

return dataloader
tr_set = prep_dataloader(df, 1, 'train', 200)
def test():
print(tr_set)

如上所示,tr_set是在函数test之前创建的,并且不传递给test。然而,运行上面的代码,我得到了以下结果:

<torch.utils.data.dataloader.DataLoader object at 0x7fb6c2ea7610>

起初,我希望得到一个类似";名称错误:未定义名称"tr_set";。但是,即使tr_set没有作为参数传递,函数也知道tr_set并打印tr_set的对象。我对此感到困惑,因为在这种情况下,tr_set似乎是一个全局变量。

我想知道这是什么原因,以及如何防止它成为一个全局变量。非常感谢。

(更新:如果这很重要,我在jupyter笔记本上运行上面的代码。(

这与DataLoader或PyTorch的工作方式无关。

它实际上不是一个全局变量,但由于tr_set在外部范围内,因此在文件的第一级中,同一文件的其他组件可以访问它。然而,例如,其他模块无法访问同一变量,因此不是全局变量。函数test能够访问tr_set的原因是对该变量进行了闭包,即变量被传递到test的内部范围。

相关内容

  • 没有找到相关文章

最新更新