光滑的代码生成器和超过 22 列的表格

Slick codegen & tables with > 22 columns

我是 Slick 的新手。我正在使用 Scala、ScalaTest 和 Slick 为 Java 应用程序创建一个测试套件。我正在使用 slick 在测试前准备数据,并在测试后对数据进行断言。使用的数据库有一些表超过 22 列。我使用 slick-codegen 来生成我的模式代码。

对于超过 22 列的表,slick-codegen 不会生成案例 class,而是生成基于 HList 的自定义类型和伴随的“构造函数”方法。据我了解,这是因为元组和大小写 classes 只能有 22 个字段的限制。代码生成方式,Row-object的字段只能通过索引访问。

我对此有几个问题:

  1. 据我了解,case classes 的 22 个字段限制已经在 Scala 2.11 中修复,对吗?
  2. 如果是这样,是否可以自定义 slick-codegen 来为所有表生成 case classes?我调查了这个:我设法在覆盖的 SourceCodeGenerator 中设置 override def hlistEnabled = false。但这会导致 Cannot generate tuple for > 22 columns, please set hlistEnable=true or override compound. 所以我不明白能够禁用 HList 的意义。可能是在“或覆盖复合”部分中出现问题,但我不明白那是什么意思。
  3. 在 slick 和 22 列上搜索 Internet,我遇到了一些基于嵌套元组的解决方案。是否可以自定义代码生成器以使用这种方法?
  4. 如果用 case classes 和 > 22 个字段生成代码不是一个可行的选择,我认为可以生成一个普通的 class,它有一个“访问器”功能每个列,从而提供从基于索引的访问到基于名称的访问的“映射”。我很乐意自己实现这一代,但我想我需要一些指导从哪里开始。我认为它应该能够为此覆盖标准代码生成器。我已经为某些自定义数据类型使用了覆盖的 SourceCodeGenerator。但是除了这个用例之外,代码生成器的文档对我帮助不大那个

非常感谢您的帮助。提前致谢!

正如您已经发现的那样,可用的选项很少 - 嵌套元组、从 Slick HList 到 Shapeless HList 的转换,然后再到 case classes 等等。

我发现所有这些选项对于这项任务来说都太复杂了,于是使用自定义的 Slick Codegen 来生成带有访问器的简单包装器 class。

看看这个gist

class MyCodegenCustomisations(model: Model) extends slick.codegen.SourceCodeGenerator(model){
import ColumnDetection._


override def Table = new Table(_){
    table =>

    val columnIndexByName = columns.map(_.name).zipWithIndex.toMap
    def getColumnIndex(columnName: String): Option[Int] = {
        columnIndexByName.get(columnName)

    }

    private def getWrapperCode: Seq[String] = {
        if (columns.length <= 22) {
            //do not generate wrapper for tables which get case class generated by Slick
            Seq.empty[String]
        } else {
            val lines =
                columns.map{c =>
                    getColumnIndex(c.name) match {
                        case Some(colIndex) =>
                            //lazy val firstname: Option[String] = row.productElement(1).asInstanceOf[Option[String]]
                            val colType = c.exposedType
                            val line = s"lazy val ${c.name}: $colType = values($colIndex).asInstanceOf[$colType]"
                            line
                        case None => ""
                    }
                }
            Seq("",
                "/*",
                "case class Wrapper(private val row: Row) {",
                "// addressing HList by index is very slow, let's convert it to vector",
                "private lazy val values = row.toList.toVector",
                ""

            ) ++ lines ++ Seq("}", "*/", "")

        }
    }


    override def code: Seq[String] = {
        val originalCode = super.code
        originalCode ++ this.getWrapperCode
    }


}

}

我最终进一步定制了 slick-codegen。首先,我会回答我自己的问题,然后我会 post 我的解决方案。

问题的答案

  1. 对于 classes 的情况,可能会取消 22 个元数限制,但不适用于元组。而且 slick-codegen 也会生成一些元组,我问的时候没有完全意识到。
  2. 不相关,请参阅答案 1。(如果元组的 22 元数限制也被取消,这可能会变得相关。)
  3. 我选择不进一步调查,所以这个问题暂时没有答案。
  4. 这是我最终采用的方法。

解决方法:生成代码

所以,我最终为超过 22 列的表生成了 "ordinary" classes。让我举一个我现在生成的例子。 (生成器代码如下。)(出于简洁和可读性的原因,此示例少于 22 列。)

case class BigAssTableRow(val id: Long, val name: String, val age: Option[Int] = None)

