快速评估大量输入值的数学表达式(函数)
Evaluating a mathematical expression (function) for a large number of input values fast
以下问题
- Evaluating a mathematical expression in a string
- Equation parsing in Python
- Safe way to parse user-supplied mathematical formula in Python
- Evaluate math equations from unsafe user input in Python
和他们各自的答案让我思考如何能够有效地解析一个数学表达式(按照这个答案 的一般术语)由(或多或少受信任的)用户有效地解析 20k 到来自数据库的 30k 个输入值。我实施了一个快速而肮脏的基准测试,因此我可以比较不同的解决方案。
# Runs with Python 3(.4)
import pprint
import time
# This is what I have
userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)
print_results = False
# Some database, represented by an array of dicts (simplified for this example)
database_xy = []
for a in range(1, demo_len, 1):
database_xy.append({
'x':float(a),
'y_eval':0,
'y_sympya':0,
'y_sympyb':0,
'y_sympyc':0,
'y_aevala':0,
'y_aevalb':0,
'y_aevalc':0,
'y_numexpr': 0,
'y_simpleeval':0
})
# 解决方案 #1:eval [是的,完全不安全]
time_start = time.time()
func = eval("lambda x: " + userinput_function)
for item in database_xy:
item['y_eval'] = func(item['x'])
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('1 eval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #2a:sympy - evalf (http://www.sympy.org)
import sympy
time_start = time.time()
x = sympy.symbols('x')
sympy_function = sympy.sympify(userinput_function)
for item in database_xy:
item['y_sympya'] = float(sympy_function.evalf(subs={x:item['x']}))
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('2a sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #2b:sympy - lambdify (http://www.sympy.org)
from sympy.utilities.lambdify import lambdify
import sympy
import numpy
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numpy') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
item['y_sympyb'] = yy[index]
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('2b sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #2c:sympy - lambdify with numexpr [and numpy] (http://www.sympy.org)
from sympy.utilities.lambdify import lambdify
import sympy
import numpy
import numexpr
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numexpr') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
item['y_sympyc'] = yy[index]
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('2c sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #3a:asteval [基于 ast] - 使用字符串魔法 (http://newville.github.io/asteval/index.html)
from asteval import Interpreter
aevala = Interpreter()
time_start = time.time()
aevala('def func(x):\n\treturn ' + userinput_function)
for item in database_xy:
item['y_aevala'] = aevala('func(' + str(item['x']) + ')')
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('3a aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #3b (M Newville):asteval [基于 ast] - 解析 & 运行 (http://newville.github.io/asteval/index.html)
from asteval import Interpreter
aevalb = Interpreter()
time_start = time.time()
exprb = aevalb.parse(userinput_function)
for item in database_xy:
aevalb.symtable['x'] = item['x']
item['y_aevalb'] = aevalb.run(exprb)
time_end = time.time()
print('3b aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# Solution #3c (M Newville): asteval [基于 ast] - parse & 运行 with numpy (http://newville.github.io/asteval/index.html)
from asteval import Interpreter
import numpy
aevalc = Interpreter()
time_start = time.time()
exprc = aevalc.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aevalc.symtable['x'] = x
y = aevalc.run(exprc)
for index, item in enumerate(database_xy):
item['y_aevalc'] = y[index]
time_end = time.time()
print('3c aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #4:simpleeval [基于 ast] (https://github.com/danthedeckie/simpleeval)
from simpleeval import simple_eval
time_start = time.time()
for item in database_xy:
item['y_simpleeval'] = simple_eval(userinput_function, names={'x': item['x']})
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('4 simpleeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
#解决方案 #5 numexpr [和 numpy] (https://github.com/pydata/numexpr)
import numpy
import numexpr
time_start = time.time()
x = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
x[index] = item['x']
y = numexpr.evaluate(userinput_function)
for index, item in enumerate(database_xy):
item['y_numexpr'] = y[index]
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('5 numexpr: ' + str(round(time_end - time_start, 4)) + ' seconds')
在我的旧测试机器上(Python 3.4,Linux 3.11 x86_64,两个内核,1.8GHz)我得到以下结果:
1 eval: 0.0185 seconds
2a sympy: 10.671 seconds
2b sympy: 0.0315 seconds
2c sympy: 0.0348 seconds
3a aeval: 2.8368 seconds
3b aeval: 0.5827 seconds
3c aeval: 0.0246 seconds
4 simpleeval: 1.2363 seconds
5 numexpr: 0.0312 seconds
突出的是 eval 令人难以置信的速度,尽管我不想在现实生活中使用它。第二个最佳解决方案似乎是 numexpr,它依赖于 numpy - 我想避免的依赖性,尽管这不是硬性要求。下一个最好的是 simpleeval,它是围绕 ast 构建的。 aeval 是另一个基于 ast 的解决方案,它的缺点是我必须首先将每个浮点输入值转换为字符串,但我找不到解决方法。 sympy 最初是我最喜欢的,因为它提供了最灵活且显然最安全的解决方案,但它最终以与倒数第二个解决方案的一些令人印象深刻的差距结束。
更新 1:使用 sympy 有一种更快的方法。请参见解决方案 2b。它几乎和 numexpr 一样好,尽管我不确定 sympy 是否真的在内部使用它。
更新 2:sympy 实现现在使用 sympify 而不是 simplify(由其首席开发人员 asmeurer 推荐 - 谢谢)。它不使用 numexpr 除非明确要求这样做(参见解决方案 2c)。我还添加了两个基于 asteval 的明显更快的解决方案(感谢 M Newville)。
我有哪些选择可以进一步加快任何相对安全的解决方案的速度?例如,还有其他安全的(-ish)方法直接使用 ast 吗?
如果您将字符串传递给 sympy.simplify
(不推荐使用;建议明确使用 sympify
),这将使用 sympy.sympify
将其转换为SymPy 表达式,内部使用 eval
。
既然您询问了 asteval, 有一种使用它并获得更快结果的方法:
aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
for item in database_xy:
aeval.symtable['x'] = item['x']
item['y_aeval'] = aeval.run(expr)
time_end = time.time()
即可以先解析("pre-compile")用户输入函数,然后将x
的每一个新值插入到符号table中,并使用Interpreter.run()
来计算该值的编译表达式。根据您的规模,我认为这将使您接近 0.5 秒。
如果愿意使用numpy
,混合方案:
aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aeval.symtable['x'] = x
y = aeval.run(expr)
time_end = time.time()
应该快得多,并且在 运行 时间上与使用 numexpr
.
相当
CPython(和 pypy)使用非常简单的堆栈语言来执行函数,使用 ast 模块自己编写字节码相当容易。
import sys
PY3 = sys.version_info.major > 2
import ast
from ast import parse
import types
from dis import opmap
ops = {
ast.Mult: opmap['BINARY_MULTIPLY'],
ast.Add: opmap['BINARY_ADD'],
ast.Sub: opmap['BINARY_SUBTRACT'],
ast.Div: opmap['BINARY_TRUE_DIVIDE'],
ast.Pow: opmap['BINARY_POWER'],
}
LOAD_CONST = opmap['LOAD_CONST']
RETURN_VALUE = opmap['RETURN_VALUE']
LOAD_FAST = opmap['LOAD_FAST']
def process(consts, bytecode, p, stackSize=0):
if isinstance(p, ast.Expr):
return process(consts, bytecode, p.value, stackSize)
if isinstance(p, ast.BinOp):
szl = process(consts, bytecode, p.left, stackSize)
szr = process(consts, bytecode, p.right, stackSize)
if type(p.op) in ops:
bytecode.append(ops[type(p.op)])
else:
print(p.op)
raise Exception("unspported opcode")
return max(szl, szr) + stackSize + 1
if isinstance(p, ast.Num):
if p.n not in consts:
consts.append(p.n)
idx = consts.index(p.n)
bytecode.append(LOAD_CONST)
bytecode.append(idx % 256)
bytecode.append(idx // 256)
return stackSize + 1
if isinstance(p, ast.Name):
bytecode.append(LOAD_FAST)
bytecode.append(0)
bytecode.append(0)
return stackSize + 1
raise Exception("unsupported token")
def makefunction(inp):
def f(x):
pass
if PY3:
oldcode = f.__code__
kwonly = oldcode.co_kwonlyargcount
else:
oldcode = f.func_code
stack_size = 0
consts = [None]
bytecode = []
p = ast.parse(inp).body[0]
stack_size = process(consts, bytecode, p, stack_size)
bytecode.append(RETURN_VALUE)
bytecode = bytes(bytearray(bytecode))
consts = tuple(consts)
if PY3:
code = types.CodeType(oldcode.co_argcount, oldcode.co_kwonlyargcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, b'')
f.__code__ = code
else:
code = types.CodeType(oldcode.co_argcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, '')
f.func_code = code
return f
这具有明显的优势,可以生成与 eval
基本相同的函数,并且它的缩放比例几乎与 compile
+eval
一样(compile
步骤比 eval
稍慢,并且 eval
会预先计算它可以做的任何事情(1+1+x
被编译为 2+x
)。
相比之下,eval
在 0.0125 秒内完成 20k 测试,makefunction
在 0.014 秒内完成。将迭代次数增加到 2,000,000 次,eval
在 1.23 秒内完成,makefunction
在 1.32 秒内完成。
有趣的是,pypy 认识到 eval
和 makefunction
产生基本相同的功能,因此第一个的 JIT 预热加速了第二个。
我不是 Python 编码员,所以我无法提供 Python 代码。但是我想我可以提供一个简单的方案来最小化你的依赖性并且仍然运行得非常快。
这里的关键是构建一些接近 eval 而不是 eval 的东西。所以你想要做的是 "compile" 将用户方程式转化为可以快速评估的东西。 OP给出了很多解决方案。
这是另一个基于Reverse Polish.
计算方程的方法
为了便于讨论,假设您可以将等式转换为 RPN(逆波兰表示法)。这意味着操作数在运算符之前,例如,对于用户公式:
sqrt(x**2 + y**2)
您从左到右得到 RPN 等效读数:
x 2 ** y 2 ** + sqrt
事实上,我们可以将 "operands"(例如,变量和常量)视为采用零操作数的运算符。现在 RPN 中的每个人都是运算符。
如果我们将每个运算符元素视为一个标记(假设每个运算符元素都有一个唯一的小整数,下面写为“RPNelement”)并将它们存储在一个数组中"RPN" ,我们可以使用下推堆栈非常快速地评估这样的公式:
stack = {}; // make the stack empty
do i=1,len(RPN),1
case RPN[i]:
"0": push(stack,0);
"1": push(stack,1);
"+": push(stack,pop(stack)+pop(stack));break;
"-": push(stack,pop(stack)-pop(stack));break;
"**": push(stack,power(pop(stack),pop(stack)));break;
"x": push(stack,x);break;
"y": push(stack,y);break;
"K1": push(stack,K1);break;
... // as many K1s as you have typical constants in a formula
endcase
enddo
answer=pop(stack);
您可以内联 push 和 pop 操作以加快速度。
如果提供的 RPN 格式正确,则此代码是绝对安全的。
现在,如何获得RPN?答案:构建一个小型递归下降解析器,其操作将 RPN 运算符附加到 RPN 数组。有关典型方程,请参阅 my SO answer for how to build a recursive descent parser easily。
如果它们不是特殊的、经常出现的值(如我在“0”和“1”中显示的那样;您如果有帮助,可以添加更多)。
这个解决方案最多应该只有几百行,并且对其他包的依赖性为零。
(Python 专家:请随意编辑代码以使其成为 Python 风格)。
我使用过 C++ ExprTK library in the past with great success. Here 是其他 C++ 解析器(例如 Muparser、MathExpr、ATMSP 等...)中的基准速度测试,而 ExprTK 名列前茅。
有一个名为 cexprtk 的 ExprTK 的 Python 包装器,我已经使用过并且发现速度非常快。您可以只编译一次数学表达式,然后根据需要多次计算此序列化表达式。这是一个使用 cexprtk
和 userinput_function
:
的简单示例代码
import cexprtk
import time
userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)
time_start = time.time()
x = 1
st = cexprtk.Symbol_Table({"x":x}, add_constants = True) # Setup the symbol table
Expr = cexprtk.Expression(userinput_function, st) # Apply the symbol table to the userinput_function
for x in range(0,demo_len,1):
st.variables['x'] = x # Update the symbol table with the new x value
Expr() # evaluate expression
time_end = time.time()
print('1 cexprtk: ' + str(round(time_end - time_start, 4)) + ' seconds')
在我的机器上(Linux,双核,2.5GHz),演示长度为 20000,这在 0.0202 秒内完成。
对于长度为 2,000,000 的演示,cexprtk
在 1.23 秒内完成。
以下问题
- Evaluating a mathematical expression in a string
- Equation parsing in Python
- Safe way to parse user-supplied mathematical formula in Python
- Evaluate math equations from unsafe user input in Python
和他们各自的答案让我思考如何能够有效地解析一个数学表达式(按照这个答案 的一般术语)由(或多或少受信任的)用户有效地解析 20k 到来自数据库的 30k 个输入值。我实施了一个快速而肮脏的基准测试,因此我可以比较不同的解决方案。
# Runs with Python 3(.4)
import pprint
import time
# This is what I have
userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)
print_results = False
# Some database, represented by an array of dicts (simplified for this example)
database_xy = []
for a in range(1, demo_len, 1):
database_xy.append({
'x':float(a),
'y_eval':0,
'y_sympya':0,
'y_sympyb':0,
'y_sympyc':0,
'y_aevala':0,
'y_aevalb':0,
'y_aevalc':0,
'y_numexpr': 0,
'y_simpleeval':0
})
# 解决方案 #1:eval [是的,完全不安全]
time_start = time.time()
func = eval("lambda x: " + userinput_function)
for item in database_xy:
item['y_eval'] = func(item['x'])
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('1 eval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #2a:sympy - evalf (http://www.sympy.org)
import sympy
time_start = time.time()
x = sympy.symbols('x')
sympy_function = sympy.sympify(userinput_function)
for item in database_xy:
item['y_sympya'] = float(sympy_function.evalf(subs={x:item['x']}))
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('2a sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #2b:sympy - lambdify (http://www.sympy.org)
from sympy.utilities.lambdify import lambdify
import sympy
import numpy
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numpy') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
item['y_sympyb'] = yy[index]
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('2b sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #2c:sympy - lambdify with numexpr [and numpy] (http://www.sympy.org)
from sympy.utilities.lambdify import lambdify
import sympy
import numpy
import numexpr
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numexpr') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
item['y_sympyc'] = yy[index]
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('2c sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #3a:asteval [基于 ast] - 使用字符串魔法 (http://newville.github.io/asteval/index.html)
from asteval import Interpreter
aevala = Interpreter()
time_start = time.time()
aevala('def func(x):\n\treturn ' + userinput_function)
for item in database_xy:
item['y_aevala'] = aevala('func(' + str(item['x']) + ')')
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('3a aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #3b (M Newville):asteval [基于 ast] - 解析 & 运行 (http://newville.github.io/asteval/index.html)
from asteval import Interpreter
aevalb = Interpreter()
time_start = time.time()
exprb = aevalb.parse(userinput_function)
for item in database_xy:
aevalb.symtable['x'] = item['x']
item['y_aevalb'] = aevalb.run(exprb)
time_end = time.time()
print('3b aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# Solution #3c (M Newville): asteval [基于 ast] - parse & 运行 with numpy (http://newville.github.io/asteval/index.html)
from asteval import Interpreter
import numpy
aevalc = Interpreter()
time_start = time.time()
exprc = aevalc.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aevalc.symtable['x'] = x
y = aevalc.run(exprc)
for index, item in enumerate(database_xy):
item['y_aevalc'] = y[index]
time_end = time.time()
print('3c aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
# 解决方案 #4:simpleeval [基于 ast] (https://github.com/danthedeckie/simpleeval)
from simpleeval import simple_eval
time_start = time.time()
for item in database_xy:
item['y_simpleeval'] = simple_eval(userinput_function, names={'x': item['x']})
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('4 simpleeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
#解决方案 #5 numexpr [和 numpy] (https://github.com/pydata/numexpr)
import numpy
import numexpr
time_start = time.time()
x = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
x[index] = item['x']
y = numexpr.evaluate(userinput_function)
for index, item in enumerate(database_xy):
item['y_numexpr'] = y[index]
time_end = time.time()
if print_results:
pprint.pprint(database_xy)
print('5 numexpr: ' + str(round(time_end - time_start, 4)) + ' seconds')
在我的旧测试机器上(Python 3.4,Linux 3.11 x86_64,两个内核,1.8GHz)我得到以下结果:
1 eval: 0.0185 seconds
2a sympy: 10.671 seconds
2b sympy: 0.0315 seconds
2c sympy: 0.0348 seconds
3a aeval: 2.8368 seconds
3b aeval: 0.5827 seconds
3c aeval: 0.0246 seconds
4 simpleeval: 1.2363 seconds
5 numexpr: 0.0312 seconds
突出的是 eval 令人难以置信的速度,尽管我不想在现实生活中使用它。第二个最佳解决方案似乎是 numexpr,它依赖于 numpy - 我想避免的依赖性,尽管这不是硬性要求。下一个最好的是 simpleeval,它是围绕 ast 构建的。 aeval 是另一个基于 ast 的解决方案,它的缺点是我必须首先将每个浮点输入值转换为字符串,但我找不到解决方法。 sympy 最初是我最喜欢的,因为它提供了最灵活且显然最安全的解决方案,但它最终以与倒数第二个解决方案的一些令人印象深刻的差距结束。
更新 1:使用 sympy 有一种更快的方法。请参见解决方案 2b。它几乎和 numexpr 一样好,尽管我不确定 sympy 是否真的在内部使用它。
更新 2:sympy 实现现在使用 sympify 而不是 simplify(由其首席开发人员 asmeurer 推荐 - 谢谢)。它不使用 numexpr 除非明确要求这样做(参见解决方案 2c)。我还添加了两个基于 asteval 的明显更快的解决方案(感谢 M Newville)。
我有哪些选择可以进一步加快任何相对安全的解决方案的速度?例如,还有其他安全的(-ish)方法直接使用 ast 吗?
如果您将字符串传递给 sympy.simplify
(不推荐使用;建议明确使用 sympify
),这将使用 sympy.sympify
将其转换为SymPy 表达式,内部使用 eval
。
既然您询问了 asteval, 有一种使用它并获得更快结果的方法:
aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
for item in database_xy:
aeval.symtable['x'] = item['x']
item['y_aeval'] = aeval.run(expr)
time_end = time.time()
即可以先解析("pre-compile")用户输入函数,然后将x
的每一个新值插入到符号table中,并使用Interpreter.run()
来计算该值的编译表达式。根据您的规模,我认为这将使您接近 0.5 秒。
如果愿意使用numpy
,混合方案:
aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aeval.symtable['x'] = x
y = aeval.run(expr)
time_end = time.time()
应该快得多,并且在 运行 时间上与使用 numexpr
.
CPython(和 pypy)使用非常简单的堆栈语言来执行函数,使用 ast 模块自己编写字节码相当容易。
import sys
PY3 = sys.version_info.major > 2
import ast
from ast import parse
import types
from dis import opmap
ops = {
ast.Mult: opmap['BINARY_MULTIPLY'],
ast.Add: opmap['BINARY_ADD'],
ast.Sub: opmap['BINARY_SUBTRACT'],
ast.Div: opmap['BINARY_TRUE_DIVIDE'],
ast.Pow: opmap['BINARY_POWER'],
}
LOAD_CONST = opmap['LOAD_CONST']
RETURN_VALUE = opmap['RETURN_VALUE']
LOAD_FAST = opmap['LOAD_FAST']
def process(consts, bytecode, p, stackSize=0):
if isinstance(p, ast.Expr):
return process(consts, bytecode, p.value, stackSize)
if isinstance(p, ast.BinOp):
szl = process(consts, bytecode, p.left, stackSize)
szr = process(consts, bytecode, p.right, stackSize)
if type(p.op) in ops:
bytecode.append(ops[type(p.op)])
else:
print(p.op)
raise Exception("unspported opcode")
return max(szl, szr) + stackSize + 1
if isinstance(p, ast.Num):
if p.n not in consts:
consts.append(p.n)
idx = consts.index(p.n)
bytecode.append(LOAD_CONST)
bytecode.append(idx % 256)
bytecode.append(idx // 256)
return stackSize + 1
if isinstance(p, ast.Name):
bytecode.append(LOAD_FAST)
bytecode.append(0)
bytecode.append(0)
return stackSize + 1
raise Exception("unsupported token")
def makefunction(inp):
def f(x):
pass
if PY3:
oldcode = f.__code__
kwonly = oldcode.co_kwonlyargcount
else:
oldcode = f.func_code
stack_size = 0
consts = [None]
bytecode = []
p = ast.parse(inp).body[0]
stack_size = process(consts, bytecode, p, stack_size)
bytecode.append(RETURN_VALUE)
bytecode = bytes(bytearray(bytecode))
consts = tuple(consts)
if PY3:
code = types.CodeType(oldcode.co_argcount, oldcode.co_kwonlyargcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, b'')
f.__code__ = code
else:
code = types.CodeType(oldcode.co_argcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, '')
f.func_code = code
return f
这具有明显的优势,可以生成与 eval
基本相同的函数,并且它的缩放比例几乎与 compile
+eval
一样(compile
步骤比 eval
稍慢,并且 eval
会预先计算它可以做的任何事情(1+1+x
被编译为 2+x
)。
相比之下,eval
在 0.0125 秒内完成 20k 测试,makefunction
在 0.014 秒内完成。将迭代次数增加到 2,000,000 次,eval
在 1.23 秒内完成,makefunction
在 1.32 秒内完成。
有趣的是,pypy 认识到 eval
和 makefunction
产生基本相同的功能,因此第一个的 JIT 预热加速了第二个。
我不是 Python 编码员,所以我无法提供 Python 代码。但是我想我可以提供一个简单的方案来最小化你的依赖性并且仍然运行得非常快。
这里的关键是构建一些接近 eval 而不是 eval 的东西。所以你想要做的是 "compile" 将用户方程式转化为可以快速评估的东西。 OP给出了很多解决方案。
这是另一个基于Reverse Polish.
计算方程的方法为了便于讨论,假设您可以将等式转换为 RPN(逆波兰表示法)。这意味着操作数在运算符之前,例如,对于用户公式:
sqrt(x**2 + y**2)
您从左到右得到 RPN 等效读数:
x 2 ** y 2 ** + sqrt
事实上,我们可以将 "operands"(例如,变量和常量)视为采用零操作数的运算符。现在 RPN 中的每个人都是运算符。
如果我们将每个运算符元素视为一个标记(假设每个运算符元素都有一个唯一的小整数,下面写为“RPNelement”)并将它们存储在一个数组中"RPN" ,我们可以使用下推堆栈非常快速地评估这样的公式:
stack = {}; // make the stack empty
do i=1,len(RPN),1
case RPN[i]:
"0": push(stack,0);
"1": push(stack,1);
"+": push(stack,pop(stack)+pop(stack));break;
"-": push(stack,pop(stack)-pop(stack));break;
"**": push(stack,power(pop(stack),pop(stack)));break;
"x": push(stack,x);break;
"y": push(stack,y);break;
"K1": push(stack,K1);break;
... // as many K1s as you have typical constants in a formula
endcase
enddo
answer=pop(stack);
您可以内联 push 和 pop 操作以加快速度。 如果提供的 RPN 格式正确,则此代码是绝对安全的。
现在,如何获得RPN?答案:构建一个小型递归下降解析器,其操作将 RPN 运算符附加到 RPN 数组。有关典型方程,请参阅 my SO answer for how to build a recursive descent parser easily。
如果它们不是特殊的、经常出现的值(如我在“0”和“1”中显示的那样;您如果有帮助,可以添加更多)。
这个解决方案最多应该只有几百行,并且对其他包的依赖性为零。
(Python 专家:请随意编辑代码以使其成为 Python 风格)。
我使用过 C++ ExprTK library in the past with great success. Here 是其他 C++ 解析器(例如 Muparser、MathExpr、ATMSP 等...)中的基准速度测试,而 ExprTK 名列前茅。
有一个名为 cexprtk 的 ExprTK 的 Python 包装器,我已经使用过并且发现速度非常快。您可以只编译一次数学表达式,然后根据需要多次计算此序列化表达式。这是一个使用 cexprtk
和 userinput_function
:
import cexprtk
import time
userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)
time_start = time.time()
x = 1
st = cexprtk.Symbol_Table({"x":x}, add_constants = True) # Setup the symbol table
Expr = cexprtk.Expression(userinput_function, st) # Apply the symbol table to the userinput_function
for x in range(0,demo_len,1):
st.variables['x'] = x # Update the symbol table with the new x value
Expr() # evaluate expression
time_end = time.time()
print('1 cexprtk: ' + str(round(time_end - time_start, 4)) + ' seconds')
在我的机器上(Linux,双核,2.5GHz),演示长度为 20000,这在 0.0202 秒内完成。
对于长度为 2,000,000 的演示,cexprtk
在 1.23 秒内完成。