使用 z3,其中约束取决于函数的输出

Using z3 where constraint depends on output of function

我想用z3解决这个问题。输入是一个 10 个字符的字符串。输入的每个字符都是一个可打印字符 (ASCII)。输入应该是这样的,当以输入作为参数调用 calc2() 函数时,结果应该是:0x0009E38E1FB7629B.

在这种情况下如何使用 z3py?

通常我只会添加独立方程作为对 z3 的约束。在这种情况下,我不确定如何使用 z3。

def calc2(input):
    result = 0

    for i in range(len(input)):
        r1 = (result << 0x5) & 0xffffffffffffffff
        r2 = result >> 0x1b
        r3 = (r1 ^ r2)

        result = (r3 ^ ord(input[i]))

    return result

if __name__ == "__main__":
        input = sys.argv[1]
        result = calc2(input)

        if result == 0x0009E38E1FB7629B:
             print "solved"

更新: 我尝试了以下但它没有给我正确的答案:

from z3 import *

def calc2(input):
    result = 0

    for i in range(len(input)):
        r1 = (result << 0x5) & 0xffffffffffffffff
        r2 = result >> 0x1b
        r3 = (r1 ^ r2)

        result = r3 ^ Concat(BitVec(0, 56), input[i])


    return result

if __name__ == "__main__":
    s = Solver()
    X = [BitVec('x' + str(i), 8) for i in range(10)]

    s.add(calc2(X) == 0x0009E38E1FB7629B)

    if s.check() == sat:
        print(s.model())

您可以在 Z3 中对 calc2 进行编码。您需要将循环展开 1,2,3,4,..,n 次(对于 n = 预期的最大输入大小),仅此而已。 (你实际上不需要展开循环,你可以使用 z3py 来创建约束)

我希望这不是家庭作业,但这是一种解决方法:

from z3 import *

s = Solver()

# Input is 10 character long; represent with 10 8-bit symbolic variables
input = [BitVec("input%s" % i, 8) for i in range(10)]

# Make sure each character is printable ASCII, i.e., between 0x20 and 0x7E
for i in range(10):
  s.add(input[i] >= 0x20)
  s.add(input[i] <= 0x7E)

def calc2(input):

    # result is a 64-bit value
    result = BitVecVal(0, 64)

    for i in range(len(input)):
        # NB. We don't actually need to mask with 0xffffffffffffffff
        # Since we explicitly have a 64-bit value in result.
        # But it doesn't hurt to mask it, so we do it here.
        r1 = (result << 0x5) & 0xffffffffffffffff
        r2 = result >> 0x1b
        r3 = r1 ^ r2

        # We need to zero-extend to match sizes
        result = r3 ^ ZeroExt(56, input[i])

    return result


# Assert the required equality
s.add(calc2(input) == 0x0009E38E1FB7629B)

# Check and get model
print s.check()
m = s.model()

# reconstruct the string:
s = ''.join([chr (m[input[i]].as_long()) for i in range(10)])
print s

这会打印:

$ python a.py
sat
L`p:LxlBVU

看起来你的秘密字符串是

"L`p:LxlBVU"

我在程序中添加了一些注释,以帮助您了解 z3py 中的编码方式,但请随时要求澄清。希望这对您有所帮助!

获取所有解决方案

要获得其他解决方案,您只需循环并断言该解决方案不应是前一个解决方案。您可以在断言后使用以下 while 循环:

while s.check() == sat:
   m = s.model()
   print ''.join([chr (m[input[i]].as_long()) for i in range(10)])
   s.add(Or([input[i] != m[input[i]] for i in range(10)]))

当我运行它时,它一直在继续!您可能想在一段时间后停止它。