Skip to content

Conversation

@juanitorduz
Copy link
Collaborator

Closes #1726

Trying here a "good first issue".

@juanitorduz juanitorduz marked this pull request as draft February 7, 2024 20:31
@fehiepsi
Copy link
Member

fehiepsi commented Feb 7, 2024

Looks great to me. Could you add a simple test rwith while loop in the model?

@juanitorduz
Copy link
Collaborator Author

I added a simple test test/infer/test_svi.py::test_forward_mode_differentiation in 532fb5b but it is failing with

    @functools.wraps(update)
    def tree_update(i, grad_tree, opt_state):
      states_flat, tree, subtrees = opt_state
      grad_flat, tree2 = tree_flatten(grad_tree)
      if tree2 != tree:
        msg = ("optimizer update function was passed a gradient tree that did "
               "not match the parameter tree structure with which it was "
               "initialized: parameter tree {} and grad tree {}.")
>       raise TypeError(msg.format(tree, tree2))
E       TypeError: optimizer update function was passed a gradient tree that did not match the parameter tree structure with which it was initialized: parameter tree PyTreeDef({'loc': *, 'scale': *}) and grad tree PyTreeDef(({'loc': *, 'scale': *}, None))

And I am not sure if its because of the implementation of the test is wrong 😑 . Any tips? Thanks

numpyro/optim.py Outdated

def _value_and_grad(f, x, forward_mode_differentiation=False):
if forward_mode_differentiation:
return f(x), jacfwd(f, has_aux=True)(x)
Copy link
Member

@fehiepsi fehiepsi Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can set has_aux to False.

A faster way, which just does 1 forward pass, is to redefine f:

def wrapper(x):
  out, aux = f(x)
  return out, (out, aux)

grad, (out, aux) = jacfwd(wrapper, has_aux=True)(x)
return (out, aux), grad

We can also apply this trick in hmc forward mode implementation.

Copy link
Collaborator Author

@juanitorduz juanitorduz Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! So I added your suggestion in ef3a3f7, and now I am getting an error downstream 🤔 . Do I need to unpack somewhere further?

        if progress_bar:
            losses = []
            with tqdm.trange(1, num_steps + 1) as t:
                batch = max(num_steps // 20, 1)
                for i in t:
                    svi_state, loss = jit(body_fn)(svi_state, None)
                    losses.append(jax.device_get(loss))
                    if i % batch == 0:
                        if stable_update:
                            valid_losses = [x for x in losses[i - batch :] if x == x]
                            num_valid = len(valid_losses)
                            if num_valid == 0:
                                avg_loss = float("nan")
                            else:
                                avg_loss = sum(valid_losses) / num_valid
                        else:
>                           avg_loss = sum(losses[i - batch :]) / batch
E                           TypeError: unsupported operand type(s) for +: 'int' and 'dict'

numpyro/infer/svi.py:409: TypeError

Copy link
Member

@fehiepsi fehiepsi Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, it should be jacfwd of wrapper, not f

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It worked in a64c092 !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same trick in 1259db7 for hmc

@juanitorduz juanitorduz requested a review from fehiepsi February 8, 2024 12:31
@juanitorduz juanitorduz marked this pull request as ready for review February 8, 2024 15:18
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woohoo, thanks for supporting this feature!

@fehiepsi fehiepsi merged commit aec6bd5 into pyro-ppl:master Feb 8, 2024
@juanitorduz
Copy link
Collaborator Author

Thank you for your guidance @fehiepsi 🙏🙂

@juanitorduz juanitorduz deleted the issue_726 branch April 30, 2024 19:03
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* allow forward pass

* fix params

* add missing param in docstring

* add flag to svi

* typo docs

* reorder arguments

* order args

* kw argument internal function

* add arg to minimize

* decouple aux function

* rm kw argument unused

* nicer doctrings

* simple test

* add wrapper

* fix wrapper order

* add wrapper trick to hmc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support forward mode differentiation for SVI

2 participants