条件状态 monad 表达式

Conditional state monad expressions

我正在使用 Scala Cats 库中的 State monad 以功能方式编写状态转换的命令序列。

我的实际用例相当复杂,所以为了简化问题,考虑以下最小问题:有一个 Counter 状态保持计数值可以递增或递减;但是,如果计数变为负数或溢出,则会出错。如果遇到错误,我需要保留错误发生时的状态,并有效地停止处理后续的状态转换。

我使用每个状态转换的 return 值来报告任何错误,使用类型 Try[Unit]。成功完成的操作 return 是新状态加上值 Success(()),而失败 return 是现有状态加上包裹在 Failure.[=32 中的异常=]

注意:很明显,我可以在遇到错误时抛出异常。但是,这将违反 引用透明性 并且还需要我做一些额外的工作来将计数器状态存储在抛出的异常中。我还拒绝使用 Try[Counter] 作为状态类型(而不仅仅是 Counter),因为我无法使用它来跟踪失败和失败状态。我还没有探索过的一种选择是使用 (Counter, Try[Unit]) 元组作为状态,因为这看起来太麻烦了,但我愿意接受建议。

import cats.data.State
import scala.util.{Failure, Success, Try}

// State being maintained: an immutable counter.
final case class Counter(count: Int)

// Type for state transition operations.
type Transition[M] = State[Counter, Try[M]]

// Operation to increment a counter.
val increment: Transition[Unit] = State {c =>

  // If the count is at its maximum, incrementing it must fail.
  if(c.count == Int.MaxValue) {
    (c, Failure(new ArithmeticException("Attempt to overflow counter failed")))
  }

  // Otherwise, increment the count and indicate success.
  else (c.copy(count = c.count + 1), Success(()))
}

// Operation to decrement a counter.
val decrement: Transition[Unit] = State {c =>

  // If the count is zero, decrementing it must fail.
  if(c.count == 0) {
    (c, Failure(new ArithmeticException("Attempt to make count negative failed")))
  }

  // Otherwise, decrement the count and indicate success.
  else (c.copy(count = c.count - 1), Success(()))
}

但是,我正在努力确定将转换串联在一起的最佳方法,同时以所需的方式处理任何失败。 (如果您愿意,对我的问题的更笼统的陈述是我需要根据前一个转换的 returned 值有条件地执行后续转换。)

例如,以下一组转换可能会在第一步、第三步或第四步失败(但我们假设它也可能在第二步失败),具体取决于计数器的起始状态,但它仍会尝试执行无条件下一步:

val counterManip: Transition[Unit] = for {
  _ <- decrement
  _ <- increment
  _ <- increment
  r <- increment
} yield r

如果我 运行 此代码的初始计数器值为 0,显然我将得到新的计数器值 3 和 Success(()),因为这是最后一步的结果:

scala> counterManip.run(Counter(0)).value
res0: (Counter, scala.util.Try[Unit]) = (Counter(3),Success(()))

但我想要的是获得初始计数器状态(decrement 操作失败的状态)和包裹在 Failure 中的 ArithmeticException,因为第一步失败了。

到目前为止,我能想到的唯一解决方案极其复杂、重复且容易出错:

val counterManip: Transition[Unit] = State {s0 =>
  val r1 = decrement.run(s0).value
  if(r1._2.isFailure) r1
  else {
    val r2 = increment.run(r1._1).value
    if(r2._2.isFailure) r2
    else {
      val r3 = increment.run(r2._1).value
      if(r3._2.isFailure) r3
      else increment.run(r3._1).value
    }
  }
}

给出正确的结果:

scala> counterMap.run(Counter(0)).value
res1: (Counter, scala.util.Try[Unit]) = (Counter(0),Failure(java.lang.ArithmeticException: Attempt to make count negative failed))

更新

我想出了 untilFailure 方法 ,用于 运行 转换序列,直到它们完成或直到发生错误(以先到者为准)。我很喜欢它,因为它使用起来简单而优雅。

但是,我仍然很好奇是否有一种优雅的方式可以更直接地将转换链接在一起。 (例如,如果转换只是 returned Try[T] 的常规函数​​——并且没有状态——那么我们可以使用 flatMap 将调用链接在一起,从而允许构建 for 表达式,它将成功转换的结果传递到下一个转换。)

