Chisel:如何从命令行更改模块参数?

Chisel: How to change module parameters from command line?

我有很多带有多个参数的模块。以模板中 GCD 的修改版本为例:

class GCD (len: Int = 16, validHigh: Boolean = true) extends Module {
  val io = IO(new Bundle {
    val value1        = Input(UInt(len.W))
    val value2        = Input(UInt(len.W))
    val loadingValues = Input(Bool())
    val outputGCD     = Output(UInt(len.W))
    val outputValid   = Output(Bool())
  })

  val x  = Reg(UInt())
  val y  = Reg(UInt())

  when(x > y) { x := x - y }
    .otherwise { y := y - x }

  when(io.loadingValues) {
    x := io.value1
    y := io.value2
  }

  io.outputGCD := x
  if (validHigh) {
    io.outputValid := (y === 0.U)
  } else {
    io.outputValid := (y =/= 0.U)
  }
}

为了测试或综合许多不同的设计,我想在调用测试器或生成器应用程序时从命令行更改值。最好是这样:

[generation or test command] --len 12 --validHigh false

但是这个或类似的东西也可以

[generation or test command] --param "len=12" --param "validHigh=false"

经过反复试验,我想出了一个如下所示的解决方案:

gcd.scala

package gcd

import firrtl._
import chisel3._

case class GCDConfig(
  len: Int = 16,
  validHigh: Boolean = true
)

class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
  val io = IO(new Bundle {
    val value1        = Input(UInt(conf.len.W))
    val value2        = Input(UInt(conf.len.W))
    val loadingValues = Input(Bool())
    val outputGCD     = Output(UInt(conf.len.W))
    val outputValid   = Output(Bool())
  })

  val x  = Reg(UInt())
  val y  = Reg(UInt())

  when(x > y) { x := x - y }
    .otherwise { y := y - x }

  when(io.loadingValues) {
    x := io.value1
    y := io.value2
  }

  io.outputGCD := x
  if (conf.validHigh) {
    io.outputValid := y === 0.U
  } else {
    io.outputValid := y =/= 0.U
  }
}

trait HasParams {
  self: ExecutionOptionsManager =>

  var params: Map[String, String] = Map()

  parser.note("Design Parameters")

  parser.opt[Map[String, String]]('p', "params")
    .valueName("k1=v1,k2=v2")
    .foreach { v => params = v }
    .text("Parameters of Design")
}

object GCD {
  def apply(params: Map[String, String]): GCD = {
    new GCD(params2conf(params))
  }

  def params2conf(params: Map[String, String]): GCDConfig = {
    var conf = new GCDConfig
    for ((k, v) <- params) {
      (k, v) match {
        case ("len", _) => conf = conf.copy(len = v.toInt)
        case ("validHigh", _) => conf = conf.copy(validHigh = v.toBoolean)
        case _ =>
      }
    }
    conf
  }
}

object GCDGen extends App {
  val optionsManager = new ExecutionOptionsManager("gcdgen")
  with HasChiselExecutionOptions with HasFirrtlOptions with HasParams
  optionsManager.parse(args) match {
    case true => 
      chisel3.Driver.execute(optionsManager, () => GCD(optionsManager.params))
    case _ =>
      ChiselExecutionFailure("could not parse results")
  }
}

和测试

GCDSpec.scala

package gcd

import chisel3._
import firrtl._
import chisel3.tester._
import org.scalatest.FreeSpec
import chisel3.experimental.BundleLiterals._
import chiseltest.internal._
import chiseltest.experimental.TestOptionBuilder._

object GCDTest extends App {
  val optionsManager = new ExecutionOptionsManager("gcdtest") with HasParams
  optionsManager.parse(args) match {
    case true => 
      //println(optionsManager.commonOptions.programArgs)
      (new GCDSpec(optionsManager.params)).execute()
    case _ =>
      ChiselExecutionFailure("could not parse results")
  }
}

class GCDSpec(params: Map[String, String] = Map()) extends FreeSpec with ChiselScalatestTester {

