Skip to content

Conversation

@juanitorduz
Copy link
Collaborator

Closes #1987

I gave it a go to this issue with the help of Claude Sonnet 3.7 to get started and work on iterations to make it "work". So all feedback is welcome.

@juanitorduz
Copy link
Collaborator Author

juanitorduz commented Feb 28, 2025

There is a test failing that is unrelated with this change

FAILED test/infer/test_autoguide.py::test_laplace_approximation_warning - Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
Emitted warnings: [].

FAILED test/test_transforms.py::test_bijective_transforms[transform9-shape9] - assert Array(False, dtype=bool)

FAILED test/test_distributions.py::test_entropy_categorical - AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @juanitorduz! I left a few comments. I think you can use https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl to extract (params, rngs, mutable) stuffs. And https://flax.readthedocs.io/en/latest/guides/surgery.html to replace params/state.

(I also feel that rng is a mutable state under nnx https://flax.readthedocs.io/en/latest/guides/randomness.html)

@juanitorduz
Copy link
Collaborator Author

Thanks for the comments @fehiepsi I will look into them 🙏

@fehiepsi fehiepsi added the WIP label Mar 6, 2025
Juan Orduz added 5 commits March 6, 2025 19:59
init

fix

fix

fix lint

rm files

rm files

rm files

rm files

rm files
@fehiepsi fehiepsi added this to the 0.18 milestone Mar 7, 2025
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @juanitorduz, I marked this as milestone 0.18 but there is no need to rush. Looks at the implementation, it seems to me that it's painful to deal with nnx (surprising? not sure why flax team is switching to use it given the complexity). If possible, try to come up with a solution similar to flax/haiku. If complicated, then it's not worth to support.

I would suggest to reach out to nnx devs if there's something unclear. Two main questions are

  • how to get the initial parameters and the initial mutable states of nn module
  • given parameters and states, how to get output and the updated states

We need params stored as a nested dict. For the state, if it is not a dict, you can just store as a dict via {"state": state}.

@juanitorduz
Copy link
Collaborator Author

juanitorduz commented Mar 7, 2025

Thanks @fehiepsi ! Indeed it has been harder than expected, but it also reflects my little experience with the NNX internals 😄.

My aim is to simplify indeed! I'll use this branch to push my progress (experiments). I'll let you know whenever I feel really blocked or whenever I need a review. I will also try to reach the nnx team for some tips :)

It would be super nice to have this integration to make sure the mid-long term integration with the neural network JAX community :)

@juanitorduz juanitorduz marked this pull request as draft March 7, 2025 07:37
Juan Orduz added 2 commits March 7, 2025 13:11
add example

rm test patch

feedback 1

feedback 2

feedback 3

improvements

fix test

hacky way to find batch normalization layers

remove unused code

simplyfy code

modularize

tests

simplify

simplify
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@juanitorduz
Copy link
Collaborator Author

juanitorduz commented Mar 7, 2025

@fehiepsi I simplified the code a lot (It was a mess, I know 🙈 ), I think it is now more modular (thank to your advice) and looks similar as the flax case. In 2b89566 you can see all tests pass (I had to write some custom code to make this work for Python 3.9). I think it is in a much better shape if you wanna take a look (most of the new lines of code are actually tests).

Moreover, I am working on a notebook example (wip) adapting a simple example (https://juanitorduz.github.io/flax_numpyro/) and showing how to use NNX + NumPyro. I hope to clean the notebook and add the relevant content this weekend :)

@fehiepsi
Copy link
Member

fehiepsi commented Mar 8, 2025

Moreover, I am working on a notebook example (wip) adapting a simple example (https://juanitorduz.github.io/flax_numpyro/) and showing how to use NNX + NumPyro. I hope to clean the notebook and add the relevant content this weekend :)

Nice tutorial, looking forward to seeing the nnx version! :)

I think the implementation could be even simpler. Let's see. ;)

@juanitorduz
Copy link
Collaborator Author

Thanks for the feedback @fehiepsi ! I will try to simplify even more based on your input 🙏

eager inizialization approach

rm unused code 1

rm unused code 2

rm code

rm code

rm  code

rm code
@juanitorduz juanitorduz requested a review from fehiepsi March 9, 2025 14:46
@juanitorduz
Copy link
Collaborator Author

Minor comment: When re-running the notebook again I see

0%|          | 0/20000 [00:00<?, ?it/s]/Users/juanitorduz/Documents/numpyro/numpyro/infer/elbo.py:253: UserWarning: mutable state is currently ignored when num_particles > 1.

Is this expected? (the results do not change tho)

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is expected. It's not clear how mutable states are updated with num_particles > 1. For dropout rng, we need a special logic to vectorize split. For batch_stats, we need sequentially updates then aggregates. I don't know - it is tricky to think about a reasonable solution for the issue.

I would suggest to use num_particles=1 if batch norm plays a key role in your network.

if params:
graph_def, state = nnx.split(model)

if params:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment above - you don't need to split the model again

params_state = eager_params_state
if params:
    nnx.replace_by_pure_dict(params_state, params)
other_state = eager_other_state
if mutable_holder:
    nnx.replace_by_pure_dict(other_state, mutable_holder["state"])
model = nnx.merge(graph_def, params_state, other_state)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to use the latest mutable state, not just the latest params.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm I am a bit confused: I tried this in 17db8c1 and I get from the tests

FAILED test/contrib/test_module.py::test_nnx_state_dropout_smoke[True-True] - TypeError: unsupported operand type(s) for *: 'float' and 'VariableState'
FAILED test/contrib/test_module.py::test_nnx_state_dropout_smoke[True-False] - TypeError: unsupported operand type(s) for *: 'float' and 'VariableState'

Is if ok that the apply function has access to parameters outside of it like eager_params_state and mutable_holder?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is due to the line mutable_holder["state"] = eager_other_state above. We can remove that line because we already have the logic to get the latest mutable state from the numpyro.mutable call.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm I removed this line (5ffcfe3) and keep seeing the tests failing. I wonder if there is something missing in 17db8c1 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we are updating the parameters more than needed in the apply. In 12933f6 I removed the last split and update just thinking we did the update just before calling the model. It seems the test with batch-normalization passes now. What do you think, or could the error be coming from somewhere else 🙏 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I brought it back in d25de60 because you have been pointing out this in many times... I can't figure out the error in the tests ...

if mutable_holder is None:
numpyro_mutable(name + "$state", {"state": eager_other_state_dict})

state_dict = {"state": eager_other_state}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to use state_dict. We just want to update the mutable state via mutable_holder above and below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, makes sense, fixed in 18fe2a2

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, by above, I meant the logic to get the latest mutable state via numpyro.mutable call. By below, I meant the logic after model call. We should remove this line, rather than rename.

@juanitorduz
Copy link
Collaborator Author

@fehiepsi Thank you very much for all the fantastic feedback these last days! I have learned a lot! Here is a summary on where we are:

  • All comments are addressed except for Flax NNX integration #1990 (comment) which is about a failing test and maybe a hidden bug. I would love some feedback (see comments in the thread there).
  • I have updated the example notebook to use both the nnx_module and the random_nnx_module. Both are working as expected 😄 ! (Unless I missed something).
  • I also build the docs locally and the nnx function and the examples are displayed as expected 🙌 .

I think we are very close to the finish line (I hope XD)

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tutorials look awesome! I just catch some small nits. Thanks, Juan!

assert "b" in samples["nn$params"]


def test_nnx_transformer_module():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this now.

@juanitorduz
Copy link
Collaborator Author

Thanks @fehiepsi ! All of the remaining comments were addressed in Done e9d6b83 ! 🙏

@fehiepsi fehiepsi merged commit f5ae79b into pyro-ppl:master Mar 10, 2025
10 checks passed
@juanitorduz juanitorduz deleted the nnx_integration branch March 10, 2025 19:37
@juanitorduz
Copy link
Collaborator Author

fyi @cgarciae

@juanitorduz
Copy link
Collaborator Author

Thank you very much for your guidance (and patience) @fehiepsi !

) from e

graph_def, eager_params_state, eager_other_state = nnx.split(
nn_module, nnx.Param, nnx.Not(nnx.Param)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is merged but if you want to slightly simplify this in the future you can use ... (Ellipsis) to match anything else e.g:

graphdef, params, rest = nnx.split(model, nnx.Param, ...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FR] Update the flax module to adapt flax new nnx api

3 participants