对类型为 Literal 的参数使用字符串值

Use string value for argument typed as Literal

我使用 boto3 包中的 kms.decrypt() 方法。 对于打字支持,我使用 boto3-stubs 包。

解密方法有属性EncryptionAlgorithm,类型为

EncryptionAlgorithmSpecType = Literal["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256", "SYMMETRIC_DEFAULT"]

我通过输入字符串中的正则表达式解析加密算法,因此在我的例子中,EncryptionAlgorithm 值是 str,而不是文字。 Mypy 投诉了。

我不想禁用此行的类型检查,我找到的唯一解决方案是以这种方式将 str 值转换为文字的辅助方法:

import boto3
from mypy_boto3_kms.literals import EncryptionAlgorithmSpecType

def getEncryptionAlgorithmLiteral(algorithm: str) -> EncryptionAlgorithmSpecType:
  result: EncryptionAlgorithmSpecType
  if algorithm == "RSAES_OAEP_SHA_1":
    result = "RSAES_OAEP_SHA_1"
  elif algorithm == "RSAES_OAEP_SHA_256":
    result = "RSAES_OAEP_SHA_256"
  elif algorithm == "SYMMETRIC_DEFAULT":
    result = "SYMMETRIC_DEFAULT"
  else:
    raise Exception(f"Unexpected algorithm '{algorithm}'. It must be one of {EncryptionAlgorithmSpecType}")
  return result

def main(binaryEncData: bytes, keyId: str, algorithm: str):
  kms = boto3.client('kms')
  kmsResult = kms.decrypt(CiphertextBlob=binaryEncData, 
                          KeyId=keyId,
                          EncryptionAlgorithm=getEncryptionAlgorithmLiteral(algorithm))

我想知道是否有更好的方法来实现 getEncryptionAlgorithmLiteral() 方法中的转换,这样我就不需要输入所有这些值两次。理想情况下,我想直接使用 EncryptionAlgorithmSpecType 类型的值,而不是再次将它们输入到我的代码中。

您可以使用 typing.get_args to get the arguments passed in to typing.Literal. In this case, you'll need to combine it with typing.cast,这样您就可以向“mypy”发出信号,表明函数 returns 的字符串值是可接受的 Literal 值。

from typing import cast
# Import below from `typing` in Python 3.8+
from typing_extensions import Literal, get_args


# noinspection SpellCheckingInspection
EncryptionAlgorithmSpecType = Literal["RSAES_OAEP_SHA_1",
                                      "RSAES_OAEP_SHA_256",
                                      "SYMMETRIC_DEFAULT"]

_valid_algorithms = get_args(EncryptionAlgorithmSpecType)
print(_valid_algorithms)


def get_encryption_algorithm_literal(
        algorithm: str) -> EncryptionAlgorithmSpecType:

    if algorithm not in _valid_algorithms:
        valid_values = str(list(_valid_algorithms)).replace("'", "")
        raise Exception(f"Unexpected algorithm '{algorithm}'. "
                        f"It must be one of {valid_values}")

    # cast string to literal, so static type checkers such as 'mypy'
    # don't complain.
    return cast(EncryptionAlgorithmSpecType, algorithm)


def main():
    # noinspection SpellCheckingInspection
    string = 'RSAES_OAEP_SHA_256'

    my_algorithm = get_encryption_algorithm_literal(string)
    print(type(my_algorithm), my_algorithm)


if __name__ == '__main__':
    main()

输出:

('RSAES_OAEP_SHA_1', 'RSAES_OAEP_SHA_256', 'SYMMETRIC_DEFAULT')
<class 'str'> RSAES_OAEP_SHA_256

无效输入的结果,如 'RSAES_OAEP_SHA_2567':

Traceback (most recent call last):
  ...
    raise Exception(...)
Exception: Unexpected algorithm 'RSAES_OAEP_SHA_2567'. It must be one of [RSAES_OAEP_SHA_1, RSAES_OAEP_SHA_256, SYMMETRIC_DEFAULT]