type BigAssTableRowList = HCons[Long,HCons[String,HCons[Option[Int]]], HNil]

object BigAssTableRow {
  def apply(hList: BigAssTableRowList) = new BigAssTableRow(hlist.head, hList.tail.head, hList.tail.tail.head)
  def unapply(row: BigAssTableRow) = Some(row.id :: row.name :: row.age)
}

implicit def GetResultBoekingenRow(implicit e0: GR[Long], e1: GR[String], e2: GR[Optional[Int]]) = GR{
  prs => import prs._
  BigAssTableRow.apply(<<[Long] :: <<[String] :: <<?[Int] :: HNil)
}

class BigAssTable(_tableTag: Tag) extends Table[BigAssTableRow](_tableTag, "big_ass") {
  def * = id :: name :: age :: :: HNil <> (BigAssTableRow.apply, BigAssTableRow.unapply)

  val id: Rep[Long] = column[Long]("id", O.PrimaryKey)
  val name: Rep[String] = column[String]("name", O.Length(255,varying=true))
  val age: Rep[Option[Int]] = column[Option[Int]]("age", O.Default(None))
}

lazy val BigAssTable = new TableQuery(tag => new BigAssTable(tag))

最难的部分是找出 * 映射在 Slick 中的工作原理。文档不多,但我发现 this Whosebug answer 很有启发性。

我创建了 BigAssTableRow object 以使 HList 的使用对客户端代码透明。请注意,对象中的 apply 函数重载了案例 class 中的 apply。所以我仍然可以通过调用 BigAssTableRow(id: 1L, name: "Foo") 创建实体,而 * 投影仍然可以使用带有 HList.

apply 函数

所以,我现在可以做这样的事情了:

// I left out the driver import as well as the scala.concurrent imports 
// for the Execution context.

val collection = TableQuery[BigAssTable]
val row = BigAssTableRow(id: 1L, name: "Qwerty") // Note that I leave out the optional age

Await.result(db.run(collection += row), Duration.Inf)

Await.result(db.run(collection.filter(_.id === 1L).result), Duration.Inf)

对于此代码,无论在幕后使用元组还是 HList,它都是完全透明的。

解决方案:这是如何生成的

我将 post 我的整个生成器代码放在这里。它并不完美;如果您有改进建议,请告诉我!大部分内容只是从 slick.codegen.AbstractSourceCodeGenerator 和相关的 class 中复制而来,然后稍作更改。还有一些和本题没有直接关系的东西,比如增加java.time.*数据类型,过滤特定表等。我把它们留在里面,因为它们可能会有用。另请注意,此示例适用于 Postgres 数据库。

import slick.codegen.SourceCodeGenerator
import slick.driver.{JdbcProfile, PostgresDriver}
import slick.jdbc.meta.MTable
import slick.model.Column

import scala.concurrent.Await
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration

object MySlickCodeGenerator {
  val slickDriver = "slick.driver.PostgresDriver"
  val jdbcDriver = "org.postgresql.Driver"
  val url = "jdbc:postgresql://localhost:5432/dbname"
  val outputFolder = "/path/to/project/src/test/scala"
  val pkg = "my.package"
  val user = "user"
  val password = "password"

  val driver: JdbcProfile = Class.forName(slickDriver + "$").getField("MODULE$").get(null).asInstanceOf[JdbcProfile]
  val dbFactory = driver.api.Database
  val db = dbFactory.forURL(url, driver = jdbcDriver, user = user, password = password, keepAliveConnection = true)

  // The schema is generated using Liquibase, which creates these tables that I don't want to use
  def excludedTables = Array("databasechangelog", "databasechangeloglock")

  def tableFilter(table: MTable): Boolean = {
    !excludedTables.contains(table.name.name) && schemaFilter(table.name.schema)
  }

  // There's also an 'audit' schema in the database, I don't want to use that one
  def schemaFilter(schema: Option[String]): Boolean = {
    schema match {
      case Some("public") => true
      case None => true
      case _ => false
    }
  }

  // Fetch data model
  val modelAction = PostgresDriver.defaultTables
    .map(_.filter(tableFilter))
    .flatMap(PostgresDriver.createModelBuilder(_, ignoreInvalidDefaults = false).buildModel)

  val modelFuture = db.run(modelAction)

