Skip to content

Introduce ORTSessionMixin and enable general io binding (works for diffusers as well) #2234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 58 commits into from
May 18, 2025

Conversation

IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Apr 19, 2025

What does this PR do?

This PR tries to further extend the io binding feature to multi part pipelines and enables it more generally through two classes:

  • ORTSessionMixin: should be used whenever a class has an underlying inference session, it enables inference with and without io/binding, and handles the torch/onnxruntime interface like provider/device/dtype.

  • ORTSessionsWrapper: should be used when a class has many models/parts (like encoder-decoders or diffusion pipelines), it usinfies the interface to these parts without compromising flexibility, for example one on can set pipeline.use_io_binding to True/False which will set it to all parts, but can also target a specific part pipeline.unet.use_io_binding differently. Parts can also be on different devices.

I also added a special feature to diffusion pipelines and especially Unet/Tranformers component so that it reuses the same input tensor as output buffer (instead of creating a new one), this makes the diffusion operation more of an in-place operation, which should make it faster (theoretically), or at least uses less memory.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil changed the title Refactor ort session mixin and enable integral and simple diffusers io binding Refactor ort session mixin and enable general io binding (works for diffusers) Apr 29, 2025
@IlyasMoutawwakil IlyasMoutawwakil requested a review from Copilot May 16, 2025 09:37
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR extends the IO binding functionality for multi‐part pipelines by introducing new mixin classes (ORTSessionMixin and ORTSessionsWrapper) and refactoring existing tests and utility functions for both diffusion and non‐diffusion workflows.

  • Updates tests for onnxruntime modeling and diffusion pipelines to use updated function names and parameters.
  • Refactors utility and exporter modules to adopt new patterns for provider and session handling, and removes deprecated or redundant code.

Reviewed Changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/onnxruntime/test_modeling.py Removed deprecated parameters and updated assertions to reflect changes in the IO binding interface and model naming.
tests/onnxruntime/test_diffusion.py Renamed helper functions and added tests for IO binding functionality with diffusion pipelines.
optimum/utils/constant.py Added and reorganized constants used across the codebase.
optimum/onnxruntime/utils.py Refactored provider handling and added a new helper for inferring a session’s dtype.
optimum/onnxruntime/modeling_decoder.py Updated deprecated code patterns to use instance configuration attributes consistently and improved IO binding calls.
Other files (e.g., optimization.py, exporters tasks, workflows) Similar refactoring to align with the new multi-part pipeline design and clean up unused code.
Comments suppressed due to low confidence (3)

tests/onnxruntime/test_modeling.py:140

  • The 'use_io_binding' parameter has been removed in favor of a default behavior. Please ensure that this API change is clearly documented in the migration guide to avoid confusion among users.
model = model_cls.from_pretrained(self.LOCAL_MODEL_PATH, use_cache=False, use_io_binding=False)

optimum/onnxruntime/modeling_decoder.py:249

  • Updating the condition to use 'self.config' instead of relying on a local 'config' variable improves maintainability. Confirm that similar updates are applied consistently throughout the file.
if self.config.model_type != "gpt_bigcode":

optimum/onnxruntime/utils.py:446

  • [nitpick] Consider adding tests and a detailed docstring example for 'get_dtype_from_session' to verify that the correct torch dtype is returned based on the ONNX input/output types.
def get_dtype_from_session(session: ort.InferenceSession) -> torch.dtype:

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks a lot for the great PR @IlyasMoutawwakil 🔥

self.output_dtypes = {output.name: output.type for output in session.get_outputs()}

self.model_path = Path(session._model_path)
self.model_name = self.model_path.name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes talking about model_name which seems to be used in the tests (and can be easily extracted from model_path) so we can remove imo, not specific to this PR though so feel free to close!

@IlyasMoutawwakil IlyasMoutawwakil merged commit 16618fc into main May 18, 2025
32 of 33 checks passed
@IlyasMoutawwakil IlyasMoutawwakil deleted the ort-session-mixin branch May 18, 2025 08:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants