如何确保 Scala 泛型函数输出原始数组而不是对象数组?

How to ensure Scala generic function outputs a primitive array instead of an object array?

在我的项目中,我使用各种处理程序对不同原始类型的数组执行逻辑,我遇到了这个运行时错误:

[error] java.lang.ClassCastException: class [Ljava.lang.Object; cannot be cast to class [B ([Ljava.lang.Object; and [B are in module java.base of loader 'bootstrap')

在这种情况下,我使用的是 Bytes,因此是 [B,但是我的一个函数 returned [Ljava.lang.Object 而不是(在运行时!)。如何确保我的泛型函数 return 原始数组而不是对象数组?

这是一个最小的重现示例:

package asdf

import scala.reflect.ClassTag
import java.nio.charset.StandardCharsets

object Main {
  trait Helper[A, B] {
    def decode(bytes: Array[Byte]): Array[A]
    def aggregate(values: Array[A]): B
  }

  object StringHelper extends Helper[Byte, String] {
    def decode(bytes: Array[Byte]): Array[Byte] = bytes.filter(_ != 0)
    def aggregate(values: Array[Byte]): String = new String(values, StandardCharsets.UTF_8)
  }

  object IntHelper extends Helper[Int, Int] {
    def decode(bytes: Array[Byte]): Array[Int] = bytes.map(_.toInt)
    def aggregate(values: Array[Int]): Int = values.sum
  }

  def decodeAgg[A, B](bytes: Array[Byte], helper: Helper[A, B])(implicit ev: ClassTag[A]): B = {
    val decoded = helper.decode(bytes)
    val notFirstDecoded = decoded
      .zipWithIndex
      .filter({case (_, i) => i != 0})
      .map(_._1)
      .toArray
    helper.aggregate(notFirstDecoded)
  }

  def main(args: Array[String]) {
    val helper: Helper[_, _] = args(1) match {
      case "int" => IntHelper
      case "string" => StringHelper
      case _ => throw new Exception("unknown data type")
    }
    val bytes = Array(97, 98, 99).map(_.toByte)
    val aggregated = decodeAgg(bytes, helper)
    println(s"aggregated to $aggregated")
  }
}

运行 与 sbt "run -- string".

此示例的完整堆栈跟踪:

[error] java.lang.ClassCastException: class [Ljava.lang.Object; cannot be cast to class [B ([Ljava.lang.Object; and [B are in module java.base of loader 'bootstrap')
[error]     at asdf.Main$StringHelper$.aggregate(Main.scala:12)
[error]     at asdf.Main$.decodeAgg(Main.scala:29)
[error]     at asdf.Main$.main(Main.scala:39)

我使用的是 Scala 2.12,JDK13。 我试过使用 @specialized 没有效果。

问题出在你的toArray in call:

val notFirstDecoded = decoded
      .zipWithIndex
      .filter({case (_, i) => i != 0})
      .map(_._1)
      .toArray

toArray 采用隐式 ClassTag 参数 - 它需要知道数组元素的运行时类型才能创建数组。

当您向 decodeAgg 提供隐式 ClassTag 参数时,toArray 很乐意接受您提供的内容。

def decodeAgg[A, B](bytes: Array[Byte], helper: Helper[A, B])(implicit ev: ClassTag[A]): B

可以看到ClassTag对应Helper的第一个泛型参数

你通过了以下帮手:

val helper: Helper[_, _] = args(1) match {
      case "int" => IntHelper
      case "string" => StringHelper
      case _ => throw new Exception("unknown data type")
    }

ClassTag 因此被推断为 Object,这就是为什么您会得到一个对象数组。

请注意,如果您直接使用 IntHelper,ClassTag 会被限制为正确的类型,并且函数会按预期工作。

val aggregated = decodeAgg(bytes, IntHelper)

解决思路

可能有多种解决方法。 一种想法可能是通过 Helper

显式提供 classTag
import scala.reflect.ClassTag
import java.nio.charset.StandardCharsets

object Main {
  trait Helper[A, B] {
    def decode(bytes: Array[Byte]): Array[A]
    def aggregate(values: Array[A]): B
    def classTag: ClassTag[A]
  }

  object StringHelper extends Helper[Byte, String] {
    def decode(bytes: Array[Byte]): Array[Byte] = bytes.filter(_ != 0)
    def aggregate(values: Array[Byte]): String = new String(values, StandardCharsets.UTF_8)
    def classTag: ClassTag[Byte] = ClassTag(classOf[Byte])
  }

  object IntHelper extends Helper[Int, Int] {
    def decode(bytes: Array[Byte]): Array[Int] = bytes.map(_.toInt)
    def aggregate(values: Array[Int]): Int = values.sum
    def classTag: ClassTag[Int] = ClassTag(classOf[Int])
  }

  def decodeAgg[A, B](bytes: Array[Byte], helper: Helper[A, B]): B = {
    val decoded = helper.decode(bytes)
    val notFirstDecoded = decoded
      .zipWithIndex
      .filter({case (_, i) => i != 0})
      .map(_._1)
      .toArray(helper.classTag)
    helper.aggregate(notFirstDecoded)
  }

  def main(args: Array[String]) {
    val helper = args(1) match {
      case "int" => IntHelper
      case "string" => StringHelper
      case _ => throw new Exception("unknown data type")
    }
    val bytes = Array(97, 98, 99).map(_.toByte)
    val aggregated = decodeAgg(bytes, helper)
    println(s"aggregated to $aggregated")
  }
}