Skip to content

Enable argument type population of torch models #1073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

LPanosTT
Copy link
Contributor

@LPanosTT LPanosTT commented Aug 21, 2025

Ticket

#1041, #856

Problem description

  • We cannot populate the argument types for torch models.

What's changed

  • Added custom torch operator torch.ops.tt.mark_argument_attributes which applies a stablehlo.custom_call to a tensor :
    • This operator returns the input tensor.
    • This operator has attributes "argument_type" and "name". (Name is optional)
    • This operation can ONLY be run on an xla device, PyTorch will raise an error if the user attempts to execute this operation on a non-xla device.
    • This op can be captured by the dynamo and is a valid inside a GraphModule
  • Added a pass to populate the argument types of the SHLO module using this custom call
    • This pass does something special, If at least one argument is populated as an "input", then all un-populated arguments will be automatically defined as "constant".
      • This works under the assumption that every user input is annotated. Users will not be expected to populate the argument types themselves. It will either be done by our torch.compile backend or by the utility function we provide: mark_model_user_inputs.
  • Edited propagateRoleAttribute to simply populate the ttcore.argument_type arg attribute in the function rather than tt.input_role. This way we can bypass PopulateArgumentTypes in tt-mlir entirely as the graph is already populated.
  • Deleted the environment variable override for populating argument types as it should no longer be needed.

Checklist

  • New/Existing tests provide coverage for changes

I didn't add a test because we will not really be able to take advantage of this until the tt_torch backend is merged: #954, I will add a test after both of these are merged.

I did add an example though, using the "openxla" backend.

@codecov-commenter
Copy link

codecov-commenter commented Aug 21, 2025

Codecov Report

❌ Patch coverage is 54.02299% with 40 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.04%. Comparing base (3a65113) to head (9999f40).

Files with missing lines Patch % Lines
...der/frontend_passes/shlo_input_role_propagation.cc 53.48% 40 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1073      +/-   ##
==========================================
- Coverage   69.61%   68.04%   -1.57%     
==========================================
  Files          25       25              
  Lines        1833     1909      +76     
==========================================
+ Hits         1276     1299      +23     
- Misses        557      610      +53     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants