如何使用 scala case class 作为函数参数来指定要使用的不同字段?

How to use scala case class as function parameter to specify different fields to use?

由于必须对一个案例 class 中的 3 个不同字段进行一些分组,然后用这些填充新案例 class,因此我有一些重复代码。由于它们共享一个通用模式,因此我应该可以执行一个函数,该函数可以接受 3 个不同字段的输入并相应地进行填充。但是,我不太确定该怎么做。

模式:

case class Transaction(
  senderBank: Bank,
  receiverBank: Bank,
  intermediaryBank: Bank)

case class Bank(
  name: String,
  country: Option[String],
  countryCode: Option[String])

case class GroupedBank(
  name: String,
  country: Option[String],
  countryCode: Option[String],
  bankType: String)

我尝试执行的当前功能:

def groupedBank(transactionSeq: Seq[Transaction], bankName: Bank, bankTypeString: String): Iterable[Seq[GroupedBank]] = {
 transactionSeq.groupBy(_ => bankName.name).map {
  case (key, transactionSeq) =>
    val bankGroupedSeq = transactionSeq.map(_ => {
      GroupedBank(
        name = bankName.name,
        country = bankName.country,
        countryCode = bankName.countryCode,
        bankType = bankTypeString)
    })
    bankGroupedSeq
  }
}

我需要对 SenderBankreceiverBankintermediaryBank 进行分组。但是,我不确定如何在函数参数 bankName 中正确引用它们。因此,对于 SenderBank,我想做类似 Transaction.senderBank 的事情,它会为 senderBank 指向正确的名称、国家等字段。对于 receiverBank 它应该是相似的,所以 Transactions.receiverBank,然后引用 receiverBank 的正确字段等等。 intermediaryBank 同样的逻辑。因此,我的问题是我怎样才能完成这样的事情,或者还有其他更合适的方法吗?

您可以传递一个函数来从交易中提取正确类型的银行:

def groupedBank(
  transactionSeq: Seq[Transaction], 
  getBank: Transaction => Bank, 
  bankTypeString: String
): Iterable[Seq[GroupedBank]] = {
  transactionSeq.groupBy(getBank(_).name).map {
    case (key, transactionSeq) =>
      transactionSeq.map { transaction =>
        val bank = getBank(transaction)
        GroupedBank(
          name = bank.name,
          country = bank.country,
          countryCode = bank.countryCode,
          bankType = bankTypeString)
      }
  }
}

然后这样称呼它:

groupedBank(transactionSeq, _.senderBank, "sender")

将银行类型概念抽象成一个单独的特征也是一个好主意:

sealed trait BankGroup {
  def name: String
  def getBank(transaction: Transaction): Bank

  def groupBanks(transactionSeq: Seq[Transaction]): Iterable[Seq[GroupedBank]] = {
    transactionSeq.groupBy(getBank(_).name).map {
      case (key, transactionSeq) =>
        transactionSeq.map { transaction =>
          val bank = getBank(transaction)
          GroupedBank(
            name = bank.name,
            country = bank.country,
            countryCode = bank.countryCode,
            bankType = name)
        }
    }
  }
}

object BankGroup {
  object Sender extends BankGroup {
    def name: String = "sender"
    def getBank(transaction: Transaction): Bank = transaction.senderBank
  }

  object Receiver extends BankGroup {
    def name: String = "receiver"
    def getBank(transaction: Transaction): Bank = transaction.receiverBank
  }

  object Intermediary extends BankGroup {
    def name: String = "intermediary"
    def getBank(transaction: Transaction): Bank = transaction.intermediaryBank
  }

  val values: Seq[BankGroup] = Seq(Sender, Receiver, Intermediary)
  def byName(name: String): BankGroup = values.find(_.name == name)
    .getOrElse(sys.error(s"unknown bank type: $name"))
}

您可以通过以下方式之一调用它:

BankGroup.Sender.groupBanks(transactionSeq)
BankGroup.byName("sender").groupBanks(transactionSeq)