string .join(iterable)方法在Python中是如何实现的/线性时间字符串连接 &



我正在尝试在Python中实现我自己的str.join方法,例如:''.join(['aa','bbb','cccc'])返回'aabbbcccc'。我知道使用连接方法的字符串连接会导致线性(结果的字符数)复杂性,我想知道如何做到这一点,因为在for循环中使用'+'操作符会导致二次复杂度,例如:

res=''
for word in ['aa','bbb','cccc']:
res = res +  word

由于字符串是不可变的,因此在每次迭代时复制一个新字符串,从而导致二次运行时间。但是,我想知道如何在线性时间内做到这一点,或者找到''.join的确切工作原理。

我找不到任何线性时间算法,也没有实现str.join(iterable)。如有任何帮助,不胜感激。

str加入实际的str是一种转移注意力的行为,而不是Python本身所做的:Python操作可变的bytes,而不是str,这也消除了了解字符串内部结构的需要。具体而言,str.join将其参数转换为字节,然后预分配并改变的结果。

这直接对应于:

  1. 一个包装器,用于将str参数编码/解码到bytes
  2. 元素和分隔符的len之和
  3. 分配一个可变的bytesarray来构造结果
  4. 将每个元素/分隔符直接复制到结果
# helper to convert to/from joinable bytes
def str_join(sep: "str", elements: "list[str]") -> "str":
joined_bytes = bytes_join(
sep.encode(),
[elem.encode() for elem in elements],
)
return joined_bytes.decode()
# actual joining at bytes level
def bytes_join(sep: "bytes", elements: "list[bytes]") -> "bytes":
# create a mutable buffer that is long enough to hold the result
total_length = sum(len(elem) for elem in elements)
total_length += (len(elements) - 1) * len(sep)
result = bytearray(total_length)
# copy all characters from the inputs to the result
insert_idx = 0
for elem in elements:
result[insert_idx:insert_idx+len(elem)] = elem
insert_idx += len(elem)
if insert_idx < total_length:
result[insert_idx:insert_idx+len(sep)] = sep
insert_idx += len(sep)
return bytes(result)
print(str_join(" ", ["Hello", "World!"]))

值得注意的是,虽然元素迭代和元素复制基本上是两个嵌套循环,但它们迭代的对象是不同的。该算法仍然只接触每个字符/字节三次/一次。

最新更新