-
Notifications
You must be signed in to change notification settings - Fork 270
Description
I ran across this issue when trying to evaluate the log_prob of the guide at the positions from and HMC chain, to do this the positions need to be flattened into the same order the guide uses internally. I found that guide._unpack_latent will take a flat vector into a structured one, so I expected guide._unpack_latent._inverse to do the opposite and take the structured position into a flat one. What I found was the order did not match as expected.
Here is some reproducing code:
# make a model where the sample sites are not in alphabetical order
def model_funnel():
y = numpyro.sample("y", dist.Normal(0, 3))
numpyro.sample("x", dist.Normal(jnp.zeros(1), jnp.exp(y / 2)))
# make an auto guide
guide_funnel = numpyro.infer.autoguide.AutoDiagonalNormal(model_funnel)
# fit the guide
optim = numpyro.optim.Adam(step_size=1e-4)
svi = numpyro.infer.SVI(model_funnel, guide_funnel, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 2000)
# make a flat vector to unpack and re-pack
values = jnp.array([1.0, 2.0])
unpack = guide_funnel._unpack_latent(values)
pack = guide_funnel._unpack_latent._inverse(unpack)
print(pack)
# [2.0, 1.0]
print(values)
# [1.0, 2.0]What I think is going on is the _inverse method of the UnpackTransform uses ravel_pytree: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py#L1137
But the forward transform uses the custom _unravel_dict function: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/autoguide.py#L609
From what I can tell ravel_pytree will always flatten the pytree in alphabetical order, but _unravel_dict will flatten in the order the sample sites show up in the model.
Posable solutions:
- Update
_unravel_dictto useravel_pytree's unflatten function instead of defining a custom order (or emulate the same order that function would give) - Update
UnpackTransform._inverseto invert the_unravel_dictoperation rather than assume it is in the same order asravel_pytree