  // customize code generator
  val codegenFuture = modelFuture.map(model => new SourceCodeGenerator(model) {

    // add custom import for added data types
    override def code = "import my.package.Java8DateTypes._" + "\n" + super.code

    override def Table = new Table(_) {
      table =>

      // Use different factory and extractor functions for tables with > 22 columns
      override def factory   = if(columns.size == 1) TableClass.elementType else if(columns.size <= 22) s"${TableClass.elementType}.tupled" else s"${EntityType.name}.apply"
      override def extractor = if(columns.size <= 22) s"${TableClass.elementType}.unapply" else s"${EntityType.name}.unapply"

      override def EntityType = new EntityTypeDef {
        override def code = {
          val args = columns.map(c =>
            c.default.map( v =>
              s"${c.name}: ${c.exposedType} = $v"
            ).getOrElse(
              s"${c.name}: ${c.exposedType}"
            )
          )
          val callArgs = columns.map(c => s"${c.name}")
          val types = columns.map(c => c.exposedType)

          if(classEnabled){
            val prns = (parents.take(1).map(" extends "+_) ++ parents.drop(1).map(" with "+_)).mkString("")
            s"""case class $name(${args.mkString(", ")})$prns"""
          } else {
            s"""
/** Constructor for $name providing default values if available in the database schema. */
case class $name(${args.map(arg => {s"val $arg"}).mkString(", ")})
type ${name}List = ${compoundType(types)}
object $name {
  def apply(hList: ${name}List): $name = new $name(${callArgs.zipWithIndex.map(pair => s"hList${tails(pair._2)}.head").mkString(", ")})
  def unapply(row: $name) = Some(${compoundValue(callArgs.map(a => s"row.$a"))})
}
          """.trim
          }
        }
      }

      override def PlainSqlMapper = new PlainSqlMapperDef {
        override def code = {
          val positional = compoundValue(columnsPositional.map(c => if (c.fakeNullable || c.model.nullable) s"<<?[${c.rawType}]" else s"<<[${c.rawType}]"))
          val dependencies = columns.map(_.exposedType).distinct.zipWithIndex.map{ case (t,i) => s"""e$i: GR[$t]"""}.mkString(", ")
          val rearranged = compoundValue(desiredColumnOrder.map(i => if(columns.size > 22) s"r($i)" else tuple(i)))
          def result(args: String) = s"$factory($args)"
          val body =
            if(autoIncLastAsOption && columns.size > 1){
              s"""
val r = $positional
import r._
${result(rearranged)} // putting AutoInc last
              """.trim
            } else {
              result(positional)
            }

              s"""
implicit def $name(implicit $dependencies): GR[${TableClass.elementType}] = GR{
  prs => import prs._
  ${indent(body)}
}
          """.trim
        }
      }

      override def TableClass = new TableClassDef {
        override def star = {
          val struct = compoundValue(columns.map(c=>if(c.fakeNullable)s"Rep.Some(${c.name})" else s"${c.name}"))
          val rhs = s"$struct <> ($factory, $extractor)"
          s"def * = $rhs"
        }
      }

      def tails(n: Int) = {
        List.fill(n)(".tail").mkString("")
      }

      // override column generator to add additional types
      override def Column = new Column(_) {
        override def rawType = {
          typeMapper(model).getOrElse(super.rawType)
        }
      }
    }
  })

  def typeMapper(column: Column): Option[String] = {
    column.tpe match {
      case "java.sql.Date" => Some("java.time.LocalDate")
      case "java.sql.Timestamp" => Some("java.time.LocalDateTime")
      case _ => None
    }
  }

  def doCodeGen() = {
    def generator = Await.result(codegenFuture, Duration.Inf)
    generator.writeToFile(slickDriver, outputFolder, pkg, "Tables", "Tables.scala")
  }

  def main(args: Array[String]) {
    doCodeGen()
    db.close()
  }
}

更新 2019-02-15:*随着 Slick 3.3.0 的发布,@Marcus built-in 支持代码生成列数超过 22 的表。

从 Slick 3.2.0 开始,>22 参数情况的最简单解决方案 class 是在 * method using mapTo instead of the <> operator (per documented unit test 中定义默认投影):

case class BigCase(id: Int,
                   p1i1: Int, p1i2: Int, p1i3: Int, p1i4: Int, p1i5: Int, p1i6: Int,
                   p2i1: Int, p2i2: Int, p2i3: Int, p2i4: Int, p2i5: Int, p2i6: Int,
                   p3i1: Int, p3i2: Int, p3i3: Int, p3i4: Int, p3i5: Int, p3i6: Int,
                   p4i1: Int, p4i2: Int, p4i3: Int, p4i4: Int, p4i5: Int, p4i6: Int)

