根据条件从 Seq 中检索值

Retrieving Values from a Seq based on a condition

我有下面的Seq

scala>   var al = Seq((1.0, 20.0, 100.0), (2.0, 30.0, 100.0), (1.0, 11.0, 100.0), (1.0, 20.0, 100.0), (1.0, 10.0, 100.0),(2.0,9.0,100.0))
al: Seq[(Double, Double, Double)] = List((1.0,20.0,100.0), (2.0,30.0,100.0), (1.0,11.0,100.0), (1.0,20.0,100.0), (1.0,10.0,100.0), (2.0,9.0,100.0))

如何从这个 Seq 中取出前 n 个元素,其中 - 第二项的总和大于第三项的 60%(这将是一个常量值)

预期输出 -

scala> Seq((1.0, 20.0, 100.0), (2.0, 30.0, 100.0), (1.0, 11.0, 100.0))
res30: Seq[(Double, Double, Double)] = List((1.0,20.0,100.0), (2.0,30.0,100.0), (1.0,11.0,100.0))

编辑-1

我用一些丑陋的方式做到了这一点。但是,如果有任何漂亮的功能性方法可以解决这个问题,那就太好了。我真的需要摆脱这个柜台,这对我来说有点棘手。

var counter =0d ; var sum = 0d 
var abc :Seq[Double] = for (x <- al) yield {
counter+=1 ;sum = sum + x._2 ; 
if (sum > 60) counter else 0
} 
println(al.take(abc.filterNot(_==0).min.toInt))

您可以使用带累加器的递归函数来获取所需的元素数量:

def getIndex(input: Seq[(Double,Double,Double)], index: Int = 0, sum: Double = 0.0): Int = {
    input match {
    case Seq() => index
    case Seq(head, _@_*) if head._2 + sum > head._3 * 0.6 => index + 1  
    case Seq(head, tail@_*) => getIndex(tail, index + 1, sum + head._2)
  }
}

al.take(getIndex(al))

或构建列表:

def getList(input: Seq[(Double,Double,Double)], sum: Double = 0.0): List[(Double,Double,Double)] = {
    input match {
    case Seq() => List.empty
    case Seq(head, _@_*) if head._2 + sum > head._3 * 0.6 => head :: List.empty
    case Seq(head, tail@_*) => head :: get(tail, sum + head._2)
  }
}

getList(al)

如评论中所述 - 将 al 限制为 List 可以允许使用 ::,这应该更有效:

def get(input: List[(Double,Double,Double)], sum: Double = 0.0): List[(Double,Double,Double)] = {
    input match {
    case Seq() => List.empty
    case head :: tail if head._2 + sum > head._3 * 0.6 => head :: List.empty
    case head :: tail => head :: get(tail,  sum + head._2)
  }
 }

如果你使用的是 Scala 2.13.x 那么你可能 unfold():

Seq.unfold((al,0.0)){ case (lst,sum) =>
  Option.when(lst.nonEmpty && sum < 0.6 * lst.head._3) {
    (lst.head, (lst.tail, sum + lst.head._2))
  }
}
//res0: Seq[(Double, Double, Double)] =
// List((1.0,20.0,100.0), (2.0,30.0,100.0), (1.0,11.0,100.0))

在早期的 Scala 版本中,您可以 scanLeft() 计算总和并使用它来计算要 take() 的元素数量。

val cnt = al.view
            .scanLeft(0.0)(_ + _._2)
            .takeWhile(_ < 0.6 * al.headOption.getOrElse((0d,0d,0d))._3)
            .length
al.take(cnt)