Feige-Fiat-Shamir 识别协议 Java

Feige-Fiat-Shamir identification protocol Java

我正在尝试实现 Feige-Fiat-Shamir 识别方案,如书 "Handbook of Applied Cryptography"(第 410 页,第 10.4.2 节)中所述。我有一个代码,但问题是有时它会成功但有时会失败。谁能帮我找出这段代码中的错误?谢谢。

    public static void main(String[] args) throws Exception {

    BigInteger p = BigInteger.probablePrime(16, new Random());
    BigInteger q = BigInteger.probablePrime(16, new Random());
    int k = 10;  // Receive k

    BigInteger trustedN = p.multiply(q);

    List<BigInteger> randomInts = new ArrayList<>();    //s1,s2...sk
    BitSet randomBits = new BitSet(k);  // b1,b2..bk
    List<BigInteger> listV = new ArrayList<>();

    Random rand = new Random();

    /*
    Choose k positive numbers less than trustedN.
    Choose k bits 0 or 1
     */

    for (int i = 0; i < k; i++) {
        // Generate random big ints less than trustedN
        randomInts.add(new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN));

        randomBits.set(i, rand.nextBoolean());
        // (-1)^bi
        BigInteger minus1pow = (((new BigInteger("-1")).pow(randomBits.get(i) ? 1 : 0))).mod(trustedN);

        // (s^2)^(-1)
        BigInteger randomIntPow = (randomInts.get(i).pow(2)).modInverse(trustedN);

        // vi = (-1)^bi * (s^2)^(-1)
        listV.add((minus1pow.multiply(randomIntPow)).mod(trustedN));
    }

    // Random r
    BigInteger randomR = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);
    // Random bit index
    int bitIndex = rand.nextInt(randomBits.length() + 1);

    // Calculate x
    BigInteger x = ((new BigInteger("-1")).pow(randomBits.get(bitIndex) ? 1 : 0).mod(trustedN)).multiply((randomR.pow(2)).mod(trustedN)).mod(trustedN);

    // Let pretend it was randomly selected vector (e1,e2,e3...)
    String eBits = "1100011010";


    BigInteger totalMultS = new BigInteger("1");
    for (int i = 0; i < k; i++) {
        totalMultS = totalMultS
                .multiply(randomInts.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));
    }

    totalMultS = totalMultS.mod(trustedN).multiply(randomR.mod(trustedN)).mod(trustedN);
    BigInteger y = totalMultS;


    BigInteger totalMultV = new BigInteger("1");

    for (int i = 0; i < k; i++) {
        totalMultV = totalMultV
                .multiply(listV.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));
    }

    totalMultV = totalMultV.mod(trustedN);
    BigInteger z = (y.pow(2).mod(trustedN)).multiply(totalMultV).mod(trustedN);


    if (z.toString().equals(x.toString())){
        System.out.println("SUCCESS");
    }
    else {
        System.out.println("FAIL");

        System.out.println("x: " + x.toString());
        System.out.println("z: " + z.toString());

    }



}

我找到了解决办法。问题出在最后一个条件。看代码:

    public static void main(String[] args) throws Exception {

    BigInteger p = BigInteger.probablePrime(4, new Random());

    BigInteger q = BigInteger.probablePrime(4, new Random());

    System.out.println("p: " + p.toString());
    System.out.println("q: " + q.toString());

    int k = 3;  // Receive k

    BigInteger trustedN = p.multiply(q);

    System.out.println("n: " + trustedN.toString());

    List<BigInteger> randomInts = new ArrayList<>();    //s1,s2...sk
    BitSet randomBits = new BitSet(k);  // b1,b2..bk
    List<BigInteger> listV = new ArrayList<>();

    Random rand = new Random();

    /*
    Choose k positive numbers less than trustedN.
    Choose k bits 0 or 1
     */
    System.out.print("random s: ");
    for (int i = 0; i < k; i++) {


        BigInteger si = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);

        while (si.gcd(trustedN).intValue() != 1){
            si = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);
        }
        // Generate random big ints less than trustedN
        randomInts.add(si);

        randomBits.set(i, rand.nextBoolean());
        // (-1)^bi
        System.out.print(randomInts.get(i) + " " + randomBits.get(i) + " ");
        BigInteger minus1pow = (((new BigInteger("-1")).pow(randomBits.get(i) ? 1 : 0)));

        // (s^2)^(-1)
        BigInteger randomIntPow = minus1pow.multiply(randomInts.get(i).pow(2)).modInverse(trustedN);

        // vi = (-1)^bi * (s^2)^(-1)//            listV.add((minus1pow.multiply(randomIntPow)).mod(trustedN));
        listV.add(randomIntPow);
    }
    System.out.print("\nlist v: ");
    for (BigInteger bi:
         listV) {
        System.out.print(bi.toString() + " ");
    }
    System.out.println();
    // Random r
    BigInteger randomR = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);
    System.out.println("r: " + randomR.toString());
    // Random bit index
    int bitIndex = (int) (Math.random() * ( randomBits.length()  ));
    System.out.println("bitIndex: " + bitIndex + " bit value: " + randomBits.get(bitIndex));
    // Calculate x//        BigInteger x = ((new BigInteger("-1")).pow(randomBits.get(bitIndex) ? 1 : 0).mod(trustedN)).multiply((randomR.pow(2)).mod(trustedN)).mod(trustedN);
    BigInteger x = (((new BigInteger("-1")).pow(randomBits.get(bitIndex) ? 1 : 0)).multiply((randomR.pow(2)))).mod(trustedN);
    // Let pretend it was randomly selected vector (e1,e2,e3)
    String eBits = "100";


    BigInteger totalMultS = new BigInteger("1");
    for (int i = 0; i < k; i++) {
        totalMultS = totalMultS
                .multiply(randomInts.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));
    }

    BigInteger y = totalMultS.multiply(randomR.mod(trustedN)).mod(trustedN);

    System.out.println("y: " + y.toString());


    BigInteger totalMultV = new BigInteger("1");

    for (int i = 0; i < k; i++) {
        totalMultV = totalMultV
                .multiply(listV.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));
    }

    System.out.println("total mult v: " + totalMultV);


    if ((z.toString().equals(x.toString()) || z.toString().equals(x.negate().mod(trustedN).toString()))
            && !z.toString().equals("0")){
        System.out.println("SUCCESS");

    }
    else {
        System.out.println("FAIL");

        System.out.println("x: " + x.toString());
        System.out.println("z: " + z.toString());

    }

}