如何更有效地将一串切片转换为切片对象,然后可用于在 PyTorch / NumPy 中对数组和张量进行切片?



我如何简化这个函数,该函数将PyTorch/NumPy的切片字符串转换为切片列表对象,然后可以用于切片数组&张量?

下面的代码是有效的,但就需要多少行而言,它似乎效率很低。

def str_to_slice_indices(slicing_str: str):
# Convert indices to lists
indices = [
[i if i else None for i in indice_set.strip().split(":")]
for indice_set in slicing_str.strip("[]").split(",")
]
# Handle Ellipsis "..."
indices = [
... if index_slice == ["..."] else index_slice for index_slice in indices
]
# Handle "None" values
indices = [
None if index_slice == ["None"] else index_slice for index_slice in indices
]
# Handle single number values
indices = [
int(index_slice[0])
if isinstance(index_slice, list)
and len(index_slice) == 1
and index_slice[0].lstrip("-").isdigit()
else index_slice
for index_slice in indices
]
# Create indice slicing list
indices = [
slice(*[int(i) if i and i.lstrip("-").isdigit() else None for i in index_slice])
if isinstance(index_slice, list)
else index_slice
for index_slice in indices
]
return indices

通过一个涵盖各种类型输入的示例运行上述函数,给出以下内容:

out = str_to_slice_indices("[None, :1, 3:4, 2, :, 2:, ...]")
print(out)
# out:
# [None, slice(None, 1, None), slice(3, 4, None), 2, slice(None, None, None), slice(2, None, None), Ellipsis]

不需要多次迭代。为了测试更多的案例,对示例字符串进行了略微扩展。

def str2slices(s):
d = {True: lambda e: slice(*[int(i) if i else None for i in e.split(':')]),
'None': lambda e: None,
'...': lambda e: ...}
return [d.get(':' in e or e.strip(), lambda e: int(e))(e.strip()) for e in s[1:-1].split(',')]
str2slices('[None, :1, 3:4, 2, :, -10: ,::,:4:2, 1:10:2, -32,...]')

输出

[None,
slice(None, 1, None),
slice(3, 4, None),
2,
slice(None, None, None),
slice(-10, None, None),
slice(None, None, None),
slice(None, 4, 2),
slice(1, 10, 2),
-32,
Ellipsis]

捕获到与OP解决方案中相同的错误。它们不会静默地更改结果,而是为不受支持的输入抛出ValueError


溶液分解

假设CCD_ 2切片和CCD_ 3函数是已知的。

以为例

s = '[None, :1, 3:4, 2, :, -10: ,::,:4:2, 1:10:2, -32,...]'

我们可以用找到slices

[':' in e for e in s[1:-1].split(',')]
#[False, True, True, False, True, True, True, True, True, False, False]

使用or短路可以区分的其他情况

[':' in e or e.strip() for e in s[1:-1].split(',')]
#['None', True, True, '2', True, True, True, True, True, '-32', '...']

此值可用作dictionary的密钥

d = {True: 'slice', 'None': None, '...': ...}
[d[':' in e or e.strip()] for e in s[1:-1].split(',')]
#KeyError: '2'

为了防止KeyError,我们可以使用具有默认值的get方法。

d = {True: 'slice', 'None': None, '...': ...}
[d.get(':' in e or e.strip(), 'number') for e in s[1:-1].split(',')]
#[None, 'slice', 'slice', 'number', 'slice', 'slice', 'slice', 'slice', 'slice', 'number', Ellipsis]

为了处理slices,我们需要解析附加值​​在运行时。所以我们使用lambdas作为字典值​​以便能够用CCD_ 11呼叫它们。最后,我们转换值​​如有必要,将其发送至CCD_ 12。

d = {True: lambda e: slice(*[int(i) if i else None for i in e.split(':')]),
'None': lambda e: None,
'...': lambda e: ...}
[d.get(':' in e or e.strip(), lambda e: int(e))(e.strip()) for e in s[1:-1].split(',')]

输出

[None,
slice(None, 1, None),
slice(3, 4, None),
2,
slice(None, None, None),
slice(-10, None, None),
slice(None, None, None),
slice(None, 4, 2),
slice(1, 10, 2),
-32,
Ellipsis]

@Michael建议在np.s_上使用eval

另一种演示方法是定义一个只接受getitemtuple:的简单类

In [83]: class Foo():
...:     def __getitem__(self, arg):
...:         print(arg)
...: 
In [84]: Foo()[None, :1, 3:4, 2, :, 2:, ...]
(None, slice(None, 1, None), slice(3, 4, None), 2, slice(None, None, None), slice(2, None, None), Ellipsis)

在正常的Python使用中,它是将"::"类型的字符串转换为slice(以及相关对象(的解释器。它只在索引表达式中这样做。实际上,您的代码试图复制解释器通常所做的工作。

我对eval安全问题关注不够,不知道您需要添加什么。看来索引语法是相当严格的,因为它是

看起来不符合slicestring0语法的字符串在传递时没有改变,也没有赋值。

In [90]: Foo()['if x is 1:print(x)']
if x is 1:print(x)

我的Foonp.s_不尝试评估__getitem__传递给它们的元组。np.s_几乎同样简单(代码是用来查找和读取的(。

通常ast.literal_eval被用作eval的"更安全"的替代品,但它只处理strings, bytes, numbers, tuples, lists, dicts, sets, booleans, and None

相关内容

  • 没有找到相关文章

最新更新