在多个条件下查询pandas-df



我有以下熊猫df。

columns = ['question_id', 'answer', 'is_correct']
data = [['1','hello','1.0'],
['1','hello', '1.0'],
['1','hello', '1.0'],
['2', 'dog', '0.0'],
['2', 'cat', '1.0'],
['2', 'dog', '0.0'],
['2', 'the answer is cat', '1.0'],
['3', 'Milan', '1.0'],
['3', 'Paris', '0.0'],
['3', 'The capital is Paris', '0.0'],
['3', 'MILAN', '1.0'],
['4', 'The capital is Paris', '1.0'],
['4', 'London', '0.0'],
['4', 'Paris', '1.0'],
['4', 'paris', '1.0'],
['5', 'lol', '0.0'],
['5', 'rofl', '0.0'],
['6', '5.5', '1.0'],
['6', '5.2', '0.0']]
df = pd.DataFrame(columns=columns, data=data)
df

我想返回一个列表。内部列表应该正好包含同一问题的两个正确(is_correct=1.0(答案(a1和a2(。每个question_id一个内部列表。question_id中的其他答案可以简单地忽略。

边缘情况:

  1. 所有答案都是正确的->那就复制一份。请参阅question_id=1
  2. 没有答案是正确的->那么跳过这个问题。例如输出None。参见question_id=5
  3. 只有一个答案是正确的->那么跳过这个问题。例如输出None。参见question_id=5

示例:

[['Paris', 'The capital is Paris'], ['MILAN', 'milano'],...]

我目前的方法对a1和a2输出相同的结果。我做错了什么?

# This takes around 1min on cpu
def filter(grp):
is_correct = grp['is_correct'] == 1.0
if is_correct.any():
sample = grp.sample()
a1 = grp['answer'][is_correct].iloc[0]
a2 = grp['answer'][is_correct].iloc[0]
n = 6
_ = 0
# I will compare a1 and a2 6 times to see if they are the same
# and if they are the same grap another one for a2... probably not smart
while _ < n:
if a1.index == a2.index:
a2 = grp['answer'][is_correct].iloc[0]
_ +=1
return [a1, a2]
data = df.groupby('question_id').apply(filter).to_list()
# Drop None values
data_clean = [x for x in data if x is not None and x[1] is not None]
data_clean

你可以做:

# get groups with at least one correct answer
res = df[df['is_correct'].astype(float).gt(0)].groupby('question_id')['answer'].agg(lambda x: x.head(2).to_list()).to_list()
# filter out groups with only one element
out = [l for l in res if len(l) > 1]
print(out)

输出

[['hello', 'hello'], ['cat', 'the answer is cat'], ['Milan', 'MILAN'], ['The capital is Paris', 'Paris']]

如果您也需要对结果进行混洗:

def filter(g): 
answers = g.loc[g.is_correct == 1.0, 'answer'] 
# Presumably we want a random shuffle of the answers
answers = list(answers.sample(frac=1)) 
# Require at least one answer 
if len(answers) == 0: 
return None 
# Duplicate if only one answer 
elif len(answers) == 1: 
answers = answers*2 
return answers[:2] # answers is a list already, so can index
list(df.groupby('question_id').apply(filter))

输出:

[['hello', 'hello'],
['cat', 'the answer is cat'],
['MILAN', 'Milan'],
['paris', 'Paris'],
None,
['5.5', '5.5']]

最新更新