我有一个3d数组,看起来像这样形状为(2001,128,128)
array([[[48, 48, 48, ..., 48, 48, 48],
[48, 48, 48, ..., 48, 48, 48],
[48, 48, 48, ..., 48, 48, 48],
...,
[[12, 12, 12, ..., 12, 12, 12],
[12, 12, 12, ..., 12, 12, 12],
[12, 12, 12, ..., 12, 12, 12],
...,
[19, 19, 19, ..., 12, 12, 12],
[19, 19, 19, ..., 19, 12, 12],
[19, 19, 19, ..., 19, 19, 19]],
我有一个像这样的字典
{1: [1, 39],
2: [2, 5, 9, 20, 32, 42, 47, 72, 88, 91, 95],
3: [3, 49, 55],
4: [4, 24, 34, 40, 53, 76, 81, 90, 96],
5: [6, 17, 30, 48, 83],
6: [7, 13, 15, 16, 27, 44, 51, 54, 56, 75],
7: [8, 50],
8: [10, 19, 22, 35, 61, 63, 65],
9: [11, 12, 21, 46, 52, 69, 78, 84, 89],
10: [14, 36, 74],
11: [18],
12: [23, 38, 66, 97],
13: [25],
14: [26, 28, 29, 62, 64, 86, 94],
15: [31, 59, 85],
16: [33, 80],
17: [37, 45, 60],
18: [41, 92, 93],
19: [43, 77, 79, 82],
20: [57, 67],
21: [58],
22: [68],
23: [70],
24: [71],
25: [73, 87],
0: [0]}
那么我要做的是如果数组value = dict value将数组值更改为键,就像这样->
array([[[5, 5, 5, ..., 5, 5, 5],
[5, 5, 5, ..., 5, 5, 5],
[5, 5, 5, ..., 5, 5, 5],
...,
[9, 9, 9, ..., 9, 9, 9],
[9, 9, 9, ..., 9, 9, 9],
[9, 9, 9, ..., 9, 9, 9]],
...,
[8, 8, 8, ..., 9, 9, 9],
[8, 8, 8, ..., 8, 9, 9],
[8, 8, 8, ..., 8, 8, 8]],
因为48在键5中,
你应该把原来的字典倒过来:
lookup_dict = {1: [1, 39],
2: [2, 5, 9, 20, 32, 42, 47, 72, 88, 91, 95],
3: [3, 49, 55],
4: [4, 24, 34, 40, 53, 76, 81, 90, 96],
5: [6, 17, 30, 48, 83],
6: [7, 13, 15, 16, 27, 44, 51, 54, 56, 75],
7: [8, 50],
8: [10, 19, 22, 35, 61, 63, 65],
9: [11, 12, 21, 46, 52, 69, 78, 84, 89],
10: [14, 36, 74],
11: [18],
12: [23, 38, 66, 97],
13: [25],
14: [26, 28, 29, 62, 64, 86, 94],
15: [31, 59, 85],
16: [33, 80],
17: [37, 45, 60],
18: [41, 92, 93],
19: [43, 77, 79, 82],
20: [57, 67],
21: [58],
22: [68],
23: [70],
24: [71],
25: [73, 87],
0: [0]}
reversed_dict = {val: key for key, lst in lookup_dict.items() for val in lst}
现在,您可以遍历输入数组并在从reversed_dict
查找后将每个项设置为新数组,这已经比JRiggles的答案更有效,因为您不需要遍历所有列表以找到新值。
但是,如果您将这个reversed_dict
的值放入数组中,使得字典中的键是数组中的索引,那么您可以简单地使用numpy的内置广播功能来索引数组并获得正确形状的结果。我更喜欢这种方法,因为它更快:
max_index = max(reversed_dict.keys())
lookup_array = np.zeros((max_index+1,))
for k, v in reversed_dict.items():
lookup_array[k] = v
最后:
input_array = np.array([[[48, 48, 48, 48, 48, 48],
[48, 48, 48, 48, 48, 48],
[48, 48, 48, 48, 48, 48]],
[[12, 12, 12, 12, 12, 12],
[12, 12, 12, 12, 12, 12],
[12, 12, 12, 12, 12, 12]],
[[19, 19, 19, 12, 12, 12],
[19, 19, 19, 19, 12, 12],
[19, 19, 19, 19, 19, 19]]])
output_array = lookup_array[input_array]
给了:
array([[[5., 5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5., 5.]],
[[9., 9., 9., 9., 9., 9.],
[9., 9., 9., 9., 9., 9.],
[9., 9., 9., 9., 9., 9.]],
[[8., 8., 8., 9., 9., 9.],
[8., 8., 8., 8., 9., 9.],
[8., 8., 8., 8., 8., 8.]]])
这种方法的优点是它对的任何形状的input_array
都可以按原样工作,并且超级快!。
三种方法的定时:
- JRiggles的回答,
func1
- 从反向字典中查找值,
func2
- 索引到新的numpy数组,
func3
import timeit
input_array = np.random.randint(0, max_index, (100, 100, 100))
def get_key(search_value):
for key, num_list in lookup_dict.items():
if search_value in num_list:
return key
def func1(arr):
arr = np.copy(arr)
for outer_lst in arr:
for sub_list in outer_lst:
for index, value in enumerate(sub_list):
new_val = get_key(value) # get the key from 'dict'
sub_list[index] = new_val # replace old subarray value
return arr
def func2(arr):
arr = np.copy(arr)
for outer_lst in arr:
for sub_list in outer_lst:
for index, value in enumerate(sub_list):
new_val = reversed_dict[value]
sub_list[index] = new_val
return arr
def func3(arr):
return lookup_array[arr]
t1 = timeit.timeit("func1(input_array)", globals=globals(), number=2)
print("t1 =", t1)
t2 = timeit.timeit("func2(input_array)", globals=globals(), number=2)
print("t2 =", t2)
t3 = timeit.timeit("func3(input_array)", globals=globals(), number=2)
print("t3 =", t3)
在我的电脑上,这是:
t1 = 25.02508409996517
t2 = 1.2259434000588953
t3 = 0.01203500002156943
也就是说
- JRiggles的方法比反向字典和在反向字典中查找值慢20倍
- JRiggles方法是2000比创建数组和使用numpy索引数组慢x倍。
这是一个测试数组,它包含的元素比输入数组少300x。对于您的数组,节省的时间将显著大。
arr = [ # Example list of lists - arbitrary values
[11, 11, 12, 13],
[24, 24, 24, 35],
[16, 27, 27, 8]
]
dictionary = {
1: [1, 39],
2: [2, 5, 9, 20, 32, 42, 47, 72, 88, 91, 95],
3: [3, 49, 55],
4: [4, 24, 34, 40, 53, 76, 81, 90, 96],
5: [6, 17, 30, 48, 83],
6: [7, 13, 15, 16, 27, 44, 51, 54, 56, 75],
7: [8, 50],
8: [10, 19, 22, 35, 61, 63, 65],
9: [11, 12, 21, 46, 52, 69, 78, 84, 89],
10: [14, 36, 74],
11: [18],
12: [23, 38, 66, 97],
13: [25],
14: [26, 28, 29, 62, 64, 86, 94],
15: [31, 59, 85],
16: [33, 80],
17: [37, 45, 60],
18: [41, 92, 93],
19: [43, 77, 79, 82],
20: [57, 67],
21: [58],
22: [68],
23: [70],
24: [71],
25: [73, 87],
0: [0]
}
def get_key(search_value):
for key, num_list in dictionary.items():
if search_value in num_list:
return key
for sub_list in arr:
for index, value in enumerate(sub_list):
new_val = get_key(value) # get the key from 'dict'
sub_list[index] = new_val # replace old subarray value
print(arr) # QED - see new array below
# [
# [9, 9, 9, 6],
# [4, 4, 4, 8],
# [6, 6, 6, 7]
# ]