Scala 反射中的 Tensorflow

Tensorflow in Scala reflection

我正在尝试让 tensorflow 使 java 在 Scala 上工作。我使用的是没有任何 Scala 包装器的 tensorflow java 库。

sbt 我有:

如果我 运行 HelloWord 找到 here,它 WORKS 很好,用 Scala 改编:

import org.tensorflow.Graph
import org.tensorflow.Session
import org.tensorflow.Tensor
import org.tensorflow.TensorFlow


val g = new Graph()
val value = "Hello from " + TensorFlow.version()
val t = Tensor.create(value.getBytes("UTF-8"))
// The Java API doesn't yet include convenience functions for adding operations.
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();

val s = new Session(g)
val output = s.runner().fetch("MyConst").run().get(0)

但是,如果我尝试使用 Scala 反射从字符串编译函数,它不起作用。这是我用来 运行:

的片段
import scala.reflect.runtime.{universe => ru}
import scala.tools.reflect.ToolBox
val fnStr = """
    {() =>
      import org.tensorflow.Graph
      import org.tensorflow.Session
      import org.tensorflow.Tensor
      import org.tensorflow.TensorFlow

      val g = new Graph()
      val value = "Hello from " + TensorFlow.version()
      val t = Tensor.create(value.getBytes("UTF-8"))
      g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();

      val s = new Session(g)

      s.runner().fetch("MyConst").run().get(0)
    }
    """
val mirror = ru.runtimeMirror(getClass.getClassLoader)
val tb = mirror.mkToolBox()
var t = tb.parse(fnStr)
val fn = tb.eval(t).asInstanceOf[() => Any]
// and finally, executing the function
fn()

此处简化build.sbt以重现上述错误:

lazy val commonSettings = Seq(
    scalaVersion := "2.12.10",

    libraryDependencies ++= {
      Seq(
                  // To support runtime compilation
        "org.scala-lang" % "scala-reflect" % scalaVersion.value,
        "org.scala-lang" % "scala-compiler" % scalaVersion.value,

        // for tensorflow4java
        "org.tensorflow" % "tensorflow" % "1.15.0",
        "org.tensorflow" % "proto" % "1.15.0",
        "org.tensorflow" % "libtensorflow_jni" % "1.15.0"

      )
    }
)

lazy val `test-proj` = project
  .in(file("."))
  .settings(commonSettings)

当 运行 执行以上操作时,例如使用 sbt console,我得到以下错误和堆栈跟踪:

java.lang.NoSuchMethodError: org.tensorflow.Session.runner()Lorg/tensorflow/Session$$Runner;
  at __wrapper$f093d26a3c504d4381a37ef78b6c3d54.__wrapper$f093d26a3c504d4381a37ef78b6c3d54$.$anonfun$wrapper(<no source file>:15)

请忽略之前代码给出的没有使用资源上下文(到 close())的内存泄漏

问题出在反射编译和 Scala-Java 互操作的组合中出现的这个错误中

https://github.com/scala/bug/issues/8956

如果路径相关类型 (s.Runner) 的值 (s.runner()) 来自 Java 非静态内部 class,则工具箱无法对其进行类型检查.而Runner就是exactly这样class里面org.tensorflow.Session.

您可以 运行 手动编译器(类似于 how 工具箱 运行)

import org.tensorflow.Tensor
import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
import scala.reflect.io.{AbstractFile, VirtualDirectory}
import scala.reflect.runtime
import scala.reflect.runtime.universe
import scala.reflect.runtime.universe._
import scala.tools.nsc.{Global, Settings}

val code: String =
  """
    |import org.tensorflow.Graph
    |import org.tensorflow.Session
    |import org.tensorflow.Tensor
    |import org.tensorflow.TensorFlow
    |
    |object Main {
    |  def foo() = () => {
    |      val g = new Graph()
    |      val value = "Hello from " + TensorFlow.version()
    |      val t = Tensor.create(value.getBytes("UTF-8"))
    |      g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
    |
    |      val s = new Session(g)
    |
    |      s.runner().fetch("MyConst").run().get(0)
    |  }
    |}
""".stripMargin

val directory = new VirtualDirectory("(memory)", None)
val runtimeMirror = createRuntimeMirror(directory, runtime.currentMirror)
compileCode(code, List(), directory)
val tensor = runObjectMethod("Main", runtimeMirror, "foo").asInstanceOf[() => Tensor[_]]
tensor() // STRING tensor with shape []

def compileCode(code: String, classpathDirectories: List[AbstractFile], outputDirectory: AbstractFile): Unit = {
  val settings = new Settings
  classpathDirectories.foreach(dir => settings.classpath.prepend(dir.toString))
  settings.outputDirs.setSingleOutput(outputDirectory)
  settings.usejavacp.value = true
  val global = new Global(settings)
  (new global.Run).compileSources(List(new BatchSourceFile("(inline)", code)))
}