  "Gcd should calculate proper greatest common denominator" in {
    test(GCD(params)) { dut =>
      dut.io.value1.poke(95.U)
      dut.io.value2.poke(10.U)
      dut.io.loadingValues.poke(true.B)
      dut.clock.step(1)
      dut.io.loadingValues.poke(false.B)
      while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
        dut.clock.step(1)
      }
      dut.io.outputGCD.expect(5.U)
    }
  }
}

这样,我可以生成不同的设计并用

测试它们
sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'

但是这个解决方案有几个问题或烦恼:

  1. 它使用已弃用的功能(ExecutionOptionsManager 和 HasFirrtlOptions)。我不确定此解决方案是否可移植到新的 FirrtlStage 基础架构。
  2. 涉及很多样板文件。为每个模块编写新的 case 类 和 params2conf 函数并在添加或删除参数时重写这两个函数变得乏味。
  3. 一直使用 conf.x 而不是 x。但我想,这是不可避免的,因为 Scala 中没有 python 的 kwargs

是否有更好的方法或至少没有被弃用的方法?

好问题。 我认为你是你几乎所有的权利。我通常不会发现我需要命令行来改变我的测试,我的开发周期通常只是直接在测试参数中戳值 运行。我使用的是 intelliJ,这似乎让这一切变得简单(但可能只适合我的习惯和我从事的项目规模)。

但是我想给你一个建议,让你远离 ExecutionOptions 风格,因为它很快就会消失。

在我下面的示例代码中,我基本上提供了两个文件,第一个文件是一些库,例如使用现代注释习语的工具,我相信,最小化样板。他们依赖于字符串匹配,但这是可以修复的。 第二个是您的 GCD,GCDSpec,稍微修改以稍微不同地提取参数。第二个底部是一些非常小的样板,可让您获得所需的命令行访问权限。

祝你好运,我希望这基本上是不言自明的。

第一个文件:

import chisel3.stage.ChiselCli
import firrtl.AnnotationSeq
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{HasShellOptions, Shell, ShellOption, Stage, Unserializable}
import firrtl.stage.FirrtlCli

trait TesterAnnotation {
  this: Annotation =>
}

case class TestParams(params: Map[String, String] = Map.empty) {
  val defaults: collection.mutable.HashMap[String, String] = new collection.mutable.HashMap()

  def getInt(key:     String): Int = params.getOrElse(key, defaults(key)).toInt
  def getBoolean(key: String): Boolean = params.getOrElse(key, defaults(key)).toBoolean
  def getString(key:  String): String = params.getOrElse(key, defaults(key))
}
case class TesterParameterAnnotation(paramString: TestParams)
    extends TesterAnnotation
    with NoTargetAnnotation
    with Unserializable

object TesterParameterAnnotation extends HasShellOptions {
  val options = Seq(
    new ShellOption[Map[String, String]](
      longOption = "param-string",
      toAnnotationSeq = (a: Map[String, String]) => Seq(TesterParameterAnnotation(TestParams(a))),
      helpText = """a comma separated, space free list of additional paramters, e.g. --param-string "k1=7,k2=dog" """
    )
  )
}

trait TesterCli {
  this: Shell =>

  Seq(TesterParameterAnnotation).foreach(_.addOptions(parser))
}

class GenericTesterStage(thunk: (TestParams, AnnotationSeq) => Unit) extends Stage {
  val shell: Shell = new Shell("chiseltest") with TesterCli with ChiselCli with FirrtlCli

  def run(annotations: AnnotationSeq): AnnotationSeq = {
    val params = annotations.collectFirst { case TesterParameterAnnotation(p) => p }.getOrElse(TestParams())

    thunk(params, annotations)
    annotations
  }
}

第二个文件:

import chisel3._
import chisel3.tester._
import chiseltest.experimental.TestOptionBuilder._
import chiseltest.{ChiselScalatestTester, GenericTesterStage, TestParams}
import firrtl._
import firrtl.options.StageMain
import org.scalatest.freespec.AnyFreeSpec

