Skip to content

An auto guide's _unpack_latent and _unpack_latent._inverse don't use produce the same order #1809

@CKrawczyk

Description

@CKrawczyk

I ran across this issue when trying to evaluate the log_prob of the guide at the positions from and HMC chain, to do this the positions need to be flattened into the same order the guide uses internally. I found that guide._unpack_latent will take a flat vector into a structured one, so I expected guide._unpack_latent._inverse to do the opposite and take the structured position into a flat one. What I found was the order did not match as expected.

Here is some reproducing code:

# make a model where the sample sites are not in alphabetical order
def model_funnel():
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(1), jnp.exp(y / 2)))

# make an auto guide
guide_funnel = numpyro.infer.autoguide.AutoDiagonalNormal(model_funnel)

# fit the guide
optim = numpyro.optim.Adam(step_size=1e-4)
svi = numpyro.infer.SVI(model_funnel, guide_funnel, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 2000)

# make a flat vector to unpack and re-pack
values = jnp.array([1.0, 2.0])
unpack = guide_funnel._unpack_latent(values)
pack = guide_funnel._unpack_latent._inverse(unpack)

print(pack)
# [2.0, 1.0]

print(values)
# [1.0, 2.0]

What I think is going on is the _inverse method of the UnpackTransform uses ravel_pytree: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/transforms.py#L1137
But the forward transform uses the custom _unravel_dict function: https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/autoguide.py#L609

From what I can tell ravel_pytree will always flatten the pytree in alphabetical order, but _unravel_dict will flatten in the order the sample sites show up in the model.

Posable solutions:

  1. Update _unravel_dict to use ravel_pytree's unflatten function instead of defining a custom order (or emulate the same order that function would give)
  2. Update UnpackTransform._inverse to invert the _unravel_dict operation rather than assume it is in the same order as ravel_pytree

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions