如果数据帧的索引在某个自定义类中,如何定义"in"来完成函数?



我已经在一些自定义类中实现了包含,例如

class A:
def __init__(self):
self.l = [1,2,3]
def __contain__(self, i:int):
if i in self.l:
return True
return False

它可以很好地与单一元素配合使用

if 1 in A:
return True

但现在我想做一些类似的事情:

df = pd.DataFrame(np.random.randn(10,10))
a = df[df.index in A]

获取索引在A中的行(也就是说索引在[1,2,3]中(但它向我显示了诸如"TypeError:类型为"A"的参数不可迭代"之类的错误

我知道可以通过表格完成

a = df[[id for id in df.index if id in A]]

但我想知道是否有像df[df.index in A]这样的形式,因为它看起来既漂亮又高效~~

每当我试图使__contains__返回可迭代时,我只得到一个bool

import pandas as pd
from typing import Iterable, Union
class A:
def __init__(self):
self.l = [1,2,3]
def __contains__(self, i:Union[int,Iterable]):
if isinstance(i, Iterable):
return [j in self.l for j in i]
elif i in self.l:
return True
return False
a = A()
df = pd.DataFrame(np.random.randn(10,10))
print(df.index in a)

输出:

True

看起来python隐式地将bool应用于来自__contains__的任何内容
不过,您可以使用类似Series的接口来实现它

import pandas as pd
from typing import Iterable, Union
class A:
def __init__(self):
self.l = [1,2,3]
def isin(self,i:Iterable):
return [j in self.l for j in i]
a = A()
df = pd.DataFrame(np.random.randn(10,10))
print(df[a.isin(df.index)])

输出:

0         1         2         3         4         5         6  
1 -0.899868  0.830076  1.106072 -1.664480  1.291234  0.257702 -1.486293   
2  1.060163  1.143478  0.861907  1.480999 -1.238395 -0.130496 -0.441712   
3  1.176099  0.105020  0.502756  0.993179  1.561893  1.036998  0.551943   
7         8         9  
1  0.394313  0.434380 -1.554062  
2 -2.538269  0.188291 -0.451774  
3 -0.342378 -0.779410 -1.491517  

相关内容

最新更新