Tensorflow:不能将 tf.case 与输入参数一起使用
Tensorflow: Can't use tf.case with input argument
我需要创建一个变量 epsilon_n
,它根据当前 step
更改定义(和值)。由于我有两个以上的案例,所以我似乎不能使用 tf.cond
。我正在尝试按如下方式使用 tf.case
:
import tensorflow as tf
####
EPSILON_DELTA_PHASE1 = 33e-4
EPSILON_DELTA_PHASE2 = 2.5
####
step = tf.placeholder(dtype=tf.float32, shape=None)
def fn1(step):
return tf.constant([1.])
def fn2(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE1])
def fn3(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE2])
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), lambda step: fn1(step)),
(tf.less(step, 6e4), lambda step: fn2(step)),
(tf.less(step, 1e5), lambda step: fn3(step))],
default=lambda: tf.constant([1e5]),
exclusive=False)
但是,我不断收到此错误消息:
TypeError: <lambda>() missing 1 required positional argument: 'step'
我尝试了以下方法:
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), fn1),
(tf.less(step, 6e4), fn2),
(tf.less(step, 1e5), fn3)],
default=lambda: tf.constant([1e5]),
exclusive=False)
我还是会出现同样的错误。 Tensorflow 文档中的示例权衡了没有将输入参数传递给可调用函数的情况。我在互联网上找不到关于 tf.case 的足够信息!请帮忙?
您需要进行以下几处更改。
为了保持一致性,您可以将所有 return 值设置为变量。
# Since step is a scalar, scalar shape [() or [], not None] much be provided
step = tf.placeholder(dtype=tf.float32, shape=())
def fn1(step):
return tf.constant([1.])
# Here you need to use Variable not constant, since you are modifying the value using placeholder
def fn2(step):
return tf.Variable([1.+step*EPSILON_DELTA_PHASE1])
def fn3(step):
return tf.Variable([1.+step*EPSILON_DELTA_PHASE2])
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), lambda : fn1(step)),
(tf.less(step, 6e4), lambda : fn2(step)),
(tf.less(step, 1e5), lambda : fn3(step))],
default=lambda: tf.constant([1e5]),
exclusive=False)
我需要创建一个变量 epsilon_n
,它根据当前 step
更改定义(和值)。由于我有两个以上的案例,所以我似乎不能使用 tf.cond
。我正在尝试按如下方式使用 tf.case
:
import tensorflow as tf
####
EPSILON_DELTA_PHASE1 = 33e-4
EPSILON_DELTA_PHASE2 = 2.5
####
step = tf.placeholder(dtype=tf.float32, shape=None)
def fn1(step):
return tf.constant([1.])
def fn2(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE1])
def fn3(step):
return tf.constant([1.+step*EPSILON_DELTA_PHASE2])
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), lambda step: fn1(step)),
(tf.less(step, 6e4), lambda step: fn2(step)),
(tf.less(step, 1e5), lambda step: fn3(step))],
default=lambda: tf.constant([1e5]),
exclusive=False)
但是,我不断收到此错误消息:
TypeError: <lambda>() missing 1 required positional argument: 'step'
我尝试了以下方法:
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), fn1),
(tf.less(step, 6e4), fn2),
(tf.less(step, 1e5), fn3)],
default=lambda: tf.constant([1e5]),
exclusive=False)
我还是会出现同样的错误。 Tensorflow 文档中的示例权衡了没有将输入参数传递给可调用函数的情况。我在互联网上找不到关于 tf.case 的足够信息!请帮忙?
您需要进行以下几处更改。 为了保持一致性,您可以将所有 return 值设置为变量。
# Since step is a scalar, scalar shape [() or [], not None] much be provided
step = tf.placeholder(dtype=tf.float32, shape=())
def fn1(step):
return tf.constant([1.])
# Here you need to use Variable not constant, since you are modifying the value using placeholder
def fn2(step):
return tf.Variable([1.+step*EPSILON_DELTA_PHASE1])
def fn3(step):
return tf.Variable([1.+step*EPSILON_DELTA_PHASE2])
epsilon_n = tf.case(
pred_fn_pairs=[
(tf.less(step, 3e4), lambda : fn1(step)),
(tf.less(step, 6e4), lambda : fn2(step)),
(tf.less(step, 1e5), lambda : fn3(step))],
default=lambda: tf.constant([1e5]),
exclusive=False)