Tensorflow probability: ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)

Tensorflow probability: ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)

我正在尝试让 NUTS 采样器在玩具模型上工作。我已经 运行 解决了标题中提到的问题。

这是重现错误的代码:

import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from functools import partial
import numpy as np

tfd = tfp.distributions


A = tf.random.normal(
    [10,10], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
noise_std = tf.random.normal([1])

x1 = tfd.Normal(0, 10 * tf.ones(10)).sample()
x1 = x1[..., tf.newaxis]
y = tf.linalg.matmul(A, x1) + noise_std


model = tfd.JointDistributionSequentialAutoBatched([
    tfd.Normal(loc=0., scale=1.), #sigma 
    tfd.Normal(0, 10 * tf.ones(10)),
    lambda x_rv, sigma : tfd.Normal(loc=tf.linalg.matmul(A, x_rv[...,tf.newaxis]) + sigma, scale=1.0)
])

def target_log_prob_fn(sigma, x_rv):
    return model.log_prob([sigma, x_rv, y[tf.newaxis, ...]])


def trace_fn(_, pkr):  
    return (
        pkr.inner_results.inner_results.target_log_prob,
        pkr.inner_results.inner_results.leapfrogs_taken,
        pkr.inner_results.inner_results.has_divergence,
        pkr.inner_results.inner_results.energy,
        pkr.inner_results.inner_results.log_accept_ratio)

n_chains = 2

def run_nuts_template(
    trace_fn,
    target_log_prob_fn,
    inits,
    bijectors_list=None, 
    num_steps=500,
    num_burnin=500,
    n_chains=n_chains):
    
    step_size = np.random.rand(n_chains, 1)*.5 + 1.
    
    if not isinstance(inits, list):
        inits = [inits]
        
    if bijectors_list is None:
        bijectors_list = [tfb.Identity()]*len(inits)

    kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
        tfp.mcmc.TransformedTransitionKernel(
            inner_kernel=tfp.mcmc.NoUTurnSampler(
                target_log_prob_fn,
                step_size=[step_size]*len(inits)
            ),
            bijector=bijectors_list
        ),
        target_accept_prob=.8,
        num_adaptation_steps=int(0.8*num_burnin),
        step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(
              inner_results=pkr.inner_results._replace(step_size=new_step_size)
          ),
        step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
        log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio,
    )
    
    res = tfp.mcmc.sample_chain(
        num_results=num_steps,
        num_burnin_steps=num_burnin,
        current_state=inits,
        kernel=kernel,
        trace_fn=trace_fn
    )
    return res



inits = model.sample(n_chains)
run_nuts = partial(run_nuts_template, trace_fn)


inits = [tf.random.uniform(s.shape, -2, 2, tf.float32, name="initializer") for s in inits]

run_nuts(target_log_prob_fn, inits[:-1])


错误:ValueError:张量的形状 (2, 2) 与提供的形状 (2,) 不兼容

完整堆栈跟踪:https://pastebin.com/zAA58P53

ValueError                                Traceback (most recent call last)
<ipython-input-17-ab5eae0dd51a> in <module>
      3 ]
      4 
----> 5 run_nuts(
      6     target_log_prob_fn,
      7             inits[:-1]
 
<ipython-input-13-2d7a920574a8> in run_nuts_template(trace_fn, target_log_prob_fn, inits, bijectors_list, num_steps, num_burnin, n_chains)
     45     )
     46 
