Tensorflow probability: ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)
Tensorflow probability: ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)
我正在尝试让 NUTS 采样器在玩具模型上工作。我已经 运行 解决了标题中提到的问题。
这是重现错误的代码:
import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from functools import partial
import numpy as np
tfd = tfp.distributions
A = tf.random.normal(
[10,10], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
noise_std = tf.random.normal([1])
x1 = tfd.Normal(0, 10 * tf.ones(10)).sample()
x1 = x1[..., tf.newaxis]
y = tf.linalg.matmul(A, x1) + noise_std
model = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), #sigma
tfd.Normal(0, 10 * tf.ones(10)),
lambda x_rv, sigma : tfd.Normal(loc=tf.linalg.matmul(A, x_rv[...,tf.newaxis]) + sigma, scale=1.0)
])
def target_log_prob_fn(sigma, x_rv):
return model.log_prob([sigma, x_rv, y[tf.newaxis, ...]])
def trace_fn(_, pkr):
return (
pkr.inner_results.inner_results.target_log_prob,
pkr.inner_results.inner_results.leapfrogs_taken,
pkr.inner_results.inner_results.has_divergence,
pkr.inner_results.inner_results.energy,
pkr.inner_results.inner_results.log_accept_ratio)
n_chains = 2
def run_nuts_template(
trace_fn,
target_log_prob_fn,
inits,
bijectors_list=None,
num_steps=500,
num_burnin=500,
n_chains=n_chains):
step_size = np.random.rand(n_chains, 1)*.5 + 1.
if not isinstance(inits, list):
inits = [inits]
if bijectors_list is None:
bijectors_list = [tfb.Identity()]*len(inits)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.NoUTurnSampler(
target_log_prob_fn,
step_size=[step_size]*len(inits)
),
bijector=bijectors_list
),
target_accept_prob=.8,
num_adaptation_steps=int(0.8*num_burnin),
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
inner_results=pkr.inner_results._replace(step_size=new_step_size)
),
step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
)
res = tfp.mcmc.sample_chain(
num_results=num_steps,
num_burnin_steps=num_burnin,
current_state=inits,
kernel=kernel,
trace_fn=trace_fn
)
return res
inits = model.sample(n_chains)
run_nuts = partial(run_nuts_template, trace_fn)
inits = [tf.random.uniform(s.shape, -2, 2, tf.float32, name="initializer") for s in inits]
run_nuts(target_log_prob_fn, inits[:-1])
错误:ValueError:张量的形状 (2, 2) 与提供的形状 (2,) 不兼容
完整堆栈跟踪:https://pastebin.com/zAA58P53
ValueError Traceback (most recent call last)
<ipython-input-17-ab5eae0dd51a> in <module>
3 ]
4
----> 5 run_nuts(
6 target_log_prob_fn,
7 inits[:-1]
<ipython-input-13-2d7a920574a8> in run_nuts_template(trace_fn, target_log_prob_fn, inits, bijectors_list, num_steps, num_burnin, n_chains)
45 )
46
---> 47 res = tfp.mcmc.sample_chain(
48 num_results=num_steps,
49 num_burnin_steps=num_burnin,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
359 return seed, next_state, current_kernel_results
360
--> 361 (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
362 loop_fn=_trace_scan_fn,
363 initial_state=(seed, current_state, previous_kernel_results),
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, name)
462 return i + 1, state, num_steps_traced, trace_arrays
463
--> 464 _, final_state, _, trace_arrays = tf.while_loop(
465 cond=lambda i, *_: i < length,
466 body=_body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in _body(i, state, num_steps_traced, trace_arrays)
452 def _body(i, state, num_steps_traced, trace_arrays):
453 elem = elems_array.read(i)
--> 454 state = loop_fn(state, elem)
455
456 trace_arrays, num_steps_traced = ps.cond(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _trace_scan_fn(seed_state_and_results, num_steps)
352
353 def _trace_scan_fn(seed_state_and_results, num_steps):
--> 354 seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
355 loop_num_iter=num_steps,
356 body_fn=_seeded_one_step,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
346 # where while/LoopCond needs it.
347 loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
--> 348 return tf.while_loop(
349 cond=lambda i, *args: i < loop_num_iter,
350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in <lambda>(i, *args)
348 return tf.while_loop(
349 cond=lambda i, *args: i < loop_num_iter,
--> 350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
351 loop_vars=[np.int32(0)] + initial_loop_vars,
352 parallel_iterations=parallel_iterations
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _seeded_one_step(seed, *state_and_results)
349 one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
350 return [passalong_seed] + list(
--> 351 kernel.one_step(*state_and_results, **one_step_kwargs))
352
353 def _trace_scan_fn(seed_state_and_results, num_steps):
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py in one_step(self, current_state, previous_kernel_results, seed)
454 # Step the inner kernel.
455 inner_kwargs = {} if seed is None else dict(seed=seed)
--> 456 new_state, new_inner_results = self.inner_kernel.one_step(
457 current_state, inner_results, **inner_kwargs)
458
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py in one_step(self, current_state, previous_kernel_results, seed)
397 self.name, 'transformed_kernel', 'one_step')):
398 inner_kwargs = {} if seed is None else dict(seed=seed)
--> 399 transformed_next_state, kernel_results = self._inner_kernel.one_step(
400 previous_kernel_results.transformed_state,
401 previous_kernel_results.inner_results,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in one_step(self, current_state, previous_kernel_results, seed)
392 )
393
--> 394 _, _, _, new_step_metastate = tf.while_loop(
395 cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda
396 (iter_ < self.max_tree_depth) &
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, state, metastate)
396 (iter_ < self.max_tree_depth) &
397 tf.reduce_any(metastate.continue_tree)),
--> 398 body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda
399 previous_kernel_results.step_size,
400 previous_kernel_results.momentum_state_memory,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed)
570 momentum_subtree_cumsum,
571 leapfrogs_taken
--> 572 ] = self._build_sub_tree(
573 directions_expanded,
574 integrator,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name)
750 final_not_divergence,
751 momentum_state_memory,
--> 752 ] = tf.while_loop(
753 cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda
754 leapfrogs_taken, state, state_c, continue_tree,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)
758 leapfrogs_taken, state, state_c, continue_tree,
759 not_divergence, momentum_state_memory: (
--> 760 self._loop_build_sub_tree(
761 directions, integrator, current_step_meta_info,
762 iter_, energy_diff_sum, init_momentum_cumsum,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed)
811 next_target,
812 next_target_grad_parts
--> 813 ] = integrator(prev_tree_state.momentum,
814 prev_tree_state.state,
815 prev_tree_state.target,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in __call__(self, momentum_parts, state_parts, target, target_grad_parts, kinetic_energy_fn, name)
295 next_target,
296 next_target_grad_parts,
--> 297 ] = tf.while_loop(
298 cond=lambda i, *_: i < self.num_steps,
299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in <lambda>(i, *args)
297 ] = tf.while_loop(
298 cond=lambda i, *_: i < self.num_steps,
--> 299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
300 self.target_fn, self.step_sizes, get_velocity_parts, *args)),
301 loop_vars=[
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in _one_step(target_fn, step_sizes, get_velocity_parts, half_next_momentum_parts, state_parts, target, target_grad_parts)
353 next_target_grad_parts))
354
--> 355 tensorshape_util.set_shape(next_target, target.shape)
356 for ng, g in zip(next_target_grad_parts, target_grad_parts):
357 tensorshape_util.set_shape(ng, g.shape)
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py in set_shape(tensor, shape)
326 """
327 if hasattr(tensor, 'set_shape'):
--> 328 tensor.set_shape(shape)
329
330
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
1213 def set_shape(self, shape):
1214 if not self.shape.is_compatible_with(shape):
-> 1215 raise ValueError(
1216 "Tensor's shape %s is not compatible with supplied shape %s" %
1217 (self.shape, shape))
ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)
形状为 (2, ) 的唯一变量是 sigma
,但我不知道它是如何变成 (2, 2)
张量的。
问题出在步长的形状上。如所写,链状态部分分别具有(包括 n_chains
批量形状)形状 [2]
和 [2, 10]
。然而,您的步长都被初始化为形状 [2, 1]
。这对于第二个状态部分([2, 1]
用 [2, 10]
广播)是正确的,但对于第一个状态部分不正确——你最终会在某个地方得到一个 [2, 2]
,大概是在 step_size * grad(tlp, state[0])
项中(释义)在积分器中。
我将步长初始化重写为此,它可能会使用一些 code-golfing,但按预期工作:
step_size = np.random.rand(n_chains) * .5 + 1.
step_size = [np.reshape(step_size, [n_chains] + [1] * (x.shape.ndims - 1))
for x in inits]
# now step size shapes are [2] and [2, 1]
一些其他注意事项:
- 查看
tf.linalg.matvec
——它会避免使用 newaxis + matmul ,从而为您节省一些字符(可能还有一些 flops)
- 您似乎在进行 stan-style 初始化,在
[-2, 2]^n
超立方体中进行随机初始化。一般来说,您会希望在 unconstrained space 中执行此操作,因此您可能希望通过约束双射器推动那些随机初始化,至少在它们不是时默认值 (Identity
)。 TransformedTransitionKernel
假设 user-provided 状态处于 constrained(样本)space。希望这很清楚...如果不清楚请告诉我!
- 这里有一个新的实验性功能“固定”,它可能会为您简化一些事情。您首先要确保您的分布有名称(将
name
arg 传递给构造函数),然后您可以调用 pinned = model.pinned(y=y)
来获取 tfp.experimental.distributions.JointDistributionPinned
的实例。你可以使用 pinned.unnormalized_log_prob
代替你的 def target_log_prob_fn
(它做同样的事情),你也可以调用 pinned.experimental_default_event_space_bijector()
来获得一个双射器,它“做正确的事”来转换un-pinned 个变量。也就是说,你可以把那个东西交给 TransformedTransitionKernel
。它实际上是一个“多部分”(或“联合”)双射器,所以它吃列表和 returns 列表;您不再需要 list-wrap 您的 passed-in 双射器。最近,TTK 知道如何使用这些多部分双射器。顾名思义,这些都是相当新的 experimental/subject 到 API 调整,但应该处于良好的工作状态;如果您尝试它们并 运行 遇到问题,请告诉我们!
这是一个带有略微修改的步长代码的 colab:https://colab.research.google.com/drive/1o-nygALqdq2ppj5rU9d6UVafZM5SBCd4
HTH!
我正在尝试让 NUTS 采样器在玩具模型上工作。我已经 运行 解决了标题中提到的问题。
这是重现错误的代码:
import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from functools import partial
import numpy as np
tfd = tfp.distributions
A = tf.random.normal(
[10,10], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
noise_std = tf.random.normal([1])
x1 = tfd.Normal(0, 10 * tf.ones(10)).sample()
x1 = x1[..., tf.newaxis]
y = tf.linalg.matmul(A, x1) + noise_std
model = tfd.JointDistributionSequentialAutoBatched([
tfd.Normal(loc=0., scale=1.), #sigma
tfd.Normal(0, 10 * tf.ones(10)),
lambda x_rv, sigma : tfd.Normal(loc=tf.linalg.matmul(A, x_rv[...,tf.newaxis]) + sigma, scale=1.0)
])
def target_log_prob_fn(sigma, x_rv):
return model.log_prob([sigma, x_rv, y[tf.newaxis, ...]])
def trace_fn(_, pkr):
return (
pkr.inner_results.inner_results.target_log_prob,
pkr.inner_results.inner_results.leapfrogs_taken,
pkr.inner_results.inner_results.has_divergence,
pkr.inner_results.inner_results.energy,
pkr.inner_results.inner_results.log_accept_ratio)
n_chains = 2
def run_nuts_template(
trace_fn,
target_log_prob_fn,
inits,
bijectors_list=None,
num_steps=500,
num_burnin=500,
n_chains=n_chains):
step_size = np.random.rand(n_chains, 1)*.5 + 1.
if not isinstance(inits, list):
inits = [inits]
if bijectors_list is None:
bijectors_list = [tfb.Identity()]*len(inits)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
tfp.mcmc.TransformedTransitionKernel(
inner_kernel=tfp.mcmc.NoUTurnSampler(
target_log_prob_fn,
step_size=[step_size]*len(inits)
),
bijector=bijectors_list
),
target_accept_prob=.8,
num_adaptation_steps=int(0.8*num_burnin),
step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
inner_results=pkr.inner_results._replace(step_size=new_step_size)
),
step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
)
res = tfp.mcmc.sample_chain(
num_results=num_steps,
num_burnin_steps=num_burnin,
current_state=inits,
kernel=kernel,
trace_fn=trace_fn
)
return res
inits = model.sample(n_chains)
run_nuts = partial(run_nuts_template, trace_fn)
inits = [tf.random.uniform(s.shape, -2, 2, tf.float32, name="initializer") for s in inits]
run_nuts(target_log_prob_fn, inits[:-1])
错误:ValueError:张量的形状 (2, 2) 与提供的形状 (2,) 不兼容
完整堆栈跟踪:https://pastebin.com/zAA58P53
ValueError Traceback (most recent call last)
<ipython-input-17-ab5eae0dd51a> in <module>
3 ]
4
----> 5 run_nuts(
6 target_log_prob_fn,
7 inits[:-1]
<ipython-input-13-2d7a920574a8> in run_nuts_template(trace_fn, target_log_prob_fn, inits, bijectors_list, num_steps, num_burnin, n_chains)
45 )
46
---> 47 res = tfp.mcmc.sample_chain(
48 num_results=num_steps,
49 num_burnin_steps=num_burnin,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
359 return seed, next_state, current_kernel_results
360
--> 361 (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
362 loop_fn=_trace_scan_fn,
363 initial_state=(seed, current_state, previous_kernel_results),
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, name)
462 return i + 1, state, num_steps_traced, trace_arrays
463
--> 464 _, final_state, _, trace_arrays = tf.while_loop(
465 cond=lambda i, *_: i < length,
466 body=_body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in _body(i, state, num_steps_traced, trace_arrays)
452 def _body(i, state, num_steps_traced, trace_arrays):
453 elem = elems_array.read(i)
--> 454 state = loop_fn(state, elem)
455
456 trace_arrays, num_steps_traced = ps.cond(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _trace_scan_fn(seed_state_and_results, num_steps)
352
353 def _trace_scan_fn(seed_state_and_results, num_steps):
--> 354 seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
355 loop_num_iter=num_steps,
356 body_fn=_seeded_one_step,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
346 # where while/LoopCond needs it.
347 loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
--> 348 return tf.while_loop(
349 cond=lambda i, *args: i < loop_num_iter,
350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in <lambda>(i, *args)
348 return tf.while_loop(
349 cond=lambda i, *args: i < loop_num_iter,
--> 350 body=lambda i, *args: [i + 1] + list(body_fn(*args)),
351 loop_vars=[np.int32(0)] + initial_loop_vars,
352 parallel_iterations=parallel_iterations
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _seeded_one_step(seed, *state_and_results)
349 one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
350 return [passalong_seed] + list(
--> 351 kernel.one_step(*state_and_results, **one_step_kwargs))
352
353 def _trace_scan_fn(seed_state_and_results, num_steps):
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py in one_step(self, current_state, previous_kernel_results, seed)
454 # Step the inner kernel.
455 inner_kwargs = {} if seed is None else dict(seed=seed)
--> 456 new_state, new_inner_results = self.inner_kernel.one_step(
457 current_state, inner_results, **inner_kwargs)
458
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py in one_step(self, current_state, previous_kernel_results, seed)
397 self.name, 'transformed_kernel', 'one_step')):
398 inner_kwargs = {} if seed is None else dict(seed=seed)
--> 399 transformed_next_state, kernel_results = self._inner_kernel.one_step(
400 previous_kernel_results.transformed_state,
401 previous_kernel_results.inner_results,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in one_step(self, current_state, previous_kernel_results, seed)
392 )
393
--> 394 _, _, _, new_step_metastate = tf.while_loop(
395 cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda
396 (iter_ < self.max_tree_depth) &
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, state, metastate)
396 (iter_ < self.max_tree_depth) &
397 tf.reduce_any(metastate.continue_tree)),
--> 398 body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda
399 previous_kernel_results.step_size,
400 previous_kernel_results.momentum_state_memory,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed)
570 momentum_subtree_cumsum,
571 leapfrogs_taken
--> 572 ] = self._build_sub_tree(
573 directions_expanded,
574 integrator,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name)
750 final_not_divergence,
751 momentum_state_memory,
--> 752 ] = tf.while_loop(
753 cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda
754 leapfrogs_taken, state, state_c, continue_tree,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)
758 leapfrogs_taken, state, state_c, continue_tree,
759 not_divergence, momentum_state_memory: (
--> 760 self._loop_build_sub_tree(
761 directions, integrator, current_step_meta_info,
762 iter_, energy_diff_sum, init_momentum_cumsum,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed)
811 next_target,
812 next_target_grad_parts
--> 813 ] = integrator(prev_tree_state.momentum,
814 prev_tree_state.state,
815 prev_tree_state.target,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in __call__(self, momentum_parts, state_parts, target, target_grad_parts, kinetic_energy_fn, name)
295 next_target,
296 next_target_grad_parts,
--> 297 ] = tf.while_loop(
298 cond=lambda i, *_: i < self.num_steps,
299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
603 func.__module__, arg_name, arg_value, 'in a future version'
604 if date is None else ('after %s' % date), instructions)
--> 605 return func(*args, **kwargs)
606
607 doc = _add_deprecated_arg_value_notice_to_docstring(
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
2487
2488 """
-> 2489 return while_loop(
2490 cond=cond,
2491 body=body,
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
2733 list(loop_vars))
2734 while cond(*loop_vars):
-> 2735 loop_vars = body(*loop_vars)
2736 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2737 packed = True
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in <lambda>(i, *args)
297 ] = tf.while_loop(
298 cond=lambda i, *_: i < self.num_steps,
--> 299 body=lambda i, *args: [i + 1] + list(_one_step( # pylint: disable=no-value-for-parameter,g-long-lambda
300 self.target_fn, self.step_sizes, get_velocity_parts, *args)),
301 loop_vars=[
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in _one_step(target_fn, step_sizes, get_velocity_parts, half_next_momentum_parts, state_parts, target, target_grad_parts)
353 next_target_grad_parts))
354
--> 355 tensorshape_util.set_shape(next_target, target.shape)
356 for ng, g in zip(next_target_grad_parts, target_grad_parts):
357 tensorshape_util.set_shape(ng, g.shape)
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py in set_shape(tensor, shape)
326 """
327 if hasattr(tensor, 'set_shape'):
--> 328 tensor.set_shape(shape)
329
330
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
1213 def set_shape(self, shape):
1214 if not self.shape.is_compatible_with(shape):
-> 1215 raise ValueError(
1216 "Tensor's shape %s is not compatible with supplied shape %s" %
1217 (self.shape, shape))
ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)
形状为 (2, ) 的唯一变量是 sigma
,但我不知道它是如何变成 (2, 2)
张量的。
问题出在步长的形状上。如所写,链状态部分分别具有(包括 n_chains
批量形状)形状 [2]
和 [2, 10]
。然而,您的步长都被初始化为形状 [2, 1]
。这对于第二个状态部分([2, 1]
用 [2, 10]
广播)是正确的,但对于第一个状态部分不正确——你最终会在某个地方得到一个 [2, 2]
,大概是在 step_size * grad(tlp, state[0])
项中(释义)在积分器中。
我将步长初始化重写为此,它可能会使用一些 code-golfing,但按预期工作:
step_size = np.random.rand(n_chains) * .5 + 1.
step_size = [np.reshape(step_size, [n_chains] + [1] * (x.shape.ndims - 1))
for x in inits]
# now step size shapes are [2] and [2, 1]
一些其他注意事项:
- 查看
tf.linalg.matvec
——它会避免使用 newaxis + matmul ,从而为您节省一些字符(可能还有一些 flops)
- 您似乎在进行 stan-style 初始化,在
[-2, 2]^n
超立方体中进行随机初始化。一般来说,您会希望在 unconstrained space 中执行此操作,因此您可能希望通过约束双射器推动那些随机初始化,至少在它们不是时默认值 (Identity
)。TransformedTransitionKernel
假设 user-provided 状态处于 constrained(样本)space。希望这很清楚...如果不清楚请告诉我! - 这里有一个新的实验性功能“固定”,它可能会为您简化一些事情。您首先要确保您的分布有名称(将
name
arg 传递给构造函数),然后您可以调用pinned = model.pinned(y=y)
来获取tfp.experimental.distributions.JointDistributionPinned
的实例。你可以使用pinned.unnormalized_log_prob
代替你的def target_log_prob_fn
(它做同样的事情),你也可以调用pinned.experimental_default_event_space_bijector()
来获得一个双射器,它“做正确的事”来转换un-pinned 个变量。也就是说,你可以把那个东西交给TransformedTransitionKernel
。它实际上是一个“多部分”(或“联合”)双射器,所以它吃列表和 returns 列表;您不再需要 list-wrap 您的 passed-in 双射器。最近,TTK 知道如何使用这些多部分双射器。顾名思义,这些都是相当新的 experimental/subject 到 API 调整,但应该处于良好的工作状态;如果您尝试它们并 运行 遇到问题,请告诉我们!
这是一个带有略微修改的步长代码的 colab:https://colab.research.google.com/drive/1o-nygALqdq2ppj5rU9d6UVafZM5SBCd4
HTH!