-
Notifications
You must be signed in to change notification settings - Fork 270
Flax NNX integration #1990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax NNX integration #1990
Conversation
e6cd827 to
b612512
Compare
|
There is a test failing that is unrelated with this change FAILED test/infer/test_autoguide.py::test_laplace_approximation_warning - Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
Emitted warnings: [].
FAILED test/test_transforms.py::test_bijective_transforms[transform9-shape9] - assert Array(False, dtype=bool)
FAILED test/test_distributions.py::test_entropy_categorical - AssertionError:
Not equal to tolerance rtol=1e-07, atol=0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @juanitorduz! I left a few comments. I think you can use https://flax.readthedocs.io/en/latest/guides/filters_guide.html#the-filter-dsl to extract (params, rngs, mutable) stuffs. And https://flax.readthedocs.io/en/latest/guides/surgery.html to replace params/state.
(I also feel that rng is a mutable state under nnx https://flax.readthedocs.io/en/latest/guides/randomness.html)
|
Thanks for the comments @fehiepsi I will look into them 🙏 |
init fix fix fix lint rm files rm files rm files rm files rm files
29b9da0 to
1270561
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @juanitorduz, I marked this as milestone 0.18 but there is no need to rush. Looks at the implementation, it seems to me that it's painful to deal with nnx (surprising? not sure why flax team is switching to use it given the complexity). If possible, try to come up with a solution similar to flax/haiku. If complicated, then it's not worth to support.
I would suggest to reach out to nnx devs if there's something unclear. Two main questions are
- how to get the initial parameters and the initial mutable states of nn module
- given parameters and states, how to get output and the updated states
We need params stored as a nested dict. For the state, if it is not a dict, you can just store as a dict via {"state": state}.
|
Thanks @fehiepsi ! Indeed it has been harder than expected, but it also reflects my little experience with the NNX internals 😄. My aim is to simplify indeed! I'll use this branch to push my progress (experiments). I'll let you know whenever I feel really blocked or whenever I need a review. I will also try to reach the nnx team for some tips :) It would be super nice to have this integration to make sure the mid-long term integration with the neural network JAX community :) |
d6bd250 to
2b89566
Compare
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
@fehiepsi I simplified the code a lot (It was a mess, I know 🙈 ), I think it is now more modular (thank to your advice) and looks similar as the flax case. In 2b89566 you can see all tests pass (I had to write some custom code to make this work for Python 3.9). I think it is in a much better shape if you wanna take a look (most of the new lines of code are actually tests). Moreover, I am working on a notebook example (wip) adapting a simple example (https://juanitorduz.github.io/flax_numpyro/) and showing how to use NNX + NumPyro. I hope to clean the notebook and add the relevant content this weekend :) |
Nice tutorial, looking forward to seeing the nnx version! :) I think the implementation could be even simpler. Let's see. ;) |
|
Thanks for the feedback @fehiepsi ! I will try to simplify even more based on your input 🙏 |
eager inizialization approach rm unused code 1 rm unused code 2 rm code rm code rm code rm code
|
Minor comment: When re-running the notebook again I see 0%| | 0/20000 [00:00<?, ?it/s]/Users/juanitorduz/Documents/numpyro/numpyro/infer/elbo.py:253: UserWarning: mutable state is currently ignored when num_particles > 1.Is this expected? (the results do not change tho) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is expected. It's not clear how mutable states are updated with num_particles > 1. For dropout rng, we need a special logic to vectorize split. For batch_stats, we need sequentially updates then aggregates. I don't know - it is tricky to think about a reasonable solution for the issue.
I would suggest to use num_particles=1 if batch norm plays a key role in your network.
numpyro/contrib/module.py
Outdated
| if params: | ||
| graph_def, state = nnx.split(model) | ||
|
|
||
| if params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comment above - you don't need to split the model again
params_state = eager_params_state
if params:
nnx.replace_by_pure_dict(params_state, params)
other_state = eager_other_state
if mutable_holder:
nnx.replace_by_pure_dict(other_state, mutable_holder["state"])
model = nnx.merge(graph_def, params_state, other_state)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to use the latest mutable state, not just the latest params.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mmm I am a bit confused: I tried this in 17db8c1 and I get from the tests
FAILED test/contrib/test_module.py::test_nnx_state_dropout_smoke[True-True] - TypeError: unsupported operand type(s) for *: 'float' and 'VariableState'
FAILED test/contrib/test_module.py::test_nnx_state_dropout_smoke[True-False] - TypeError: unsupported operand type(s) for *: 'float' and 'VariableState'Is if ok that the apply function has access to parameters outside of it like eager_params_state and mutable_holder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it is due to the line mutable_holder["state"] = eager_other_state above. We can remove that line because we already have the logic to get the latest mutable state from the numpyro.mutable call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we are updating the parameters more than needed in the apply. In 12933f6 I removed the last split and update just thinking we did the update just before calling the model. It seems the test with batch-normalization passes now. What do you think, or could the error be coming from somewhere else 🙏 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I brought it back in d25de60 because you have been pointing out this in many times... I can't figure out the error in the tests ...
numpyro/contrib/module.py
Outdated
| if mutable_holder is None: | ||
| numpyro_mutable(name + "$state", {"state": eager_other_state_dict}) | ||
|
|
||
| state_dict = {"state": eager_other_state} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to use state_dict. We just want to update the mutable state via mutable_holder above and below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, makes sense, fixed in 18fe2a2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, by above, I meant the logic to get the latest mutable state via numpyro.mutable call. By below, I meant the logic after model call. We should remove this line, rather than rename.
|
@fehiepsi Thank you very much for all the fantastic feedback these last days! I have learned a lot! Here is a summary on where we are:
I think we are very close to the finish line (I hope XD) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tutorials look awesome! I just catch some small nits. Thanks, Juan!
test/contrib/test_module.py
Outdated
| assert "b" in samples["nn$params"] | ||
|
|
||
|
|
||
| def test_nnx_transformer_module(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this now.
|
fyi @cgarciae |
|
Thank you very much for your guidance (and patience) @fehiepsi ! |
| ) from e | ||
|
|
||
| graph_def, eager_params_state, eager_other_state = nnx.split( | ||
| nn_module, nnx.Param, nnx.Not(nnx.Param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is merged but if you want to slightly simplify this in the future you can use ... (Ellipsis) to match anything else e.g:
graphdef, params, rest = nnx.split(model, nnx.Param, ...)
Closes #1987
I gave it a go to this issue with the help of Claude Sonnet 3.7 to get started and work on iterations to make it "work". So all feedback is welcome.