带有列MultiIndex的DataFrame,高级to_dict



我有一个pandas数据框架,它的结构如下所示:

import numpy as np
import pandas as pd 
import pprint as pp
np.random.seed(0)
times = np.linspace(0, 3.0, num=5)
positions = np.linspace(0, 0.1, num=8)
fields = ["g", "h"]
columns = pd.MultiIndex.from_product([times, fields], names=["time", "field"])
index = pd.Index(positions, name="position")

data = np.random.randn(len(positions), len(times)*len(fields))
df = pd.DataFrame(data, columns=columns, index=index)
print(df)

看起来像:

time          0.00                0.75                1.50                2.25                3.00          
field            g         h         g         h         g         h         g         h         g         h
position                                                                                                    
0.000000  1.764052  0.400157  0.978738  2.240893  1.867558 -0.977278  0.950088 -0.151357 -0.103219  0.410599
0.014286  0.144044  1.454274  0.761038  0.121675  0.443863  0.333674  1.494079 -0.205158  0.313068 -0.854096
0.028571 -2.552990  0.653619  0.864436 -0.742165  2.269755 -1.454366  0.045759 -0.187184  1.532779  1.469359
0.042857  0.154947  0.378163 -0.887786 -1.980796 -0.347912  0.156349  1.230291  1.202380 -0.387327 -0.302303
0.057143 -1.048553 -1.420018 -1.706270  1.950775 -0.509652 -0.438074 -1.252795  0.777490 -1.613898 -0.212740
0.071429 -0.895467  0.386902 -0.510805 -1.180632 -0.028182  0.428332  0.066517  0.302472 -0.634322 -0.362741
0.085714 -0.672460 -0.359553 -0.813146 -1.726283  0.177426 -0.401781 -1.630198  0.462782 -0.907298  0.051945
0.100000  0.729091  0.128983  1.139401 -1.234826  0.402342 -0.684810 -0.870797 -0.578850 -0.311553  0.056165

的想法是,我有一个MultiIndex列:在第一级,我有一个列表"时间"对于每一个"时间"我有多个"字段"。

对于实际场景,"位置"&"时间"&";和";fields">

我的目标是将这个数据帧转换成一个字典,每"时间"进行分组。给定"字段"的;作为数组。

为了更清楚,我想生成这样的东西:

{'g': array([[ 1.76405235,  0.97873798,  1.86755799,  0.95008842, -0.10321885],
[ 0.14404357,  0.76103773,  0.44386323,  1.49407907,  0.3130677 ],
[-2.55298982,  0.8644362 ,  2.26975462,  0.04575852,  1.53277921],
[ 0.15494743, -0.88778575, -0.34791215,  1.23029068, -0.38732682],
[-1.04855297, -1.70627019, -0.50965218, -1.25279536, -1.61389785],
[-0.89546656, -0.51080514, -0.02818223,  0.06651722, -0.63432209],
[-0.67246045, -0.81314628,  0.17742614, -1.63019835, -0.90729836],
[ 0.72909056,  1.13940068,  0.40234164, -0.87079715, -0.31155253]]),
'h': array([[ 0.40015721,  2.2408932 , -0.97727788, -0.15135721,  0.4105985 ],
[ 1.45427351,  0.12167502,  0.33367433, -0.20515826, -0.85409574],
[ 0.6536186 , -0.74216502, -1.45436567, -0.18718385,  1.46935877],
[ 0.37816252, -1.98079647,  0.15634897,  1.20237985, -0.30230275],
[-1.42001794,  1.9507754 , -0.4380743 ,  0.77749036, -0.21274028],
[ 0.3869025 , -1.18063218,  0.42833187,  0.3024719 , -0.36274117],
[-0.35955316, -1.7262826 , -0.40178094,  0.46278226,  0.0519454 ],
[ 0.12898291, -1.23482582, -0.68481009, -0.57884966,  0.05616534]]),
'position': array([0.        , 0.01428571, 0.02857143, 0.04285714, 0.05714286,
0.07142857, 0.08571429, 0.1       ]),
'time': array([0.  , 0.75, 1.5 , 2.25, 3.  ])}

可以通过以下方式手动构建:

output = {'position': positions,
'time': times,
fields[0] : data[:, ::len(fields)],
fields[1] : data[:, 1::len(fields)]
}
pp.pprint(output)

我正在考虑围绕df.to_dict('list')的一些类似于这里描述的方式:https://stackoverflow.com/a/39074579/10812478

您可以对字段数组使用groupby和字典推导式,然后添加其他键:

d = {k: d.to_numpy() for k,d in df.groupby(level='field', axis=1)}
d['position'] = df.index.to_numpy()
d['time'] = df.stack('field').columns.get_level_values('time').to_numpy()

NB。我使用np.random.seed(0)生成输入

输出:

{'g': array([[ 1.76405235,  0.97873798,  1.86755799,  0.95008842, -0.10321885],
[ 0.14404357,  0.76103773,  0.44386323,  1.49407907,  0.3130677 ],
[-2.55298982,  0.8644362 ,  2.26975462,  0.04575852,  1.53277921],
[ 0.15494743, -0.88778575, -0.34791215,  1.23029068, -0.38732682],
[-1.04855297, -1.70627019, -0.50965218, -1.25279536, -1.61389785],
[-0.89546656, -0.51080514, -0.02818223,  0.06651722, -0.63432209],
[-0.67246045, -0.81314628,  0.17742614, -1.63019835, -0.90729836],
[ 0.72909056,  1.13940068,  0.40234164, -0.87079715, -0.31155253]]),
'h': array([[ 0.40015721,  2.2408932 , -0.97727788, -0.15135721,  0.4105985 ],
[ 1.45427351,  0.12167502,  0.33367433, -0.20515826, -0.85409574],
[ 0.6536186 , -0.74216502, -1.45436567, -0.18718385,  1.46935877],
[ 0.37816252, -1.98079647,  0.15634897,  1.20237985, -0.30230275],
[-1.42001794,  1.9507754 , -0.4380743 ,  0.77749036, -0.21274028],
[ 0.3869025 , -1.18063218,  0.42833187,  0.3024719 , -0.36274117],
[-0.35955316, -1.7262826 , -0.40178094,  0.46278226,  0.0519454 ],
[ 0.12898291, -1.23482582, -0.68481009, -0.57884966,  0.05616534]]),
'position': array([0.        , 0.01428571, 0.02857143, 0.04285714, 0.05714286,
0.07142857, 0.08571429, 0.1       ]),
'time': array([0.  , 0.75, 1.5 , 2.25, 3.  ])}

最新更新