case class GCD(testParams: TestParams) extends Module {
  val bitWidth = testParams.getInt("len")
  val validHigh = testParams.getBoolean("validHigh")

  val io = IO(new Bundle {
    val value1 = Input(UInt(bitWidth.W))
    val value2 = Input(UInt(bitWidth.W))
    val loadingValues = Input(Bool())
    val outputGCD = Output(UInt(bitWidth.W))
    val outputValid = Output(Bool())
  })

  val x = Reg(UInt())
  val y = Reg(UInt())

  when(x > y) { x := x - y }.otherwise { y := y - x }

  when(io.loadingValues) {
    x := io.value1
    y := io.value2
  }

  io.outputGCD := x
  if (validHigh) {
    io.outputValid := y === 0.U
  } else {
    io.outputValid := y =/= 0.U
  }
}

class GCDSpec(params: TestParams, annotations: AnnotationSeq = Seq()) extends AnyFreeSpec with ChiselScalatestTester {

  "Gcd should calculate proper greatest common denominator" in {
    test(GCD(params)).withAnnotations(annotations) { dut =>
      dut.io.value1.poke(95.U)
      dut.io.value2.poke(10.U)
      dut.io.loadingValues.poke(true.B)
      dut.clock.step(1)
      dut.io.loadingValues.poke(false.B)
      while (dut.io.outputValid.peek().litToBoolean != dut.validHigh) {
        dut.clock.step(1)
      }
      dut.io.outputGCD.expect(5.U)
    }
  }
}

class GcdTesterStage
    extends GenericTesterStage((params, annotations) => {
      params.defaults ++= Seq("len" -> "16", "validHigh" -> "false")
      (new GCDSpec(params, annotations)).execute()
    })

object GcdTesterStage extends StageMain(new GcdTesterStage)

基于http://blog.echo.sh/2013/11/04/exploring-scala-macros-map-to-case-class-conversion.html, I was able to find another way of removing the params2conf boilerplate using scala macros. I also extended Chick's answer with verilog generation since that was also part of the original question. A full repository of my solution can be found on github.

基本上就是三四个文件:

将地图转换为案例的宏class:

package mappable

import scala.language.experimental.macros
import scala.reflect.macros.whitebox.Context

trait Mappable[T] {
  def toMap(t: T): Map[String, String]
  def fromMap(map: Map[String, String]): T
}

object Mappable {
  implicit def materializeMappable[T]: Mappable[T] = macro materializeMappableImpl[T]

  def materializeMappableImpl[T: c.WeakTypeTag](c: Context): c.Expr[Mappable[T]] = {
    import c.universe._
    val tpe = weakTypeOf[T]
    val companion = tpe.typeSymbol.companion

    val fields = tpe.decls.collectFirst {
      case m: MethodSymbol if m.isPrimaryConstructor => m
    }.get.paramLists.head

    val (toMapParams, fromMapParams) = fields.map { field =>
      val name = field.name.toTermName
      val decoded = name.decodedName.toString
      val returnType = tpe.decl(name).typeSignature

      val fromMapLine = returnType match {
        case NullaryMethodType(res) if res =:= typeOf[Int] => q"map($decoded).toInt"
        case NullaryMethodType(res) if res =:= typeOf[String] => q"map($decoded)"
        case NullaryMethodType(res) if res =:= typeOf[Boolean] => q"map($decoded).toBoolean"
        case _ => q""
      }

      (q"$decoded -> t.$name.toString", fromMapLine)
    }.unzip

    c.Expr[Mappable[T]] { q"""
      new Mappable[$tpe] {
        def toMap(t: $tpe): Map[String, String] = Map(..$toMapParams)
        def fromMap(map: Map[String, String]): $tpe = $companion(..$fromMapParams)
      }
    """ }
  }
}

类库工具:

package cliparams

import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation, ChiselCli}
import firrtl.AnnotationSeq
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{HasShellOptions, Shell, ShellOption, Stage, Unserializable, StageMain}
import firrtl.stage.FirrtlCli

import mappable._

trait SomeAnnotaion {
  this: Annotation =>
}

case class ParameterAnnotation(map: Map[String, String])
    extends SomeAnnotaion
    with NoTargetAnnotation
    with Unserializable

