-
Notifications
You must be signed in to change notification settings - Fork 574
Introduce ORTSessionMixin and enable general io binding (works for diffusers as well) #2234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…ConditionalGeneration with positional arguments.
420f59c
to
3d7b976
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends the IO binding functionality for multi‐part pipelines by introducing new mixin classes (ORTSessionMixin and ORTSessionsWrapper) and refactoring existing tests and utility functions for both diffusion and non‐diffusion workflows.
- Updates tests for onnxruntime modeling and diffusion pipelines to use updated function names and parameters.
- Refactors utility and exporter modules to adopt new patterns for provider and session handling, and removes deprecated or redundant code.
Reviewed Changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
tests/onnxruntime/test_modeling.py | Removed deprecated parameters and updated assertions to reflect changes in the IO binding interface and model naming. |
tests/onnxruntime/test_diffusion.py | Renamed helper functions and added tests for IO binding functionality with diffusion pipelines. |
optimum/utils/constant.py | Added and reorganized constants used across the codebase. |
optimum/onnxruntime/utils.py | Refactored provider handling and added a new helper for inferring a session’s dtype. |
optimum/onnxruntime/modeling_decoder.py | Updated deprecated code patterns to use instance configuration attributes consistently and improved IO binding calls. |
Other files (e.g., optimization.py, exporters tasks, workflows) | Similar refactoring to align with the new multi-part pipeline design and clean up unused code. |
Comments suppressed due to low confidence (3)
tests/onnxruntime/test_modeling.py:140
- The 'use_io_binding' parameter has been removed in favor of a default behavior. Please ensure that this API change is clearly documented in the migration guide to avoid confusion among users.
model = model_cls.from_pretrained(self.LOCAL_MODEL_PATH, use_cache=False, use_io_binding=False)
optimum/onnxruntime/modeling_decoder.py:249
- Updating the condition to use 'self.config' instead of relying on a local 'config' variable improves maintainability. Confirm that similar updates are applied consistently throughout the file.
if self.config.model_type != "gpt_bigcode":
optimum/onnxruntime/utils.py:446
- [nitpick] Consider adding tests and a detailed docstring example for 'get_dtype_from_session' to verify that the correct torch dtype is returned based on the ONNX input/output types.
def get_dtype_from_session(session: ort.InferenceSession) -> torch.dtype:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks a lot for the great PR @IlyasMoutawwakil 🔥
optimum/onnxruntime/base.py
Outdated
self.output_dtypes = {output.name: output.type for output in session.get_outputs()} | ||
|
||
self.model_path = Path(session._model_path) | ||
self.model_name = self.model_path.name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes talking about model_name
which seems to be used in the tests (and can be easily extracted from model_path
) so we can remove imo, not specific to this PR though so feel free to close!
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
…timum into ort-session-mixin
… cuda device, with specific dtype, etc)
What does this PR do?
This PR tries to further extend the io binding feature to multi part pipelines and enables it more generally through two classes:
ORTSessionMixin
: should be used whenever a class has an underlying inference session, it enables inference with and without io/binding, and handles the torch/onnxruntime interface like provider/device/dtype.ORTSessionsWrapper
: should be used when a class has many models/parts (like encoder-decoders or diffusion pipelines), it usinfies the interface to these parts without compromising flexibility, for example one on can setpipeline.use_io_binding
to True/False which will set it to all parts, but can also target a specific partpipeline.unet.use_io_binding
differently. Parts can also be on different devices.I also added a special feature to diffusion pipelines and especially Unet/Tranformers component so that it reuses the same input tensor as output buffer (instead of creating a new one), this makes the diffusion operation more of an in-place operation, which should make it faster (theoretically), or at least uses less memory.
Before submitting
Who can review?