如何缩小状态机编码中上限类型参数的类型?

How to narrow the type of an upper-bounded type parameter in a state machine encoding?

假设我有一个 Cake 可以在多个状态之间循环:

sealed trait State extends Product with Serializable
object State {
    final case object Raw extends State
    final case class JustRight(temperature: Int) extends State
    final case class Burnt(charCoalContent: Double) extends State
}
final case class Cake[S <: State](name: String, state: S)

这很好,因为现在我可以确保只尝试将 Raw 蛋糕放入烤箱,而不是立即食用。

但有时我只是有一个 Cake[State] 躺在身边,想尝试吃它,但前提是它处于可食用状态。我当然可以总是在 cake.state 上进行模式匹配,但我认为应该可以通过添加方法 def narrow[S <: State]: Cake[State] => Option[Cake[S]].

来节省我自己的一些击键次数

但是,现在我正在努力实际实现该功能。编译器接受 Try(cake.asInstanceOf[Cake[S]]).toOption,但似乎总是会成功(我猜是因为类型参数被删除了,实际上任何类型 A 都会在这里被接受,而不仅仅是 S)。似乎有效的是 Try(cake.copy(state = cake.state.asInstanceOf[S])).toOption,但现在我制作了一个多余的数据副本。还有其他更好的方法吗?或者整个编码可能从一开始就存在缺陷?

如果您在编译时不知道 Cake 状态的具体类型,您可能想使用 ClassTag 来检查您的状态类型,因为抛出和抓住 ClassCastExceptions 对我来说似乎不是个好主意:

def narrow[S <: State](cake: Cake[_])(implicit classTag: ClassTag[S]): Option[Cake[S]] =
  Option.when(classTag.runtimeClass.isInstance(cake.state))(cake.asInstanceOf[Cake[S]])

Scastie

这通过检查蛋糕的状态是否是 S 的已擦除 class 的实例来工作。但是,如果您的状态采用类型参数,您可能希望改用 TypeTag

但是,如果您确实知道具体类型,则可能需要使用 =:= <:< 来检查(<:< 拒绝 Cake[State]):

def narrow[S <: State] = new NarrowFn[S]

class NarrowFn[S <: State] {
  def apply[S2 <: State](cake: Cake[S2])(implicit ev: S2 <:< S = null): Option[Cake[S]] =
    Option.when(ev != null)(cake.asInstanceOf[Cake[S]])
}

Scastie

请注意,这些都不是很好的解决方案,我建议只为每种情况制作一个单独的方法,并使用普通模式匹配来获得答案。

您可以使用 typeclass 来解决这个问题,它检查并转换 (以类型安全的方式) 状态的类型。

sealed trait State extends Product with Serializable
object State {
    final case object Raw extends State
    type Raw = Raw.type
    final case class JustRight(temperature: Int) extends State
    final case class Burnt(charCoalContent: Double) extends State
  
    sealed trait Checker[S <: State] {
      def check(state: State): Option[S]
    }
    object Checker {
      private def instance[S <: State](pf: PartialFunction[State, S]): Checker[S] =
        new Checker[S] {
          val f = pf.lift
          override def check(state: State): Option[S] = f(state)
        }
      
      implicit final val RawChecker: Checker[Raw] = instance {
        case Raw => Raw
      }
      
      implicit final val JustRightChecker: Checker[JustRight] = instance {
        case s @ JustRight(_) => s
      }
      
      implicit final val BurntChecker: Checker[Burnt] = instance {
        case s @ Burnt(_) => s
      }
    }
}

final case class Cake[S <: State](name: String, state: S)

def narrow[S <: State](cake: Cake[State])(implicit checker: State.Checker[S]): Option[Cake[S]] =
  checker.check(cake.state).map(s => cake.copy(state = s))

你可以这样使用:

val rawCake: Cake[State] = Cake(name = "Foo", state = State.Raw)

narrow[State.Raw](rawCake)
// res: Option[Cake[State.Raw]] = Some(Cake(Foo,Raw))
narrow[State.JustRight](rawCake)
// res: Option[Cake[State.JustRight] = None

顺便说一句,如果你想避免 copy,你可以将 check 更改为 return Boolean 并使用脏 asInstanceOf.

// Technically speaking it is unsafe, but it seems to work just right.
def narrowUnsafe[S <: State](cake: Cake[State])(implicit checker: State.Checker[S]): Option[Cake[S]] =
  if (checker.check(cake.state)) Some(cake.asInstanceOf[Cake[S]])
  else None

(可以看代码运行 here)