object ParameterAnnotation extends HasShellOptions {
  val options = Seq(
    new ShellOption[Map[String, String]](
      longOption = "params",
      toAnnotationSeq = (a: Map[String, String]) => Seq(ParameterAnnotation(a)),
      helpText = """a comma separated, space free list of additional paramters, e.g. --param-string "k1=7,k2=dog" """
    )
  )
}

trait ParameterCli {
  this: Shell =>

  Seq(ParameterAnnotation).foreach(_.addOptions(parser))
}

class GenericParameterCliStage[P: Mappable](thunk: (P, AnnotationSeq) => Unit, default: P) extends Stage {

  def mapify(p: P) = implicitly[Mappable[P]].toMap(p)
  def materialize(map: Map[String, String]) = implicitly[Mappable[P]].fromMap(map)

  val shell: Shell = new Shell("chiseltest") with ParameterCli with ChiselCli with FirrtlCli

  def run(annotations: AnnotationSeq): AnnotationSeq = {
    val params = annotations
      .collectFirst {case ParameterAnnotation(map) => materialize(mapify(default) ++ map.toSeq)}
      .getOrElse(default)

    thunk(params, annotations)
    annotations
  }
}

GCD 源文件

// See README.md for license details.

package gcd

import firrtl._
import chisel3._
import chisel3.stage.{ChiselStage, ChiselGeneratorAnnotation}
import firrtl.options.{StageMain}

// Both have to be imported
import mappable._
import cliparams._

case class GCDConfig(
  len: Int = 16,
  validHigh: Boolean = true
)

/**
  * Compute GCD using subtraction method.
  * Subtracts the smaller from the larger until register y is zero.
  * value in register x is then the GCD
  */
class GCD (val conf: GCDConfig = GCDConfig()) extends Module {
  val io = IO(new Bundle {
    val value1        = Input(UInt(conf.len.W))
    val value2        = Input(UInt(conf.len.W))
    val loadingValues = Input(Bool())
    val outputGCD     = Output(UInt(conf.len.W))
    val outputValid   = Output(Bool())
  })

  val x  = Reg(UInt())
  val y  = Reg(UInt())

  when(x > y) { x := x - y }
    .otherwise { y := y - x }

  when(io.loadingValues) {
    x := io.value1
    y := io.value2
  }

  io.outputGCD := x
  if (conf.validHigh) {
    io.outputValid := y === 0.U
  } else {
    io.outputValid := y =/= 0.U
  }
}

class GCDGenStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
  (new chisel3.stage.ChiselStage).execute(
    Array("-X", "verilog"),
    Seq(ChiselGeneratorAnnotation(() => new GCD(params))))}, GCDConfig())

object GCDGen extends StageMain(new GCDGenStage)

和测试

// See README.md for license details.

package gcd

import chisel3._
import firrtl._
import chisel3.tester._
import org.scalatest.FreeSpec
import chisel3.experimental.BundleLiterals._
import chiseltest.internal._
import chiseltest.experimental.TestOptionBuilder._
import firrtl.options.{StageMain}

import mappable._
import cliparams._

class GCDSpec(params: GCDConfig, annotations: AnnotationSeq = Seq()) extends FreeSpec with ChiselScalatestTester {

  "Gcd should calculate proper greatest common denominator" in {
    test(new GCD(params)) { dut =>
      dut.io.value1.poke(95.U)
      dut.io.value2.poke(10.U)
      dut.io.loadingValues.poke(true.B)
      dut.clock.step(1)
      dut.io.loadingValues.poke(false.B)
      while (dut.io.outputValid.peek().litToBoolean != dut.conf.validHigh) {
        dut.clock.step(1)
      }
      dut.io.outputGCD.expect(5.U)
    }
  }
}

class GCDTestStage extends GenericParameterCliStage[GCDConfig]((params, annotations) => {
  (new GCDSpec(params, annotations)).execute()}, GCDConfig())

object GCDTest extends StageMain(new GCDTestStage)

生成和测试都可以像在 OQ 中一样通过 CLI 进行参数化:

sbt 'runMain gcd.GCDGen --params "len=12,validHigh=false"'
sbt 'test:runMain gcd.GCDTest --params "len=12,validHigh=false"'