拆分数据以按条件进行训练和测试



假设我有一个包含贷款信息的pandas DataFrame,我想预测用户不还钱的概率(由数据框中的default列表示)。我想使用sklearn.model_selection.train_test_split来分割训练集和测试集中的数据。

然而,我想确保具有相同客户ID的贷款不会同时出现在测试和火车上。我该怎么做?

下面是我的数据样本:

d = {'loan_date': ['20170101','20170701','20170301','20170415','20170515'],
'customerID': [111,111,222,333,444],
'loanID': ['aaa','fff','ccc','ddd','bbb'],
'loan_duration' : [6,3,12,5,12],
'gender':['F','F','M','F','M'],
'loan_amount': [20000,10000,30000,10000,40000],
'default':[0,1,0,0,1]}
df = pd.DataFrame(data=d)

例如,CustomerID==111贷款记录应出现在测试或列车集中,但不能同时出现在两者中。

我提出以下解决方案。具有相同customerID的客户不会出现在培训和测试中;按活动划分的aslo客户,也就是说,拥有相同贷款数量的用户中,大约有同等比例的用户将接受培训和测试。

我扩展了用于演示目的的数据样本:

d = {'loan_date': ['20170101','20170701','20170301','20170415','20170515','20170905', '20170814', '20170819', '20170304'],         
'customerID': [111,111,222,333,444,222,111,444,555],        
'loanID': ['aaa','fff','ccc','ddd','bbb','eee', 'kkk', 'zzz', 'yyy'],                                                         
'loan_duration' : [6,3,12,5,12, 3, 17, 4, 6],
'gender':['F','F','M','F','M','M', 'F', 'M','F'],
'loan_amount': [20000,10000,30000,10000,40000,20000,30000,30000,40000],
'default':[0,1,0,0,1,0,1,1,0]}
df = pd.DataFrame(data=d) 

代码:

from sklearn.model_selection import train_test_split
def group_customers_by_activity(df):
value_count = df.customerID.value_counts().reset_index()
df_by_customer = df.set_index('customerID')
df_s = [df_by_customer.loc[value_count[value_count.customerID == count]['index']] for count in value_count.customerID.unique()]
return df_s

-此函数将df除以customerID活动(具有相同customerID的条目数)
此功能的示例输出:

group_customers_by_activity(df)
Out:
[           loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
111         20170101    aaa              6      F        20000        0
111         20170701    fff              3      F        10000        1
111         20170814    kkk             17      F        30000        1,
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
222         20170301    ccc             12      M        30000        0
222         20170905    eee              3      M        20000        0
444         20170515    bbb             12      M        40000        1
444         20170819    zzz              4      M        30000        1,
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
333         20170415    ddd              5      F        10000        0
555         20170304    yyy              6      F        40000        0]

-拥有1、2、3笔贷款等的用户组。

该功能以用户进入列车或测试的方式划分一组:

def split_group(df_group, train_size=0.8):
customers = df_group.index.unique()
train_customers, test_customers = train_test_split(customers, train_size=train_size)
train_df, test_df = df_group.loc[train_customers], df_group.loc[test_customers]
return train_df, test_df
split_group(df_s[2])
Out:
(           loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
444         20170515    bbb             12      M        40000        1
444         20170819    zzz              4      M        30000        1,
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
222         20170301    ccc             12      M        30000        0
222         20170905    eee              3      M        20000        0)

剩下的就是将其应用于所有"客户活动"组:

def get_sized_splits(df_s, train_size):
train_splits, test_splits = zip(*[split_group(df_group, train_size) for df_group in df_s])
return train_splits, test_splits
df_s = group_customers_by_activity(df)
train_splits, test_splits = get_sized_splits(df_s, 0.8)
train_splits, test_splits
Out:
((Empty DataFrame
Columns: [loan_date, loanID, loan_duration, gender, loan_amount, default]
Index: [],
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
444         20170515    bbb             12      M        40000        1
444         20170819    zzz              4      M        30000        1,
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
333         20170415    ddd              5      F        10000        0),
(           loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
111         20170101    aaa              6      F        20000        0
111         20170701    fff              3      F        10000        1
111         20170814    kkk             17      F        30000        1,
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
222         20170301    ccc             12      M        30000        0
222         20170905    eee              3      M        20000        0,
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
555         20170304    yyy              6      F        40000        0))

不要害怕emty DataFrame,它很快就会被连接起来。split函数具有以下定义:

def split(df, train_size):
df_s = group_customers_by_activity(df)
train_splits, test_splits = get_sized_splits(df_s, train_size=train_size)
return pd.concat(train_splits), pd.concat(test_splits)
split(df, 0.8)
Out[106]: 
(           loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
444         20170515    bbb             12      M        40000        1
444         20170819    zzz              4      M        30000        1
555         20170304    yyy              6      F        40000        0,
loan_date loanID  loan_duration gender  loan_amount  default
customerID                                                             
111         20170101    aaa              6      F        20000        0
111         20170701    fff              3      F        10000        1
111         20170814    kkk             17      F        30000        1
222         20170301    ccc             12      M        30000        0
222         20170905    eee              3      M        20000        0
333         20170415    ddd              5      F        10000        0)

-因此,customerID被放置在训练数据或测试数据中。我想这样的裂缝(火车>测试)是因为输入数据的大小很小
如果您不需要按"customerID活动"进行分组,则可以省略它,只需使用split_group即可实现目标。

相关内容

  • 没有找到相关文章

最新更新