使用__getitem__方法赋值时获取KeyError



我想实现bert模型。

所以我构建了一个包含__getitem__的类

我可以打印类似test[0]的内容,但当我指定一个值(如data = test[0](时,会出现KeyError


import random
"""
corpus_file = 'vocab'
vocab_size = 6
vocab_freq = 1
save_path = 'obj/'
max_sentence = 16
corpus -> org_line -> ope_line
corpus -> org_line -> token_list -> idx_to_token + token_to_idx
"""
class vocab():
def __init__(self, corpus_file, vocab_size, vocab_freq,save_path,max_sentence):
self.max_sentence = max_sentence
self.special_labels = ['PAD', 'UNK', 'SEP', 'CLS', 'MASK']
# output
self.data = []
self.idx_to_token = []
self.token_to_idx = {}
# ope
self.pre_ope(corpus_file,vocab_size,vocab_freq)
#self.save_data(save_path)
#self.print_data()
def pre_ope(self,corpus_file,vocab_size,vocab_freq):
token_list = {}
with open(corpus_file, 'r') as f:
while 1:
new_org_line = f.readline()
if new_org_line != '':
new_org_line = new_org_line.strip('n')
new_sentence = new_org_line.split('t')
sentence = []
for tmp in new_sentence:
token_sentence = tmp.split()
sentence.append(token_sentence)
for token in token_sentence:
if token_list.get(token):
token_list[token] += 1
else:
new_token = {token: 1}
token_list.update(new_token)
self.data.append(sentence)
else:
break
f.close()
token_list = sorted(token_list.items(), key=lambda i: (-i[1], i[0]))
self.build_dictionary(token_list,vocab_freq,vocab_size)
'''
Special labels:
PAD
UNK
SEP sentence separator
CLS classifier token
MASK
'''
def build_dictionary(self,token_list,vocab_freq,vocab_size):
for idx, label in enumerate(self.special_labels):
self.idx_to_token.append(label)
self.token_to_idx[label] = idx
for idx, (token, freq) in enumerate(token_list):
if freq >= vocab_freq :
self.idx_to_token.append(token)
self.token_to_idx[token] = idx + len(self.special_labels)
if len(self.idx_to_token) >= vocab_size + len(self.special_labels) and vocab_size != 0 :
break
def __len__(self):
return len(self.data)
def print_data(self):
print(self.data)
print(self.idx_to_token)
print(self.token_to_idx)
def __getitem__(self, item):
s1,s2,is_next_sentence = self.get_random_next_sentence(item)
s1,s1_label = self.get_random_sentence(s1)
s2,s2_label = self.get_random_sentence(s2)
sentence = [self.token_to_idx['CLS']] +s1 +[self.token_to_idx['SEP']] +s2 +[self.token_to_idx['SEP']]
label = [-1] +s1_label +[-1] +s2_label +[-1]
if len(sentence) > self.max_sentence :
print('sentence is greater than the setting of max sentence')
for pos in range(len(sentence),self.max_sentence):
sentence.append(self.token_to_idx['PAD'])
label.append(-1)
return {
'token' : sentence,
'label' : label,
'is_next' : is_next_sentence
}
def get_random_next_sentence(self,item):
s1 = self.data[item][0]
s2 = self.data[item][1]
if random.random() < 0.5 :
is_next = 0
s2 = self.data[self.get_random_line(item)][1]
else:
is_next = 1
return s1,s2,is_next
def get_random_line(self,item):
rand = random.randint(0,len(self.data)-1)
while rand == item :
rand = random.randint(0,len(self.data)-1)
return rand

def get_random_sentence(self,sentence):
label = []
for idx,token in enumerate(sentence):
rand = random.random()
if rand < 0.15:
rand = rand/0.15
if rand < 0.8: #mask
sentence[idx] = self.token_to_idx['MASK']
elif rand < 0.9: #rand
sentence[idx] = random.randint(len(self.special_labels),len(self.token_to_idx)-1)
else: # still
sentence[idx] = self.token_to_idx[token]
label.append(self.token_to_idx[token])
else:
sentence[idx] = self.token_to_idx[token]
label.append(-1)
return sentence,label
if __name__ == '__main__':
test = vocab('vocab', 0, 1,'obj/',16)
print(len(test))
print(test[0])
print(test[1])
data = test[0]

结果:

2
{'token': [3, 4, 18, 12, 15, 11, 2, 7, 9, 13, 2, 0, 0, 0, 0, 0], 'label': [-1, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 'is_next': 0}
{'token': [3, 6, 4, 5, 8, 5, 17, 2, 16, 5, 14, 20, 2, 0, 0, 0], 'label': [-1, -1, 19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 'is_next': 0}
Traceback (most recent call last):
File "vocab.py", line 146, in <module>
data = test[0]
File "vocab.py", line 90, in ```__getitem__```
s1,s1_label = self.get_random_sentence(s1)
File "vocab.py", line 136, in get_random_sentence
sentence[idx] = self.token_to_idx[token]
KeyError: 4

vocab文件:

hello this is my home   nice to meet you
I want to go to school  and have lunch

更改代码:

def get_random_next_sentence(self,item):
s1 = self.data[item][0]
s2 = self.data[item][1]
if random.random() < 0.5 :
is_next = 0
s2 = self.data[self.get_random_line(item)][1]
else:
is_next = 1
return s1,s2,is_next

至:

def get_random_next_sentence(self,item):
s1 = copy.deepcopy(self.data[item][0])
s2 = copy.deepcopy(self.data[item][1])
if random.random() < 0.5 :
is_next = 0
s2 = copy.deepcopy(self.data[self.get_random_line(item)][1])
print(s2)
else:
is_next = 1
return s1,s2,is_next

最新更新