Google AutoMl 表:如何使用 JAVA 客户端库在数据集中设置目标列

Google AutoMl Tables: How to set target column in the dataset using JAVA client library

根据我的要求,我需要自动化整个周期:-

  1. 从数据库中以正确格式将数据作为 CSV 格式。
  2. 正在将其上传到 google 存储桶
  3. 在 google AutoML 表中创建数据集
  4. 正在将我上传的 CSV 从存储桶导入数据集
  5. 根据创建的数据集训练模型
  6. 训练后部署模型

我已成功完成第 1 步到第 4 步,并且还完成了第 6 步,但 我在第 5 步中遇到问题,我需要在其中定义目标列我的 CSV 文件.

为了解决这个问题,我查看了 AutoMl 文档,但找不到任何通过它的 (AutoMl Tables) JAVA 客户端库来解决这个问题的方法。这是文档中提供的代码,我们需要 2 个重要参数,它们是 tableSpecIdcolumnSpecId 下面是文档中的代码:- https://cloud.google.com/automl-tables/docs/train

import com.google.cloud.automl.v1beta1.AutoMlClient;
import com.google.cloud.automl.v1beta1.ColumnSpec;
import com.google.cloud.automl.v1beta1.ColumnSpecName;
import com.google.cloud.automl.v1beta1.LocationName;
import com.google.cloud.automl.v1beta1.Model;
import com.google.cloud.automl.v1beta1.OperationMetadata;
import com.google.cloud.automl.v1beta1.TablesModelMetadata;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

class TablesCreateModel {

  public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String tableSpecId = "YOUR_TABLE_SPEC_ID";
    String columnSpecId = "YOUR_COLUMN_SPEC_ID";
    String displayName = "YOUR_DATASET_NAME";
    createModel(projectId, datasetId, tableSpecId, columnSpecId, displayName);
  }

  // Create a model
  static void createModel(
      String projectId,
      String datasetId,
      String tableSpecId,
      String columnSpecId,
      String displayName)
      throws IOException, ExecutionException, InterruptedException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // A resource that represents Google Cloud Platform location.
      LocationName projectLocation = LocationName.of(projectId, "us-central1");

      // Get the complete path of the column.
      ColumnSpecName columnSpecName =
          ColumnSpecName.of(projectId, "us-central1", datasetId, tableSpecId, columnSpecId);

      // Build the get column spec.
      ColumnSpec targetColumnSpec =
          ColumnSpec.newBuilder().setName(columnSpecName.toString()).build();

      // Set model metadata.
      TablesModelMetadata metadata =
          TablesModelMetadata.newBuilder()
              .setTargetColumnSpec(targetColumnSpec)
              .setTrainBudgetMilliNodeHours(24000)
              .build();

      Model model =
          Model.newBuilder()
              .setDisplayName(displayName)
              .setDatasetId(datasetId)
              .setTablesModelMetadata(metadata)
              .build();

      // Create a model with the model metadata in the region.
      OperationFuture<Model, OperationMetadata> future =
          client.createModelAsync(projectLocation, model);
      // OperationFuture.get() will block until the model is created, which may take several hours.
      // You can use OperationFuture.getInitialFuture to get a future representing the initial
      // response to the request, which contains information while the operation is in progress.
      System.out.format("Training operation name: %s%n", future.getInitialFuture().get().getName());
      *System*.out.println("Training started...");
    }
  }
}

在成功将我的 CSV 导入数据集(第 4 步)后,我能够从这 2 个重要参数中获得 tableSpecId,但我无法获得 columnSpecId 因为它定义了列之间的相关性,而且根据我的理解,它还定义了哪一列是目标列。

在网上做了更多研究后,我发现下面提到的 REST API 发送请求以在数据集中设置目标列 https://automl.clients6.google.com/v1beta1/projects/[projectId]/locations/[Location]/datasets/[DatasetId]?updateMask=tablesDatasetMetadata.targetColumnSpecId&key=[auth stuff] 但我正在使用 SpringBoot,这意味着我正在使用 AutoMl 的 JAVA 客户端库。现在,在 AutoML 的客户端库中,我 无法找到任何可以用来显式发送请求的方法,告诉 Google AutoMl 表该特定列是我的目标列 。我想 python 客户端库中有可用的方法,但 java 中没有。我先谢谢你了。

错误(在 columnSpecId 中传递空字符串时):-

