我在调试pytorch
代码时发现DataLoader
类的实例默认情况下似乎是全局变量。我不明白为什么会出现这种情况,但我已经建立了一个最低限度的工作示例,如下所示,应该可以重现我的观察结果。代码如下:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, df, n_feats, mode):
data = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]).transpose()
x = data[:, list(range(n_feats))] # features
y = data[:, -1] # target
self.x = torch.FloatTensor(x)
self.y = torch.FloatTensor(y)
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return len(self.x)
def prep_dataloader(df, n_feats, mode, batch_size):
dataset = MyDataset(df, n_feats, mode)
dataloader = DataLoader(dataset, batch_size, shuffle=False)
return dataloader
tr_set = prep_dataloader(df, 1, 'train', 200)
def test():
print(tr_set)
如上所示,tr_set
是在函数test
之前创建的,并且不传递给test
。然而,运行上面的代码,我得到了以下结果:
<torch.utils.data.dataloader.DataLoader object at 0x7fb6c2ea7610>
起初,我希望得到一个类似";名称错误:未定义名称"tr_set";。但是,即使tr_set
没有作为参数传递,函数也知道tr_set
并打印tr_set
的对象。我对此感到困惑,因为在这种情况下,tr_set
似乎是一个全局变量。
我想知道这是什么原因,以及如何防止它成为一个全局变量。非常感谢。
(更新:如果这很重要,我在jupyter笔记本上运行上面的代码。(
这与DataLoader
或PyTorch的工作方式无关。
它实际上不是一个全局变量,但由于tr_set
在外部范围内,因此在文件的第一级中,同一文件的其他组件可以访问它。然而,例如,其他模块无法访问同一变量,因此不是全局变量。函数test
能够访问tr_set
的原因是对该变量进行了闭包,即变量被传递到test
的内部范围。