你能推荐一个更好的方法吗?

呸!我不知道为什么我没有早点想到这一点。有时只是用更简单的术语解释你的问题会迫使你重新审视它,我猜......

一种可能性是处理转换序列,以便仅当当前任务成功时才执行下一个任务。

// Run a sequence of transitions, until one fails.
def untilFailure[M](ts: List[Transition[M]]): Transition[M] = State {s =>
  ts match {

    // If we have an empty list, that's an error. (Cannot report a success value.)
    case Nil => (s, Failure(new RuntimeException("Empty transition sequence")))

    // If there's only one transition left, perform it and return the result.
    case t :: Nil => t.run(s).value

    // Otherwise, we have more than one transition remaining.
    //
    // Run the next transition. If it fails, report the failure, otherwise repeat
    // for the tail.
    case t :: tt => {
      val r = t.run(s).value
      if(r._2.isFailure) r
      else untilFailure(tt).run(r._1).value
    }
  }
}

然后我们可以将 counterManip 实现为一个序列。

val counterManip: Transition[Unit] = for {
  r <- untilFailure(List(decrement, increment, increment, increment))
} yield r

给出正确的结果:

scala> counterManip.run(Counter(0)).value
res0: (Counter, scala.util.Try[Unit]) = (Counter(0),Failure(java.lang.ArithmeticException: Attempt to make count negative failed))

scala> counterManip.run(Counter(1)).value
res1: (Counter, scala.util.Try[Unit]) = (Counter(3),Success(()))

scala> counterManip.run(Counter(Int.MaxValue - 2)).value
res2: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Success(()))

scala> counterManip.run(Counter(Int.MaxValue - 1)).value
res3: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))

scala> counterManip.run(Counter(Int.MaxValue)).value
res4: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))

缺点是所有转换都需要有一个共同的 return 值(除非您接受 Any 结果)。

据我了解,您的计算有两种状态,您可以将其定义为 ADT

sealed trait CompState[A]
case class Ok[A](value: A) extends CompState[A]
case class Err[A](lastValue: A, cause: Exception) extends CompState[A]

您可以采取的下一步是为 CompState 定义一个 update 方法,以封装您在链接计算时应该发生什么的逻辑。

def update(f: A => A): CompState[A] = this match {
  case Ok(a) => 
    try Ok(f(a))
    catch { case e: Exception => Err(a, e) }
  case Err(a, e) => Err(a, e)
}

从那里,重新定义

type Transition[M] = State[CompState[Counter], M]

// Operation to increment a counter.
// note: using `State.modify` instead of `.apply`
val increment: Transition[Unit] = State.modify { cs =>
  // use the new `update` method to take advantage of your chaining semantics
  cs update{ c =>
    // If the count is at its maximum, incrementing it must fail.
    if(c.count == Int.MaxValue) {
      throw new ArithmeticException("Attempt to overflow counter failed")
    }

    // Otherwise, increment the count and indicate success.
    else c.copy(count = c.count + 1)
  }
}

// Operation to decrement a counter.
val decrement: Transition[Unit] = State.modify { cs =>
  cs update { c =>
    // If the count is zero, decrementing it must fail.
    if(c.count == 0) {
      throw new ArithmeticException("Attempt to make count negative failed")
    }

    // Otherwise, decrement the count and indicate success.
    else c.copy(count = c.count - 1)
  }
}

请注意,在上面更新的 increment/decrement 转换中,我使用了 State.modify,它会更改状态,但不会生成结果。看起来在转换结束时获取当前状态的 "idiomatic" 方法是使用 State.get,即

val counterManip: State[CompState[Counter], CompState[Counter]] = for {
    _ <- decrement
    _ <- increment
    _ <- increment
    _ <- increment
    r <- State.get
} yield r

你可以 运行 使用 runA 帮助程序丢弃最终状态,即

counterManip.runA(Ok(Counter(0))).value
// Err(Counter(0),java.lang.ArithmeticException: Attempt to make count negative failed)