将 Trainable class 与 tune.with_parameters 一起使用时出现 Ray Tune 错误

Ray Tune error when using Trainable class with tune.with_parameters

使用 tune 文档本身的非常简单的示例:

from ray import tune
import numpy as np
class MyTrainable(tune.Trainable):
    def setup(self, config, dataset=None):
        print(config, dataset)
        self.dataset = dataset
        self.iter = iter(self.dataset)
        self.next_sample = next(self.iter)

    def step(self):
        loss = 0.1
        return {"loss": loss, done: True}

tune.run(
    tune.with_parameters(MyTrainable, dataset=np.ones([1,])),
)

https://docs.ray.io/en/master/tune/api_docs/trainable.html?highlight=with_parameters#tune-with-parameters

这个returns错误:

TypeError: init() 得到了一个意外的关键字参数 'dataset'

已满:

Failure # 1 (occurred at 2021-07-04_14-56-21)
Traceback (most recent call last):
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/trial_runner.py", line 586, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/ray_trial_executor.py", line 609, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/worker.py", line 1456, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): [36mray::ImplicitFunc.train_buffered()[39m (pid=239786, ip=10.10.5.103)
  File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/trainable.py", line 167, in train_buffered
    result = self.train()
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/trainable.py", line 226, in train
    result = self.step()
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/function_runner.py", line 366, in step
    self._report_thread_runner_error(block=True)
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/function_runner.py", line 512, in _report_thread_runner_error
    raise TuneError(
ray.tune.error.TuneError: Trial raised an exception. Traceback:
[36mray::ImplicitFunc.train_buffered()[39m (pid=239786, ip=10.10.5.103)
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/function_runner.py", line 248, in run
    self._entrypoint()
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/function_runner.py", line 315, in entrypoint
    return self._trainable_func(self.config, self._status_reporter,
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/function_runner.py", line 576, in _trainable_func
    output = fn()
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/function_runner.py", line 651, in _inner
    inner(config, checkpoint_dir=None)
  File "/users/username/miniconda3/envs/EnvName/lib/python3.8/site-packages/ray/tune/function_runner.py", line 645, in inner
    fn(config, **fn_kwargs)
TypeError: __init__() got an unexpected keyword argument 'dataset'

我正在使用 ray 版本 1.2.0。

帮助非常感谢。

你能试试升级 Ray 吗?最新版本是 1.4.1,您链接的文档来自最新的 master。在 1.2.0 中,tune.with_parameters 仅支持函数训练。 只要 运行 pip install -U ray 就可以了。