class bigCaseTable(tag: Tag) extends Table[BigCase](tag, "t_wide") {
      def id = column[Int]("id", O.PrimaryKey)
      def p1i1 = column[Int]("p1i1")
      def p1i2 = column[Int]("p1i2")
      def p1i3 = column[Int]("p1i3")
      def p1i4 = column[Int]("p1i4")
      def p1i5 = column[Int]("p1i5")
      def p1i6 = column[Int]("p1i6")
      def p2i1 = column[Int]("p2i1")
      def p2i2 = column[Int]("p2i2")
      def p2i3 = column[Int]("p2i3")
      def p2i4 = column[Int]("p2i4")
      def p2i5 = column[Int]("p2i5")
      def p2i6 = column[Int]("p2i6")
      def p3i1 = column[Int]("p3i1")
      def p3i2 = column[Int]("p3i2")
      def p3i3 = column[Int]("p3i3")
      def p3i4 = column[Int]("p3i4")
      def p3i5 = column[Int]("p3i5")
      def p3i6 = column[Int]("p3i6")
      def p4i1 = column[Int]("p4i1")
      def p4i2 = column[Int]("p4i2")
      def p4i3 = column[Int]("p4i3")
      def p4i4 = column[Int]("p4i4")
      def p4i5 = column[Int]("p4i5")
      def p4i6 = column[Int]("p4i6")

      // HList-based wide case class mapping
      def m3 = (
        id ::
        p1i1 :: p1i2 :: p1i3 :: p1i4 :: p1i5 :: p1i6 ::
        p2i1 :: p2i2 :: p2i3 :: p2i4 :: p2i5 :: p2i6 ::
        p3i1 :: p3i2 :: p3i3 :: p3i4 :: p3i5 :: p3i6 ::
        p4i1 :: p4i2 :: p4i3 :: p4i4 :: p4i5 :: p4i6 :: HNil
      ).mapTo[BigCase]

      def * = m3
}

编辑

因此,如果您希望 slick-codegen 使用上述 mapTo 方法生成巨大的表格,您可以将 the relevant parts 覆盖到代码生成器并添加一个 mapTo 语句:

package your.package
import slick.codegen.SourceCodeGenerator
import slick.{model => m}


class HugeTableCodegen(model: m.Model) extends SourceCodeGenerator(model) with GeneratorHelpers[String, String, String]{


  override def Table = new Table(_) {
    table =>

    // always defines types using case classes
    override def EntityType = new EntityTypeDef{
      override def classEnabled = true
    }

    // allow compound statements using HNil, but not for when "def *()" is being defined, instead use mapTo statement
    override def compoundValue(values: Seq[String]): String = {
      // values.size>22 assumes that this must be for the "*" operator and NOT a primary/foreign key
      if(hlistEnabled && values.size > 22) values.mkString("(", " :: ", s" :: HNil).mapTo[${StringExtensions(model.name.table).toCamelCase}Row]")
      else if(hlistEnabled) values.mkString(" :: ") + " :: HNil"
      else if (values.size == 1) values.head
      else s"""(${values.mkString(", ")})"""
    }

    // should always be case classes, so no need to handle hlistEnabled here any longer
    override def compoundType(types: Seq[String]): String = {
      if (types.size == 1) types.head
      else s"""(${types.mkString(", ")})"""
    }
  }
}

然后在单独的项目中构建代码生成代码 as documented 以便它在编译时生成源代码。您可以将您的 class 名称作为参数传递给您要扩展的 SourceCodeGenerator

lazy val generateSlickSchema = taskKey[Seq[File]]("Generates Schema definitions for SQL tables")
generateSlickSchema := {

  val managedSourceFolder = sourceManaged.value / "main" / "scala"
  val packagePath = "your.sql.table.package"

  (runner in Compile).value.run(
    "slick.codegen.SourceCodeGenerator", (dependencyClasspath in Compile).value.files,
    Array(
      "env.db.connectorProfile",
      "slick.db.driver",
      "slick.db.url",
      managedSourceFolder.getPath,
      packagePath,
      "slick.db.user",
      "slick.db.password",
      "true",
      "your.package.HugeTableCodegen"
    ),
    streams.value.log
  )
  Seq(managedSourceFolder / s"${packagePath.replace(".","/")}/Tables.scala")
}

此问题已在 Slick 3.3 中解决: https://github.com/slick/slick/pull/1889/

此解决方案提供 def *def ?,还支持普通 SQL。