通过插入单个数字获得最大可能数字的高效算法
Efficient algorithm to get the largest possible number by inserting a single digit
我刚刚编写了下面的算法,通过插入任何特定数字来获得最大可能的数字。
getLargestPossible
接收两个参数,旨在通过将 numInsertion
插入 num
中的任何位置来找到最大可能的数字。例如:
getLargestPossible(623, 5)
和 getLargestPossible(-482, 5)
returns 分别为 6523 和 -4582。
如何以最有效的方式编写此算法?
def attachRest(list_int, str_num, index):
'''
Converts list_int to list of strings, merges it with str_num
(splitted from index) and returns the result as an integer
'''
s = [str(i) for i in list_int]
str_num = str_num[index:]
s = s + list(str_num)
num = int("".join(s))
return num
def getLargestPossible(num, numInsertion):
'''
Returns a largest possible number by inserting numInsertion into num
e.g.: getLargestPossible(623, 5) returns 6523
getLargestPossible(-482, 5) returns -4582
'''
new_num = []
isNumberInserted = False
if num > 0:
str_num = str(num)
for index, digit in enumerate(str_num):
if numInsertion > int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False: #e.g. if num==666 and numInsertion==5, return 6665
num = num * 10 + numInsertion
print(num)
else:
str_num = str(-1 * num)
for index, digit in enumerate(str_num):
if numInsertion < int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False:
num = num*10 - numInsertion
print(num)
else:
print(-1 * num)
你能不能简化成这样:
def get_largest_insertion(long_n,n):
sn=str(abs(long_n))
n=str(n)
sign='-' if long_n<0 else '+'
return(max(int(sign+sn[0:i]+n+sn[i:]) for i in range(len(sn)+1)))
>>> get_largest_insertion(623, 5)
6523
>>> get_largest_insertion(-482,5)
-4582
>>> get_largest_insertion(666,5)
6665
好的,让我们更快
通过简单的检查,我相信如果数字是负数,插入将永远是以下之一:
- 字符串结束(如
f(-524242, 8): -5242428)
;或
- 左边第一个较大的数字(例如
f(-3251342,2): -23251342
或f(-1251342,2): -12251342
对于正数:
- 字符串结束(如
f(345342,2): 3453422)
;或
- 左边第一个较小的数字(例如
f(3251342,2): 32521342
您可以修改我的函数,使其一次性找到插入点:
def f3(long_n, n):
sn=str(abs(long_n))
sign='-' if long_n<0 else '+'
n=str(n)
if sign=='-':
i=next((i for i,c in enumerate(sn) if c>n), len(sn))
else:
i=next((i for i,c in enumerate(sn) if c<n), len(sn))
return int(sign+sn[0:i]+n+sn[i:])
现在让我们以此为基准:
# your original function is f1
# my original is f2
# new one is f3
import time
# ====================
def attachRest(list_int, str_num, index):
s = [str(i) for i in list_int]
str_num = str_num[index:]
s = s + list(str_num)
num = int("".join(s))
return num
def f1(num, numInsertion):
new_num = []
isNumberInserted = False
if num > 0:
str_num = str(num)
for index, digit in enumerate(str_num):
if numInsertion > int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False:
num = num * 10 + numInsertion
return num
else:
str_num = str(-1 * num)
for index, digit in enumerate(str_num):
if numInsertion < int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False:
num = num*10 - numInsertion
return num
else:
return -1 * num
# ===============
def f2(long_n,n):
sn=str(abs(long_n))
n=str(n)
sign='-' if long_n<0 else '+'
return(max([int(sign+sn[0:i]+n+sn[i:]) for i in range(len(sn)+1)]))
def f3(long_n, n):
sn=str(abs(long_n))
sign='-' if long_n<0 else '+'
n=str(n)
if sign=='-':
i=next((i for i,c in enumerate(sn) if c>n), len(sn))
else:
i=next((i for i,c in enumerate(sn) if c<n), len(sn))
return int(sign+sn[0:i]+n+sn[i:])
# =====
def cmpthese(funcs, args=(), cnt=100, rate=True, micro=True, deepcopy=True):
from copy import deepcopy
"""Generate a Perl style function benchmark"""
def pprint_table(table):
"""Perl style table output"""
def format_field(field, fmt='{:,.0f}'):
if type(field) is str: return field
if type(field) is tuple: return field[1].format(field[0])
return fmt.format(field)
def get_max_col_w(table, index):
return max([len(format_field(row[index])) for row in table])
col_paddings=[get_max_col_w(table, i) for i in range(len(table[0]))]
for i,row in enumerate(table):
# left col
row_tab=[row[0].ljust(col_paddings[0])]
# rest of the cols
row_tab+=[format_field(row[j]).rjust(col_paddings[j]) for j in range(1,len(row))]
print(' '.join(row_tab))
results={}
for i in range(cnt):
for f in funcs:
if args:
local_args=deepcopy(args)
start=time.perf_counter_ns()
f(*local_args)
stop=time.perf_counter_ns()
results.setdefault(f.__name__, []).append(stop-start)
results={k:float(sum(v))/len(v) for k,v in results.items()}
fastest=sorted(results,key=results.get, reverse=True)
table=[['']]
if rate: table[0].append('rate/sec')
if micro: table[0].append('\u03bcsec/pass')
table[0].extend(fastest)
for e in fastest:
tmp=[e]
if rate:
tmp.append('{:,}'.format(int(round(float(cnt)*1000000.0/results[e]))))
if micro:
tmp.append('{:,.1f}'.format(results[e]/float(cnt)))
for x in fastest:
if x==e: tmp.append('--')
else: tmp.append('{:.1%}'.format((results[x]-results[e])/results[e]))
table.append(tmp)
pprint_table(table)
if __name__=='__main__':
import sys
import time
print(sys.version)
funcs=[f1, f2, f3]
cases=(
(-524242,8),
(345342,2),
(-34734573524242,8),
(71347345345342, 2)
)
for ln, n in cases:
for f in funcs:
print(f'{f.__name__}{ln, n}: {f(ln,n)}')
args=(ln, n)
cmpthese(funcs,args)
print()
该基准打印:
3.9.0 (default, Nov 21 2020, 14:55:42)
[Clang 12.0.0 (clang-1200.0.32.27)]
f1(-524242, 8): -5242428
f2(-524242, 8): -5242428
f3(-524242, 8): -5242428
rate/sec μsec/pass f2 f1 f3
f2 19,777 50.6 -- -37.0% -57.1%
f1 31,402 31.8 58.8% -- -32.0%
f3 46,148 21.7 133.3% 47.0% --
f1(345342, 2): 3453422
f2(345342, 2): 3453422
f3(345342, 2): 3453422
rate/sec μsec/pass f2 f1 f3
f2 19,671 50.8 -- -38.4% -56.5%
f1 31,916 31.3 62.2% -- -29.5%
f3 45,260 22.1 130.1% 41.8% --
f1(-34734573524242, 8): -347345735242428
f2(-34734573524242, 8): -347345735242428
f3(-34734573524242, 8): -347345735242428
rate/sec μsec/pass f2 f1 f3
f2 10,331 96.8 -- -38.1% -72.6%
f1 16,689 59.9 61.5% -- -55.7%
f3 37,679 26.5 264.7% 125.8% --
f1(71347345345342, 2): 721347345345342
f2(71347345345342, 2): 721347345345342
f3(71347345345342, 2): 721347345345342
rate/sec μsec/pass f2 f1 f3
f2 10,579 94.5 -- -67.7% -77.4%
f1 32,779 30.5 209.8% -- -29.9%
f3 46,743 21.4 341.8% 42.6% --
同样的方法,但检查与蛮力相比,速度提高了 2 到 3 倍...
我刚刚编写了下面的算法,通过插入任何特定数字来获得最大可能的数字。
getLargestPossible
接收两个参数,旨在通过将 numInsertion
插入 num
中的任何位置来找到最大可能的数字。例如:
getLargestPossible(623, 5)
和 getLargestPossible(-482, 5)
returns 分别为 6523 和 -4582。
如何以最有效的方式编写此算法?
def attachRest(list_int, str_num, index):
'''
Converts list_int to list of strings, merges it with str_num
(splitted from index) and returns the result as an integer
'''
s = [str(i) for i in list_int]
str_num = str_num[index:]
s = s + list(str_num)
num = int("".join(s))
return num
def getLargestPossible(num, numInsertion):
'''
Returns a largest possible number by inserting numInsertion into num
e.g.: getLargestPossible(623, 5) returns 6523
getLargestPossible(-482, 5) returns -4582
'''
new_num = []
isNumberInserted = False
if num > 0:
str_num = str(num)
for index, digit in enumerate(str_num):
if numInsertion > int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False: #e.g. if num==666 and numInsertion==5, return 6665
num = num * 10 + numInsertion
print(num)
else:
str_num = str(-1 * num)
for index, digit in enumerate(str_num):
if numInsertion < int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False:
num = num*10 - numInsertion
print(num)
else:
print(-1 * num)
你能不能简化成这样:
def get_largest_insertion(long_n,n):
sn=str(abs(long_n))
n=str(n)
sign='-' if long_n<0 else '+'
return(max(int(sign+sn[0:i]+n+sn[i:]) for i in range(len(sn)+1)))
>>> get_largest_insertion(623, 5)
6523
>>> get_largest_insertion(-482,5)
-4582
>>> get_largest_insertion(666,5)
6665
好的,让我们更快
通过简单的检查,我相信如果数字是负数,插入将永远是以下之一:
- 字符串结束(如
f(-524242, 8): -5242428)
;或 - 左边第一个较大的数字(例如
f(-3251342,2): -23251342
或f(-1251342,2): -12251342
对于正数:
- 字符串结束(如
f(345342,2): 3453422)
;或 - 左边第一个较小的数字(例如
f(3251342,2): 32521342
您可以修改我的函数,使其一次性找到插入点:
def f3(long_n, n):
sn=str(abs(long_n))
sign='-' if long_n<0 else '+'
n=str(n)
if sign=='-':
i=next((i for i,c in enumerate(sn) if c>n), len(sn))
else:
i=next((i for i,c in enumerate(sn) if c<n), len(sn))
return int(sign+sn[0:i]+n+sn[i:])
现在让我们以此为基准:
# your original function is f1
# my original is f2
# new one is f3
import time
# ====================
def attachRest(list_int, str_num, index):
s = [str(i) for i in list_int]
str_num = str_num[index:]
s = s + list(str_num)
num = int("".join(s))
return num
def f1(num, numInsertion):
new_num = []
isNumberInserted = False
if num > 0:
str_num = str(num)
for index, digit in enumerate(str_num):
if numInsertion > int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False:
num = num * 10 + numInsertion
return num
else:
str_num = str(-1 * num)
for index, digit in enumerate(str_num):
if numInsertion < int(digit):
new_num.append(numInsertion)
num = attachRest(new_num, str_num, index)
isNumberInserted = True
break
else:
new_num.append(int(digit))
if isNumberInserted is False:
num = num*10 - numInsertion
return num
else:
return -1 * num
# ===============
def f2(long_n,n):
sn=str(abs(long_n))
n=str(n)
sign='-' if long_n<0 else '+'
return(max([int(sign+sn[0:i]+n+sn[i:]) for i in range(len(sn)+1)]))
def f3(long_n, n):
sn=str(abs(long_n))
sign='-' if long_n<0 else '+'
n=str(n)
if sign=='-':
i=next((i for i,c in enumerate(sn) if c>n), len(sn))
else:
i=next((i for i,c in enumerate(sn) if c<n), len(sn))
return int(sign+sn[0:i]+n+sn[i:])
# =====
def cmpthese(funcs, args=(), cnt=100, rate=True, micro=True, deepcopy=True):
from copy import deepcopy
"""Generate a Perl style function benchmark"""
def pprint_table(table):
"""Perl style table output"""
def format_field(field, fmt='{:,.0f}'):
if type(field) is str: return field
if type(field) is tuple: return field[1].format(field[0])
return fmt.format(field)
def get_max_col_w(table, index):
return max([len(format_field(row[index])) for row in table])
col_paddings=[get_max_col_w(table, i) for i in range(len(table[0]))]
for i,row in enumerate(table):
# left col
row_tab=[row[0].ljust(col_paddings[0])]
# rest of the cols
row_tab+=[format_field(row[j]).rjust(col_paddings[j]) for j in range(1,len(row))]
print(' '.join(row_tab))
results={}
for i in range(cnt):
for f in funcs:
if args:
local_args=deepcopy(args)
start=time.perf_counter_ns()
f(*local_args)
stop=time.perf_counter_ns()
results.setdefault(f.__name__, []).append(stop-start)
results={k:float(sum(v))/len(v) for k,v in results.items()}
fastest=sorted(results,key=results.get, reverse=True)
table=[['']]
if rate: table[0].append('rate/sec')
if micro: table[0].append('\u03bcsec/pass')
table[0].extend(fastest)
for e in fastest:
tmp=[e]
if rate:
tmp.append('{:,}'.format(int(round(float(cnt)*1000000.0/results[e]))))
if micro:
tmp.append('{:,.1f}'.format(results[e]/float(cnt)))
for x in fastest:
if x==e: tmp.append('--')
else: tmp.append('{:.1%}'.format((results[x]-results[e])/results[e]))
table.append(tmp)
pprint_table(table)
if __name__=='__main__':
import sys
import time
print(sys.version)
funcs=[f1, f2, f3]
cases=(
(-524242,8),
(345342,2),
(-34734573524242,8),
(71347345345342, 2)
)
for ln, n in cases:
for f in funcs:
print(f'{f.__name__}{ln, n}: {f(ln,n)}')
args=(ln, n)
cmpthese(funcs,args)
print()
该基准打印:
3.9.0 (default, Nov 21 2020, 14:55:42)
[Clang 12.0.0 (clang-1200.0.32.27)]
f1(-524242, 8): -5242428
f2(-524242, 8): -5242428
f3(-524242, 8): -5242428
rate/sec μsec/pass f2 f1 f3
f2 19,777 50.6 -- -37.0% -57.1%
f1 31,402 31.8 58.8% -- -32.0%
f3 46,148 21.7 133.3% 47.0% --
f1(345342, 2): 3453422
f2(345342, 2): 3453422
f3(345342, 2): 3453422
rate/sec μsec/pass f2 f1 f3
f2 19,671 50.8 -- -38.4% -56.5%
f1 31,916 31.3 62.2% -- -29.5%
f3 45,260 22.1 130.1% 41.8% --
f1(-34734573524242, 8): -347345735242428
f2(-34734573524242, 8): -347345735242428
f3(-34734573524242, 8): -347345735242428
rate/sec μsec/pass f2 f1 f3
f2 10,331 96.8 -- -38.1% -72.6%
f1 16,689 59.9 61.5% -- -55.7%
f3 37,679 26.5 264.7% 125.8% --
f1(71347345345342, 2): 721347345345342
f2(71347345345342, 2): 721347345345342
f3(71347345345342, 2): 721347345345342
rate/sec μsec/pass f2 f1 f3
f2 10,579 94.5 -- -67.7% -77.4%
f1 32,779 30.5 209.8% -- -29.9%
f3 46,743 21.4 341.8% 42.6% --
同样的方法,但检查与蛮力相比,速度提高了 2 到 3 倍...