Swift 中用于 RSA 实现的大量计算

Big number computation in Swift for RSA implementation

我正在尝试在 Swift 中为 CryptoSwift lib (to fix #63 实施 RSA 算法。算法本身是可行的,但我需要提高大数计算性能才能在合理的时间内运行。

我实现了自己的 GiantUInt 结构(将字节存储为 UInt8)来计算 RSA 大数(例如 2048 位长度)的运算,但速度太慢(主要是余数运算, 但我认为一切都可以改进):

precedencegroup PowerPrecedence { higherThan: MultiplicationPrecedence }
infix operator ^^ : PowerPrecedence

public struct GiantUInt: Equatable, Comparable, ExpressibleByIntegerLiteral, ExpressibleByArrayLiteral {
  
  // Properties
  
  public let bytes: Array<UInt8>
  
  // Initialization
  
  public init(_ raw: Array<UInt8>) {
    var bytes = raw
    
    while bytes.last == 0 {
      bytes.removeLast()
    }
    
    self.bytes = bytes
  }
  
  // ExpressibleByIntegerLiteral
  
  public typealias IntegerLiteralType = UInt8
  
  public init(integerLiteral value: UInt8) {
    self = GiantUInt([value])
  }
  
  // ExpressibleByArrayLiteral
  
  public typealias ArrayLiteralElement = UInt8
  
  public init(arrayLiteral elements: UInt8...) {
    self = GiantUInt(elements)
  }
    
  // Equatable
  
  public static func == (lhs: GiantUInt, rhs: GiantUInt) -> Bool {
    lhs.bytes == rhs.bytes
  }
  
  // Comparable
  
  public static func < (rhs: GiantUInt, lhs: GiantUInt) -> Bool {
    for i in (0 ..< max(rhs.bytes.count, lhs.bytes.count)).reversed() {
      let r = rhs.bytes[safe: i] ?? 0
      let l = lhs.bytes[safe: i] ?? 0
      if r < l {
        return true
      } else if r > l {
        return false
      }
    }
    
    return false
  }
  
  // Operations
  
  public static func + (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var bytes = [UInt8]()
    var r: UInt8 = 0
    
    for i in 0 ..< max(rhs.bytes.count, lhs.bytes.count) {
      let res = UInt16(rhs.bytes[safe: i] ?? 0) + UInt16(lhs.bytes[safe: i] ?? 0) + UInt16(r)
      r = UInt8(res >> 8)
      bytes.append(UInt8(res & 0xff))
    }
    
    if r != 0 {
      bytes.append(r)
    }
    
    return GiantUInt(bytes)
  }
  
  public static func - (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var bytes = [UInt8]()
    var r: UInt8 = 0
    
    for i in 0 ..< max(rhs.bytes.count, lhs.bytes.count) {
      let rhsb = UInt16(rhs.bytes[safe: i] ?? 0)
      let lhsb = UInt16(lhs.bytes[safe: i] ?? 0) + UInt16(r)
      r = UInt8(rhsb < lhsb ? 1 : 0)
      let res = (UInt16(r) << 8) + rhsb - lhsb
      bytes.append(UInt8(res & 0xff))
    }
    
    if r != 0 {
      bytes.append(r)
    }
    
    return GiantUInt(bytes)
  }
  
  public static func * (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var offset = 0
    var sum = [GiantUInt]()
    
    for rbyte in rhs.bytes {
      var bytes = [UInt8](repeating: 0, count: offset)
      var r: UInt8 = 0
      
      for lbyte in lhs.bytes {
        let res = UInt16(rbyte) * UInt16(lbyte) + UInt16(r)
        r = UInt8(res >> 8)
        bytes.append(UInt8(res & 0xff))
      }
      
      if r != 0 {
        bytes.append(r)
      }
      
      sum.append(GiantUInt(bytes))
      offset += 1
    }
    
    return sum.reduce(0, +)
  }
  
  public static func % (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    var remainder = rhs
    
    // This needs serious optimization (but works)
    while remainder >= lhs {
      remainder = remainder - lhs
    }
  
    return remainder
  }
  
  static func ^^ (rhs: GiantUInt, lhs: GiantUInt) -> GiantUInt {
    let count = lhs.bytes.count
    var result = GiantUInt([1])
    
    for iByte in 0 ..< count {
      let byte = lhs.bytes[iByte]
      for i in 0 ..< 8 {
        if iByte != count - 1 || byte >> i > 0 {
          result = result * result
          if (byte >> i) & 1 == 1 {
            result = result * rhs
          }
        }
      }
    }
    
    return result
  }
  
  public static func exponentiateWithModulus(rhs: GiantUInt, lhs: GiantUInt, modulus: GiantUInt) -> GiantUInt {
    let count = lhs.bytes.count
    var result = GiantUInt([1])
    
    for iByte in 0 ..< count {
      let byte = lhs.bytes[iByte]
      for i in 0 ..< 8 {
        if iByte != count - 1 || byte >> i > 0 {
          result = (result * result) % modulus
          if (byte >> i) & 1 == 1 {
            result = (result * rhs) % modulus
          }
        }
      }
    }
    
    return result
  }
  
}

(此文件可用 here on my fork

我怎样才能改进它以使其(大大)更快?

How can I improve it to make it (a lot) quicker?

不要使用字节。性能取决于大数字中“数字”的数量;所以越小的数字越差,越少的大数字越好。例如,对于一对 2048 位的大数相乘,如果它是使用字节实现的,你最终会得到“256 位数字 * 256 位数字 = 65536 次数字乘法”,如果它是用 64 位整数实现的,那么你最终会得到“ 32 位 * 32 位 = 1024 次数字乘法”(大约快 64 倍)。

喜欢破坏性操作

对于大数字;对于像“a = b + c”这样的东西,CPU 必须处理 3 组缓存行,对于像“a += b”这样的东西,CPU 只需要处理 2 组缓存行缓存行。对于很大的数字,这可能是“它都适合缓存”和“性能因缓存未命中而破坏”之间的区别。

不要使用追加

bytes.append(r)这样的事情可能涉及缓冲区容量检查和底层缓冲区的潜在调整大小;这种额外的开销是不必要的,也是可以避免的——你应该能够提前确定结果的大小,提前创建一个正确大小的数组,然后在不进行任何检查和调整大小的情况下计算结果。

不使用乘法求平方

对于平方,依靠两个数是同一个数这一事实,“数字乘法”的次数几乎可以减半。要理解这一点,假设您正在做十进制的 1234 * 1234 并将中间值表示为这样的网格:

     1        2         3           4
    --------------
 1 | 1*1    + 2*10    + 3*100     + 4*1000 +
 2 | 2*10   + 4*100   + 6*1000    + 8*10000 +
 3 | 3*100  + 6*1000  + 9*10000   + 12*100000 +
 4 | 4*1000 + 8*10000 + 12*100000 + 16*1000000

您可以看到网格右上角“不完全一半”是左下角“不完全一半”的镜像,所以您可以这样做:

     1          2           3              4
    --------------
 1 | 1*1
 2 | (2*10)*2 + 4*100
 3 | (3*100   + 6*1000)*2 + 9*10000
 4 | (4*1000  + 8*10000   + 12*100000)*2 + 16*1000000

当然可以重新排列,使 *2 只发生一次;像“result = (2*10 + 3*100 + 6*1000 + 4*1000 + 8*10000 + 12*100000) * 2 + 1*1 + 4*100 + 9*10000 + 16*1000000”。

你的模can/should有待提高

一种方法是将除数左移,直到它大于分子(同时跟踪移位计数);然后执行“当移位计数不为零时{右移除数;如果除数不大于分子,则从分子中减去除数;减少移位计数}”。

不要使用 Swift 或高级语言

大多数 CPU 都有特殊的说明,可以更有效地处理大数,即使是最基本的东西(例如“加进位”)在大多数高级语言中也是不可能的。结果是使用最佳数字大小(例如 64 位 CPUs 上的 64 位数字)是痛苦的,因此您以不同的方式实现算法,然后编译器无法正确优化,因为算法不同。

只有使用汇编语言才能实现最佳性能(例如,可以从 swift 代码中使用的本机库)。如果您将(例如)GMP 库(使用大量内联汇编语言)与 mini-GMP(不使用汇编语言,仅旨在“数字速度不超过 10 倍”)进行比较,您可以清楚地看到这种差异最多几百位").