-
Notifications
You must be signed in to change notification settings - Fork 270
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
An example is probably the clearest illustration. If the transform increases the number of event dimensions, everything works as expected.
>>> from numpyro.distributions.transforms import CorrCholeskyTransform, ComposeTransform
>>> part = CorrCholeskyTransform()
>>> part.domain.event_dim, part.codomain.event_dim
(1, 2)
>>> composed = ComposeTransform([part])
>>> composed.domain.event_dim, composed.codomain.event_dim
(1, 2)If the transform reduces the number of event dimensions, wrapping the transform in a ComposeTransform leads to unexpected results.
>>> part = CorrCholeskyTransform().inv
>>> part.domain.event_dim, part.codomain.event_dim
(2, 1)
>>> composed = ComposeTransform([part])
>>> composed.domain.event_dim, composed.codomain.event_dim
(3, 1)I had a brief look at the code below, but I couldn't quite get my head around it.
numpyro/numpyro/distributions/transforms.py
Lines 292 to 298 in 8ace34f
| def _get_compose_transform_output_event_dim(parts): | |
| output_event_dim = parts[0].codomain.event_dim | |
| for part in parts[1:]: | |
| output_event_dim = part.codomain.event_dim + max( | |
| output_event_dim - part.domain.event_dim, 0 | |
| ) | |
| return output_event_dim |
Here's a minimal test that can be added to test/test_transforms.py to reproduce.
@pytest.mark.parametrize("transform", [
CorrCholeskyTransform(), # passes
CorrCholeskyTransform().inv, # fails
])
def test_compose_domain_codomain(transform):
composed = ComposeTransform([transform])
assert transform.domain.event_dim == composed.domain.event_dim
assert transform.codomain.event_dim == composed.codomain.event_dimMetadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working