Skip to content

Incorrect event dimension for ComposeTransform if transform reduces event dimensions #1893

@tillahoffmann

Description

@tillahoffmann

An example is probably the clearest illustration. If the transform increases the number of event dimensions, everything works as expected.

>>> from numpyro.distributions.transforms import CorrCholeskyTransform, ComposeTransform
>>> part = CorrCholeskyTransform()
>>> part.domain.event_dim, part.codomain.event_dim
(1, 2)
>>> composed = ComposeTransform([part])
>>> composed.domain.event_dim, composed.codomain.event_dim
(1, 2)

If the transform reduces the number of event dimensions, wrapping the transform in a ComposeTransform leads to unexpected results.

>>> part = CorrCholeskyTransform().inv
>>> part.domain.event_dim, part.codomain.event_dim
(2, 1)
>>> composed = ComposeTransform([part])
>>> composed.domain.event_dim, composed.codomain.event_dim
(3, 1)

I had a brief look at the code below, but I couldn't quite get my head around it.

def _get_compose_transform_output_event_dim(parts):
output_event_dim = parts[0].codomain.event_dim
for part in parts[1:]:
output_event_dim = part.codomain.event_dim + max(
output_event_dim - part.domain.event_dim, 0
)
return output_event_dim

Here's a minimal test that can be added to test/test_transforms.py to reproduce.

@pytest.mark.parametrize("transform", [
    CorrCholeskyTransform(),  # passes
    CorrCholeskyTransform().inv,  # fails
])
def test_compose_domain_codomain(transform):
    composed = ComposeTransform([transform])
    assert transform.domain.event_dim == composed.domain.event_dim
    assert transform.codomain.event_dim == composed.codomain.event_dim

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions