- 
                Notifications
    You must be signed in to change notification settings 
- Fork 270
Refactoring SteinVI #1883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring SteinVI #1883
Conversation
| @fehiepsi Let me know if you want me to split it into multiple PRs. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending minor comments.
|  | ||
| import jax | ||
| from jax import numpy as jnp, random, vmap | ||
| from jax.tree_util import tree_flatten, tree_map | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is deprecated. You can use jax.tree.flatten and jax.tree.map now.
        
          
                numpyro/contrib/einstein/steinvi.py
              
                Outdated
          
        
      | from jax import grad, jacfwd, numpy as jnp, random, vmap | ||
| from jax import grad, numpy as jnp, random, vmap | ||
| from jax.flatten_util import ravel_pytree | ||
| from jax.tree_util import tree_map | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
|  | ||
| import jax | ||
| from jax import numpy as jnp | ||
| from jax.tree_util import tree_flatten, tree_map | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
| Thanks! I'll change it to  | 
SVGD has been extended to a few varieties (ASVGD, GSVGD, HSVGD) which I want to include in the einstein module. This PR consists of the changes to
SteinVIto allow for the extensions and introduction of SVGD and ASVGD. GSVGD and HSVGD will be in subsequent PRs.To this end, I've:
1/mformparticles.num_elbo_particlesbecause it's always 1 for SVGD.setup_runmethod to SteinVI.setup_runto change theloss_temperature.AutoIAFNormal,AutoBNAFNormal,AutoDAIS,AutoSemiDAISandAutoSurrogateLikelihoodDAIS.Misc changes
ProbabilityProductKernelhas been removed. The kernel is still a proper kernel; however, this version avoids vanishing/exploding when the guide variances deviate from 1 for "high" dimensional models.enumparameter as it is currently unsupported.TODO
Fix SteinVI documentationI'll do this in a separate PR.