如何检查命名元组是否在列表中?



我试图检查NamedTuple"Transition"的缩进是否等于列表"self.memory"中的任何对象。

这是我尝试运行的代码:

from typing import NamedTuple
import random
import torch as t
Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)

class ReplayMemory:
def __init__(self, capacity):
self.memory = []
self.capacity = capacity
self.position = 0
def store(self, *args):
print(self.memory == Transition(*args))
if Transition(*args) in self.memory:
return
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
...

这是输出:

False
False

我得到的错误是:

...
if Transition(*args) in self.memory:
RuntimeError: bool value of Tensor with more than one value is ambiguous

这对我来说似乎很奇怪,因为打印告诉我"=="操作返回布尔值。

如何正确完成此操作?

谢谢

编辑:

*args 是一个元组,由

torch.Size([16, 12])
int
int
torch.Size([16, 12])
int
torch.Size([4])

我认为你应该明确定义平等。

from typing import NamedTuple
import random
import torch as t

class Sample(NamedTuple):
state: t.Tensor
action: int
def __eq__(self, other):
return bool(t.all(self.state == other.state)) and self.action == other.action

最新更新