def runObjectMethod(objectName: String, runtimeMirror: Mirror, methodName: String, arguments: Any*): Any = {
  val objectSymbol = runtimeMirror.staticModule(objectName)
  val objectModuleMirror = runtimeMirror.reflectModule(objectSymbol)
  val objectInstance = objectModuleMirror.instance
  val objectType = objectSymbol.typeSignature
  val methodSymbol = objectType.decl(TermName(methodName)).asMethod
  val objectInstanceMirror = runtimeMirror.reflect(objectInstance)
  val methodMirror = objectInstanceMirror.reflectMethod(methodSymbol)
  methodMirror(arguments: _*)
}

def createRuntimeMirror(directory: AbstractFile, parentMirror: Mirror): Mirror = {
  val classLoader = new AbstractFileClassLoader(directory, parentMirror.classLoader)
  universe.runtimeMirror(classLoader)
}

dynamically parse json in flink map

Dynamic compilation of multiple Scala classes at runtime

正如 Dmytro 在他的回答中指出的那样,使用工具箱是不可能的。他指出了另一个答案 ()。我认为有一个巧妙的解决方案,只需替换之前定义的 Compiler class,并替换 Compiler class.

的工具箱

在这种情况下,最终代码段将如下所示:

import your.package.Compiler
val fnStr = """
    {() =>
      import org.tensorflow.Graph
      import org.tensorflow.Session
      import org.tensorflow.Tensor
      import org.tensorflow.TensorFlow

      val g = new Graph()
      val value = "Hello from " + TensorFlow.version()
      val t = Tensor.create(value.getBytes("UTF-8"))
      g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();

      val s = new Session(g)

      s.runner().fetch("MyConst").run().get(0)
    }
    """
val tb = new Compiler() // this replaces the mirror and toolbox instantiation
var t = tb.parse(fnStr)
val fn = tb.eval(t).asInstanceOf[() => Any]
// and finally, executing the function
println(fn())

为了完成,copy/paste 来自 的解决方案:

  class Compiler() {
    import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
    import scala.reflect.io.{AbstractFile, VirtualDirectory}
    import scala.reflect.runtime
    import scala.reflect.runtime.universe
    import scala.reflect.runtime.universe._
    import scala.tools.nsc.{Global, Settings}
    import scala.collection.mutable
    import java.security.MessageDigest
    import java.math.BigInteger
       
    val target  = new VirtualDirectory("(memory)", None)
       
    val classCache = mutable.Map[String, Class[_]]()
       
    private val settings = new Settings()
    settings.deprecation.value = true // enable detailed deprecation warnings
    settings.unchecked.value = true // enable detailed unchecked warnings
    settings.outputDirs.setSingleOutput(target)
    settings.usejavacp.value = true
       
    private val global = new Global(settings)
    private lazy val run = new global.Run
       
    val classLoader = new AbstractFileClassLoader(target, this.getClass.getClassLoader)
       
    /**Compiles the code as a class into the class loader of this compiler.
      * 
      * @param code
      * @return
      */
    def compile(code: String) = {
      val className = classNameForCode(code)
      findClass(className).getOrElse {
        val sourceFiles = List(new BatchSourceFile("(inline)", wrapCodeInClass(className, code)))
        run.compileSources(sourceFiles)
        findClass(className).get
      } 
    }   
       
    /** Compiles the source string into the class loader and
      * evaluates it.
      * 
      * @param code
      * @tparam T
      * @return
      */
    def eval[T](code: String): T = {
      val cls = compile(code)
      cls.getConstructor().newInstance().asInstanceOf[() => Any].apply().asInstanceOf[T]
    }  
        
    def findClass(className: String): Option[Class[_]] = {
      synchronized {
        classCache.get(className).orElse {
          try {
            val cls = classLoader.loadClass(className)
            classCache(className) = cls
            Some(cls)
          } catch {
            case e: ClassNotFoundException => None
          }
        }
      } 
    }   
  
    protected def classNameForCode(code: String): String = {
      val digest = MessageDigest.getInstance("SHA-1").digest(code.getBytes)
      "sha"+new BigInteger(1, digest).toString(16)
    }   
  
    /*  
     * Wrap source code in a new class with an apply method.
     */ 
   private def wrapCodeInClass(className: String, code: String) = {
     "class " + className + " extends (() => Any) {\n" +
     "  def apply() = {\n" +
     code + "\n" +
     "  }\n" +
     "}\n"
   }    
  }