com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
java.util.concurrent.ExecutionException: com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at com.google.common.util.concurrent.AbstractFuture.getDoneValue(AbstractFuture.java:566)
    at com.google.common.util.concurrent.AbstractFuture.get(AbstractFuture.java:547)
    at com.google.common.util.concurrent.FluentFuture$TrustedFuture.get(FluentFuture.java:86)
    at com.google.common.util.concurrent.ForwardingFuture.get(ForwardingFuture.java:62)
    at com.realcoderz.ai.AutoMlTables.TablesCreateModel.createModel(TablesCreateModel.java:76)
    at com.realcoderz.ai.AutoMlTables.TablesCreateModel.getTablesCreateModel(TablesCreateModel.java:29)
    at com.realcoderz.ai.controller.AiMasterController.CreateModel(AiMasterController.java:105)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:64)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:564)
    at org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:197)
    at org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:141)
    at org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:106)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:894)
    at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:808)
    at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
    at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:1060)
    at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:962)
    at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:1006)
    at org.springframework.web.servlet.FrameworkServlet.doGet(FrameworkServlet.java:898)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:626)
    at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:883)
    at javax.servlet.http.HttpServlet.service(HttpServlet.java:733)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:53)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:100)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.FormContentFilter.doFilterInternal(FormContentFilter.java:93)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:201)
    at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:119)
    at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
    at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
    at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:202)
    at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:97)
    at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:542)
    at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:143)
    at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:92)
    at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:78)
    at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:343)
    at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:374)
    at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:65)
    at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:888)
    at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1597)
    at org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
    at java.base/java.lang.Thread.run(Thread.java:832)
Caused by: com.google.api.gax.rpc.InvalidArgumentException: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at com.google.api.gax.rpc.ApiExceptionFactory.createException(ApiExceptionFactory.java:49)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:72)
    at com.google.api.gax.grpc.GrpcApiExceptionFactory.create(GrpcApiExceptionFactory.java:60)
    at com.google.api.gax.grpc.GrpcExceptionCallable$ExceptionTransformingFuture.onFailure(GrpcExceptionCallable.java:97)
    at com.google.api.core.ApiFutures.onFailure(ApiFutures.java:68)
    at com.google.common.util.concurrent.Futures$CallbackListener.run(Futures.java:1041)
    at com.google.common.util.concurrent.DirectExecutor.execute(DirectExecutor.java:30)
    at com.google.common.util.concurrent.AbstractFuture.executeListener(AbstractFuture.java:1215)
    at com.google.common.util.concurrent.AbstractFuture.complete(AbstractFuture.java:983)
    at com.google.common.util.concurrent.AbstractFuture.setException(AbstractFuture.java:771)
    at io.grpc.stub.ClientCalls$GrpcFuture.setException(ClientCalls.java:563)
    at io.grpc.stub.ClientCalls$UnaryStreamToFuture.onClose(ClientCalls.java:533)
    at io.grpc.internal.DelayedClientCall$DelayedListener.run(DelayedClientCall.java:464)
    at io.grpc.internal.DelayedClientCall$DelayedListener.delayOrExecute(DelayedClientCall.java:428)
    at io.grpc.internal.DelayedClientCall$DelayedListener.onClose(DelayedClientCall.java:461)
    at io.grpc.internal.ClientCallImpl.closeObserver(ClientCallImpl.java:617)
    at io.grpc.internal.ClientCallImpl.access0(ClientCallImpl.java:70)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImplStreamClosed.runInternal(ClientCallImpl.java:803)
    at io.grpc.internal.ClientCallImpl$ClientStreamListenerImplStreamClosed.runInContext(ClientCallImpl.java:782)
    at io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
    at io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:123)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
    at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
    at java.base/java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:304)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
    ... 1 more
Caused by: io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Request contains an invalid argument.
    at io.grpc.Status.asRuntimeException(Status.java:533)
    ... 16 more

这里有 2 张图片用于比较 图片 1 我已经通过我的 springboot 应用程序发送了请求 我离开 columnSpecId Blank 并且在image 2 我手动设置目标并在 google 云控制台中训练模型 UI Image 1 Image 2

您可以通过在 TableSpec 上调用 ListColumnSpecs 来获取 columnSpecId:https://googleapis.dev/java/google-cloud-clients/latest/com/google/cloud/automl/v1beta1/AutoMlClient.html#listColumnSpecs-com.google.cloud.automl.v1beta1.ListColumnSpecsRequest-

像这样:

TableSpecName parent = TableSpecName.of("[PROJECT]", "[LOCATION]", "[DATASET]", "[TABLE_SPEC]")
ColumnSpec targetColumnSpec;
for (ColumnSpec element : autoMlClient.listColumnSpecs(parent).iterateAll()) {
  if (element.getDisplayName().equals("[MY_TARGET_COLUMN]")) {
    targetColumnSpec = element;
    break;
  }
}