---> 47     res = tfp.mcmc.sample_chain(
     48         num_results=num_steps,
     49         num_burnin_steps=num_burnin,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
    359       return seed, next_state, current_kernel_results
    360 
--> 361     (_, _, final_kernel_results), (all_states, trace) = mcmc_util.trace_scan(
    362         loop_fn=_trace_scan_fn,
    363         initial_state=(seed, current_state, previous_kernel_results),
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, parallel_iterations, name)
    462       return i + 1, state, num_steps_traced, trace_arrays
    463 
--> 464     _, final_state, _, trace_arrays = tf.while_loop(
    465         cond=lambda i, *_: i < length,
    466         body=_body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in _body(i, state, num_steps_traced, trace_arrays)
    452     def _body(i, state, num_steps_traced, trace_arrays):
    453       elem = elems_array.read(i)
--> 454       state = loop_fn(state, elem)
    455 
    456       trace_arrays, num_steps_traced = ps.cond(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _trace_scan_fn(seed_state_and_results, num_steps)
    352 
    353     def _trace_scan_fn(seed_state_and_results, num_steps):
--> 354       seed, next_state, current_kernel_results = mcmc_util.smart_for_loop(
    355           loop_num_iter=num_steps,
    356           body_fn=_seeded_one_step,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
    346       # where while/LoopCond needs it.
    347       loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
--> 348       return tf.while_loop(
    349           cond=lambda i, *args: i < loop_num_iter,
    350           body=lambda i, *args: [i + 1] + list(body_fn(*args)),
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/util.py in <lambda>(i, *args)
    348       return tf.while_loop(
    349           cond=lambda i, *args: i < loop_num_iter,
--> 350           body=lambda i, *args: [i + 1] + list(body_fn(*args)),
    351           loop_vars=[np.int32(0)] + initial_loop_vars,
    352           parallel_iterations=parallel_iterations
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/sample.py in _seeded_one_step(seed, *state_and_results)
    349       one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
    350       return [passalong_seed] + list(
--> 351           kernel.one_step(*state_and_results, **one_step_kwargs))
    352 
    353     def _trace_scan_fn(seed_state_and_results, num_steps):
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/dual_averaging_step_size_adaptation.py in one_step(self, current_state, previous_kernel_results, seed)
    454       # Step the inner kernel.
    455       inner_kwargs = {} if seed is None else dict(seed=seed)
--> 456       new_state, new_inner_results = self.inner_kernel.one_step(
    457           current_state, inner_results, **inner_kwargs)
    458 
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/transformed_kernel.py in one_step(self, current_state, previous_kernel_results, seed)
    397         self.name, 'transformed_kernel', 'one_step')):
    398       inner_kwargs = {} if seed is None else dict(seed=seed)
--> 399       transformed_next_state, kernel_results = self._inner_kernel.one_step(
    400           previous_kernel_results.transformed_state,
    401           previous_kernel_results.inner_results,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in one_step(self, current_state, previous_kernel_results, seed)
    392           )
    393 
--> 394       _, _, _, new_step_metastate = tf.while_loop(
    395           cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
    396               (iter_ < self.max_tree_depth) &
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, state, metastate)
    396               (iter_ < self.max_tree_depth) &
    397               tf.reduce_any(metastate.continue_tree)),
--> 398           body=lambda iter_, seed, state, metastate: self._loop_tree_doubling(  # pylint: disable=g-long-lambda
    399               previous_kernel_results.step_size,
    400               previous_kernel_results.momentum_state_memory,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed)
    570           momentum_subtree_cumsum,
    571           leapfrogs_taken
--> 572       ] = self._build_sub_tree(
    573           directions_expanded,
    574           integrator,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name)
    750           final_not_divergence,
    751           momentum_state_memory,
--> 752       ] = tf.while_loop(
    753           cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
    754                       leapfrogs_taken, state, state_c, continue_tree,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in <lambda>(iter_, seed, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory)
    758                       leapfrogs_taken, state, state_c, continue_tree,
    759                       not_divergence, momentum_state_memory: (
--> 760                           self._loop_build_sub_tree(
    761                               directions, integrator, current_step_meta_info,
    762                               iter_, energy_diff_sum, init_momentum_cumsum,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/nuts.py in _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory, seed)
    811           next_target,
    812           next_target_grad_parts
--> 813       ] = integrator(prev_tree_state.momentum,
    814                      prev_tree_state.state,
    815                      prev_tree_state.target,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in __call__(self, momentum_parts, state_parts, target, target_grad_parts, kinetic_energy_fn, name)
    295           next_target,
    296           next_target_grad_parts,
--> 297       ] = tf.while_loop(
    298           cond=lambda i, *_: i < self.num_steps,
    299           body=lambda i, *args: [i + 1] + list(_one_step(  # pylint: disable=no-value-for-parameter,g-long-lambda
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    603                   func.__module__, arg_name, arg_value, 'in a future version'
    604                   if date is None else ('after %s' % date), instructions)
--> 605       return func(*args, **kwargs)
    606 
    607     doc = _add_deprecated_arg_value_notice_to_docstring(
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
   2487 
   2488   """
-> 2489   return while_loop(
   2490       cond=cond,
   2491       body=body,
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2733                                               list(loop_vars))
   2734       while cond(*loop_vars):
-> 2735         loop_vars = body(*loop_vars)
   2736         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2737           packed = True
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in <lambda>(i, *args)
    297       ] = tf.while_loop(
    298           cond=lambda i, *_: i < self.num_steps,
--> 299           body=lambda i, *args: [i + 1] + list(_one_step(  # pylint: disable=no-value-for-parameter,g-long-lambda
    300               self.target_fn, self.step_sizes, get_velocity_parts, *args)),
    301           loop_vars=[
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/mcmc/internal/leapfrog_integrator.py in _one_step(target_fn, step_sizes, get_velocity_parts, half_next_momentum_parts, state_parts, target, target_grad_parts)
    353               next_target_grad_parts))
    354 
--> 355     tensorshape_util.set_shape(next_target, target.shape)
    356     for ng, g in zip(next_target_grad_parts, target_grad_parts):
    357       tensorshape_util.set_shape(ng, g.shape)
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow_probability/python/internal/tensorshape_util.py in set_shape(tensor, shape)
    326   """
    327   if hasattr(tensor, 'set_shape'):
--> 328     tensor.set_shape(shape)
    329 
    330 
 
~/anaconda3/envs/tensor/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
   1213   def set_shape(self, shape):
   1214     if not self.shape.is_compatible_with(shape):
-> 1215       raise ValueError(
   1216           "Tensor's shape %s is not compatible with supplied shape %s" %
   1217           (self.shape, shape))
 
ValueError: Tensor's shape (2, 2) is not compatible with supplied shape (2,)

形状为 (2, ) 的唯一变量是 sigma,但我不知道它是如何变成 (2, 2) 张量的。

问题出在步长的形状上。如所写,链状态部分分别具有(包括 n_chains 批量形状)形状 [2][2, 10]。然而,您的步长都被初始化为形状 [2, 1]。这对于第二个状态部分([2, 1][2, 10] 广播)是正确的,但对于第一个状态部分不正确——你最终会在某个地方得到一个 [2, 2],大概是在 step_size * grad(tlp, state[0]) 项中(释义)在积分器中。

我将步长初始化重写为此,它可能会使用一些 code-golfing,但按预期工作:

    step_size = np.random.rand(n_chains) * .5 + 1.
    step_size = [np.reshape(step_size, [n_chains] + [1] * (x.shape.ndims - 1))
                 for x in inits]
    # now step size shapes are [2] and [2, 1]

一些其他注意事项:

  1. 查看 tf.linalg.matvec——它会避免使用 newaxis + matmul
  2. ,从而为您节省一些字符(可能还有一些 flops)
  3. 您似乎在进行 stan-style 初始化,在 [-2, 2]^n 超立方体中进行随机初始化。一般来说,您会希望在 unconstrained space 中执行此操作,因此您可能希望通过约束双射器推动那些随机初始化,至少在它们不是时默认值 (Identity)。 TransformedTransitionKernel 假设 user-provided 状态处于 constrained(样本)space。希望这很清楚...如果不清楚请告诉我!
  4. 这里有一个新的实验性功能“固定”,它可能会为您简化一些事情。您首先要确保您的分布有名称(将 name arg 传递给构造函数),然后您可以调用 pinned = model.pinned(y=y) 来获取 tfp.experimental.distributions.JointDistributionPinned 的实例。你可以使用 pinned.unnormalized_log_prob 代替你的 def target_log_prob_fn (它做同样的事情),你也可以调用 pinned.experimental_default_event_space_bijector() 来获得一个双射器,它“做正确的事”来转换un-pinned 个变量。也就是说,你可以把那个东西交给 TransformedTransitionKernel。它实际上是一个“多部分”(或“联合”)双射器,所以它吃列表和 returns 列表;您不再需要 list-wrap 您的 passed-in 双射器。最近,TTK 知道如何使用这些多部分双射器。顾名思义,这些都是相当新的 experimental/subject 到 API 调整,但应该处于良好的工作状态;如果您尝试它们并 运行 遇到问题,请告诉我们!

这是一个带有略微修改的步长代码的 colab:https://colab.research.google.com/drive/1o-nygALqdq2ppj5rU9d6UVafZM5SBCd4

HTH!