识别/删除 pandas 数据帧中的冗余列



如果我有一个如下所示的数据帧:

df = pd.DataFrame({'val':['a','b','c','d','e','f','g','h'],
'cat':['C','D','D','C','D','D','D','C'],
'num':[1,2,2,1,2,2,2,1],
'cat2':['X','Y','Y','X','Y','Y','Y','X']})

给:

val cat  num cat2
0   a   C    1    X
1   b   D    2    Y
2   c   D    2    Y
3   d   C    1    X
4   e   D    2    Y
5   f   D    2    Y
6   g   D    2    Y
7   h   C    1    X

您会注意到,我们可以确定列numcat2是多余的,因为catnumcat2行中的值始终在列中匹配:C == 1 == XD == 2 == Y

我想确定冗余的列,最终丢弃它们并且只有一个表示形式,如下所示。numcat2而不是cat在那里也很好。

val cat
0   a   C
1   b   D
2   c   D
3   d   C
4   e   D
5   f   D
6   g   D
7   h   C

我想不出不涉及嵌套循环的解决方案,嵌套循环随着更多列而成倍增加,我怀疑可能有一种聪明的方法来解决这个问题。我看到的关于冗余数据的其他问题通常是在值相等时处理。

谢谢!

你可以用factorize检查,然后drop_duplicates

out = df.loc[:,df.transform(lambda x : x.factorize()[0]).T.drop_duplicates().T.columns]
Out[56]: 
val cat
0   a   C
1   b   D
2   c   D
3   d   C
4   e   D
5   f   D
6   g   D
7   h   C

为了获得基于pandas.factorize的更快方法,对生成的数组进行哈希处理并将其用作字典键:

df[{hash(pd.factorize(df[c])[0].data.tobytes()): c for c in df.columns[::-1]}.values()]

注意:如果您有重复的列名,请改用:

df.iloc[:, list({hash(pd.factorize(df.iloc[:, i])[0].data.tobytes()): i
for i in range(df.shape[1])}.values())]

输出:

cat val
0   C   a
1   D   b
2   D   c
3   C   d
4   D   e
5   D   f
6   D   g
7   C   h

执行速度:

# factorize + hash
737 µs ± 41.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# factorize + drop_duplicates
2.25 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

在 20 倍以上的列上:

# factorize + hash
6.53 ms ± 395 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# factorize + drop_duplicates
13 ms ± 781 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

您可以生成数据帧列的幂集,并检查所有行的列集是否重复。如果是这样,则意味着此集中的列是冗余的。

from itertools import chain, combinations
def powerset(iterable):
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
redundant_cols = list(filter(lambda cols: (len(cols) > 1) and (df.duplicated(cols, keep=False).all()),
(powerset(df.columns))))
print(redundant_cols)
[('cat', 'num'), ('cat', 'cat2'), ('num', 'cat2'), ('cat', 'num', 'cat2')]

然后最后一个也是最长的包含所有冗余列。

out = df.drop(columns=list(redundant_cols[-1][1:]))
print(out)
val cat
0   a   C
1   b   D
2   c   D
3   d   C
4   e   D
5   f   D
6   g   D
7   h   C

您可以使用scikit learn LabelEncoder生成每列的编码版本。这样可以轻松比较所有列。使用LabelEncoder,我生成了这个版本的数据帧:

val cat num cat2
0   0   0   0   0
1   1   1   1   1
2   2   1   1   1
3   3   0   0   0
4   4   1   1   1
5   5   1   1   1
6   6   1   1   1
7   7   0   0   0

下面给出了比较和查找相同列并删除它们的完整代码:

import pandas as pd
from sklearn.preprocessing import LabelEncoder
df = pd.DataFrame({'val':['a','b','c','d','e','f','g','h'],
'cat':['C','D','D','C','D','D','D','C'],
'num':[1,2,2,1,2,2,2,1],
'cat2':['X','Y','Y','X','Y','Y','Y','X']})
# labelencoder will encode all unique values to numbers that will make it easy to compare
temp_df = pd.DataFrame()
le = LabelEncoder()
for column_name in df.columns:
temp_df[column_name] = le.fit_transform(df[column_name])
# now compare the columns to see which one are same
# this double for loop can defnitely be improved, it makes redundant comparisons
duplicate_columns = []
for column_name in df.columns:
for compare_column_name in df.columns:
if column_name != compare_column_name and column_name not in duplicate_columns:
if temp_df[column_name].equals(temp_df[compare_column_name]):
duplicate_columns.append(compare_column_name)
# now you have the duplicate columns just remove them
df.drop(duplicate_columns, axis=1, inplace=True)

这实际上是此处的副本,其中使用相关矩阵,并且可以传递自定义阈值来过滤类似的列(在本例中为 .95,但您也可以使用 .999)。

您唯一缺少的是将文本转换为数值,您可以使用因式分解,类别代码或pandas或sklearn等其他许多方法中的任何一种来完成。

import pandas as pd
df = pd.DataFrame({'val':['a','b','c','d','e','f','g','h'],
'cat':['C','D','D','C','D','D','D','C'],
'num':[1,2,2,1,2,2,2,1],
'cat2':['X','Y','Y','X','Y','Y','Y','X']})

dfc = df.astype('category').apply(lambda x: x.cat.codes).corr().abs()
upper = dfc.where(np.triu(np.ones(dfc.shape), k=1).astype(bool))
to_drop = [column for column in upper.columns if any(upper[column] > 0.95)]
df.drop(to_drop, axis=1, inplace=True)
print(df)

输出

val cat
0   a   C
1   b   D
2   c   D
3   d   C
4   e   D
5   f   D
6   g   D
7   h   C

最新更新