问题: 创建最有效的函数,将 1D 数组(group_id列(转换为另一个 1D 数组(输出列(。
条件是:
-
最多
n
组可以位于任何批次中,在此示例中为n=2
。 -
每个批次必须包含相同大小的组。
-
琐碎条件:尽量减少批次数量。
该函数会将这些不同大小的组分发到具有唯一标识符的批次中,条件是每个批次具有固定大小,并且每个批次仅包含具有相同大小的组。
data = {'group_size': [1,2,3,1,2,3,4,5,1,2,1,1,1],
'batch_id': [1,4,6,1,4,6,7,8,2,5,2,3,3]}
df = pd.DataFrame(data=data)
print(df)
group_size batch_id
0 1 1
1 2 4
2 3 6
3 1 1
4 2 4
5 3 6
6 4 7
7 5 8
8 1 2
9 2 5
10 1 2
11 1 3
12 1 3
我需要什么:
some_function( data['group_size'] )
给我data['batch_id']
编辑:
我的笨拙功能
def generate_array():
out = 1
batch_size = 2
dictionary = {}
for i in range(df['group_size'].max()):
# get the mini df corresponding to the group size
sub_df = df[df['group_size'] == i+1 ]
# how many batches will we create?
no_of_new_batches = np.ceil ( sub_df.shape[0] / batch_size )
# create new array
a = np.repeat(np.arange(out, out+no_of_new_batches ), batch_size)
shift = len(a) - sub_df.shape[0]
# remove last elements from array to match the size
if len(a) != sub_df.shape[0]:
a = a[0:-shift]
# update batch id
out = out + no_of_new_batches
# create dictionary to store idx
indexes = sub_df.index.values
d = dict(zip(indexes, a))
dictionary.update(d)
array = [dictionary[i] for i in range(len(dictionary))]
return array
generate_array()
Out[78]:
[1.0, 4.0, 6.0, 1.0, 4.0, 6.0, 7.0, 8.0, 2.0, 5.0, 2.0, 3.0, 3.0]
这是我的解决方案。我不认为它给出的结果与您的函数完全相同,但它满足您的三个规则:
import numpy as np
def package(data, mxsz):
idx = data.argsort()
ds = data[idx]
chng = np.empty((ds.size + 1,), bool)
chng[0] = True
chng[-1] = True
chng[1:-1] = ds[1:] != ds[:-1]
szs = np.diff(*np.where(chng))
corr = (-szs) % mxsz
result = np.empty_like(idx)
result[idx] = (np.arange(idx.size) + corr.cumsum().repeat(szs)) // mxsz
return result
data = np.random.randint(0, 4, (20,))
result = package(data, 3)
print(f'group_size {data}')
print(f'batch_id {result}')
check = np.lexsort((data, result))
print('sorted:')
print(f'group_size {data[check]}')
print(f'batch_id {result[check]}')
样本运行时 n=3,输出的最后两行与前两行相同,只是排序是为了便于检查:
group_size [1 1 0 1 2 0 2 2 2 3 1 2 3 2 1 0 1 0 2 0]
batch_id [3 3 1 3 6 1 6 5 6 7 2 5 7 5 2 1 2 0 4 0]
sorted:
group_size [0 0 0 0 0 1 1 1 1 1 1 2 2 2 2 2 2 2 3 3]
batch_id [0 0 1 1 1 2 2 2 3 3 3 4 5 5 5 6 6 6 7 7]
工作原理:
1( 对数据进行排序
2( 检测排序数据的变化位置,以识别值相等的组("组大小组"(
3(确定组大小组的大小,并为每个组计算遗漏的内容到n的干净倍数
4(枚举排序后的数据,同时在每次切换到一组新的组大小,跳转到n的下一个干净倍数;我们使用 (3( 以矢量化的方式执行此操作
5( 楼层除以 n 得到批次 ID
6( 洗牌回原始顺序