str.join(iterable)方法如何在Python/线性时间字符串连接中实现
How is str.join(iterable) method implemented in Python/ Linear time string concatenation
我正在尝试在 Python 中实现我自己的 str.join
方法,例如:
''.join(['aa','bbb','cccc'])
returns 'aabbbcccc'
。我知道使用 join 方法的字符串连接会导致线性(结果的字符数)复杂性,我想知道如何做到这一点,因为在 for 循环中使用 '+'
运算符会导致二次复杂度,例如:
res=''
for word in ['aa','bbb','cccc']:
res = res + word
由于字符串是不可变的,因此每次迭代都会复制一个新字符串,从而导致二次 运行 时间。但是,我想知道如何在线性时间内完成或找到 ''.join
的工作原理。
我在任何地方都找不到线性时间算法,也找不到 str.join(iterable) 的实现。非常感谢任何帮助。
实际加入 str
str
是一个转移注意力的问题,not what Python itself does: Python operates on mutable bytes
, not the str
, which also removes the need to know string internals. In specific, str.join
converts its arguments to bytes, then pre-allocates and mutates its result。
这直接对应于:
- encode/decode
str
参数的包装器 to/from bytes
- 对元素和分隔符的
len
求和
- 分配一个可变的
bytesarray
来构造结果
- 将每个 element/separator 直接复制到结果中
# helper to convert to/from joinable bytes
def str_join(sep: "str", elements: "list[str]") -> "str":
joined_bytes = bytes_join(
sep.encode(),
[elem.encode() for elem in elements],
)
return joined_bytes.decode()
# actual joining at bytes level
def bytes_join(sep: "bytes", elements: "list[bytes]") -> "bytes":
# create a mutable buffer that is long enough to hold the result
total_length = sum(len(elem) for elem in elements)
total_length += (len(elements) - 1) * len(sep)
result = bytearray(total_length)
# copy all characters from the inputs to the result
insert_idx = 0
for elem in elements:
result[insert_idx:insert_idx+len(elem)] = elem
insert_idx += len(elem)
if insert_idx < total_length:
result[insert_idx:insert_idx+len(sep)] = sep
insert_idx += len(sep)
return bytes(result)
print(str_join(" ", ["Hello", "World!"]))
值得注意的是,虽然元素迭代和元素复制基本上是两个嵌套循环,但它们迭代不同的事物。该算法仍然仅触及每个 character/byte thrice/once.
我正在尝试在 Python 中实现我自己的 str.join
方法,例如:
''.join(['aa','bbb','cccc'])
returns 'aabbbcccc'
。我知道使用 join 方法的字符串连接会导致线性(结果的字符数)复杂性,我想知道如何做到这一点,因为在 for 循环中使用 '+'
运算符会导致二次复杂度,例如:
res=''
for word in ['aa','bbb','cccc']:
res = res + word
由于字符串是不可变的,因此每次迭代都会复制一个新字符串,从而导致二次 运行 时间。但是,我想知道如何在线性时间内完成或找到 ''.join
的工作原理。
我在任何地方都找不到线性时间算法,也找不到 str.join(iterable) 的实现。非常感谢任何帮助。
实际加入 str
str
是一个转移注意力的问题,not what Python itself does: Python operates on mutable bytes
, not the str
, which also removes the need to know string internals. In specific, str.join
converts its arguments to bytes, then pre-allocates and mutates its result。
这直接对应于:
- encode/decode
str
参数的包装器 to/frombytes
- 对元素和分隔符的
len
求和 - 分配一个可变的
bytesarray
来构造结果 - 将每个 element/separator 直接复制到结果中
# helper to convert to/from joinable bytes
def str_join(sep: "str", elements: "list[str]") -> "str":
joined_bytes = bytes_join(
sep.encode(),
[elem.encode() for elem in elements],
)
return joined_bytes.decode()
# actual joining at bytes level
def bytes_join(sep: "bytes", elements: "list[bytes]") -> "bytes":
# create a mutable buffer that is long enough to hold the result
total_length = sum(len(elem) for elem in elements)
total_length += (len(elements) - 1) * len(sep)
result = bytearray(total_length)
# copy all characters from the inputs to the result
insert_idx = 0
for elem in elements:
result[insert_idx:insert_idx+len(elem)] = elem
insert_idx += len(elem)
if insert_idx < total_length:
result[insert_idx:insert_idx+len(sep)] = sep
insert_idx += len(sep)
return bytes(result)
print(str_join(" ", ["Hello", "World!"]))
值得注意的是,虽然元素迭代和元素复制基本上是两个嵌套循环,但它们迭代不同的事物。该算法仍然仅触及每个 character/byte thrice/once.