是否有一种内置的方法可以实现这一点。
我更喜欢避免使用pd.concat([...],1)
.all(1)
方法,因为我正在使用的数据集缺少数据点。
main.py
import pandas as pd
import numpy as np
import numpy.typing as npt
def _index_mask(index_a: pd.Index, index_b: pd.Index) -> npt.NDArray[np.bool_]:
return index_b.isin(index_a[index_a.isin(index_b)])
def mask_b(a: pd.DataFrame, b: pd.DataFrame) -> pd.DataFrame:
return b[_index_mask(a.index, b.index)]
if __name__ == '__main__':
frame_a = pd.DataFrame(
np.arange(10).reshape(5, 2),
index=["A", "B", "C", "D", "E"]
)
frame_b = pd.DataFrame(
np.arange(16).reshape(8, 2),
index=["F", "G", "H", "C", "D", "E", "I", "J"]
)
x = mask_b(frame_a, frame_b)
print(x)
编辑
我忘了提到我还需要对frame_a
执行反向操作
def _index_mask(index_a: pd.Index, index_b: pd.Index) -> tuple[npt.NDArray[np.bool_],npt.NDArray[np.bool_]]:
return index_a.isin(index_b), index_b.isin(index_a)
mask_a, mask_b = _index_mask(frame_a.index, frame_b.index)
frame_a = frame_a[mask_a]
frame_b = frame_b[mask_b]
assert all(frame_b.index == frame_a.index)
结果
0 1
C 6 7
D 8 9
E 10 11
我认为您需要pd.Index.intersection
:
x = frame_b.loc[frame_a.index.intersection(frame_b.index)]
输出:
>>> x
0 1
C 6 7
D 8 9
E 10 11