Skip to content

[mlir][MemRef] Address TODO to use early_inc to simplify elimination of uses #155123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

snarang181
Copy link
Contributor

No description provided.

@snarang181 snarang181 marked this pull request as ready for review August 23, 2025 23:22
@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2025

@llvm/pr-subscribers-mlir

Author: Samarth Narang (snarang181)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155123.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp (+27-43)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 5d3cec402cab1..860384f954536 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) {
 /// propagate the type change and erase old subview ops.
 static void replaceUsesAndPropagateType(RewriterBase &rewriter,
                                         Operation *oldOp, Value val) {
-  SmallVector<Operation *> opsToDelete;
-  SmallVector<OpOperand *> operandsToReplace;
-
-  // Save the operand to replace / delete later (avoid iterator invalidation).
-  // TODO: can we use an early_inc iterator?
-  for (OpOperand &use : oldOp->getUses()) {
-    // Non-subview ops will be replaced by `val`.
-    auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
-    if (!subviewUse) {
-      operandsToReplace.push_back(&use);
+  // Iterate with early_inc to erase current user inside the loop.
+  for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
+    Operation *user = use.getOwner();
+    if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
+      // `subview(old_op)` is replaced by a new `subview(val)`.
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(subviewUse);
+      MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
+          subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
+          subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
+          subviewUse.getStaticStrides());
+      Value newSubview = memref::SubViewOp::create(
+          rewriter, subviewUse->getLoc(), newType, val,
+          subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+          subviewUse.getMixedStrides());
+
+      // Ouch recursion ... is this really necessary?
+      replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
+
+      // Safe to erase.
+      rewriter.eraseOp(subviewUse);
       continue;
     }
-
-    // `subview(old_op)` is replaced by a new `subview(val)`.
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(subviewUse);
-    MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
-        subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
-        subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
-        subviewUse.getStaticStrides());
-    Value newSubview = memref::SubViewOp::create(
-        rewriter, subviewUse->getLoc(), newType, val,
-        subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
-        subviewUse.getMixedStrides());
-
-    // Ouch recursion ... is this really necessary?
-    replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
-
-    opsToDelete.push_back(use.getOwner());
+    // Non-subview: replace with new value.
+    rewriter.startOpModification(user);
+    use.set(val);
+    rewriter.finalizeOpModification(user);
   }
-
-  // Perform late replacement.
-  // TODO: can we use an early_inc iterator?
-  for (OpOperand *operand : operandsToReplace) {
-    Operation *op = operand->getOwner();
-    rewriter.startOpModification(op);
-    operand->set(val);
-    rewriter.finalizeOpModification(op);
-  }
-
-  // Perform late op erasure.
-  // TODO: can we use an early_inc iterator?
-  for (Operation *op : opsToDelete)
-    rewriter.eraseOp(op);
 }
 
 // Transformation to do multi-buffering/array expansion to remove dependencies
@@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
                                             offsets, sizes, strides);
   LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
 
-  // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
-  // handle dealloc uses separately..
+  // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
+  // to handle dealloc uses separately..
   for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
     auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
     if (!deallocOp)

@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2025

@llvm/pr-subscribers-mlir-memref

Author: Samarth Narang (snarang181)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155123.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp (+27-43)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 5d3cec402cab1..860384f954536 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) {
 /// propagate the type change and erase old subview ops.
 static void replaceUsesAndPropagateType(RewriterBase &rewriter,
                                         Operation *oldOp, Value val) {
-  SmallVector<Operation *> opsToDelete;
-  SmallVector<OpOperand *> operandsToReplace;
-
-  // Save the operand to replace / delete later (avoid iterator invalidation).
-  // TODO: can we use an early_inc iterator?
-  for (OpOperand &use : oldOp->getUses()) {
-    // Non-subview ops will be replaced by `val`.
-    auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
-    if (!subviewUse) {
-      operandsToReplace.push_back(&use);
+  // Iterate with early_inc to erase current user inside the loop.
+  for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
+    Operation *user = use.getOwner();
+    if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) {
+      // `subview(old_op)` is replaced by a new `subview(val)`.
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(subviewUse);
+      MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
+          subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
+          subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
+          subviewUse.getStaticStrides());
+      Value newSubview = memref::SubViewOp::create(
+          rewriter, subviewUse->getLoc(), newType, val,
+          subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+          subviewUse.getMixedStrides());
+
+      // Ouch recursion ... is this really necessary?
+      replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
+
+      // Safe to erase.
+      rewriter.eraseOp(subviewUse);
       continue;
     }
-
-    // `subview(old_op)` is replaced by a new `subview(val)`.
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(subviewUse);
-    MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
-        subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
-        subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
-        subviewUse.getStaticStrides());
-    Value newSubview = memref::SubViewOp::create(
-        rewriter, subviewUse->getLoc(), newType, val,
-        subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
-        subviewUse.getMixedStrides());
-
-    // Ouch recursion ... is this really necessary?
-    replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
-
-    opsToDelete.push_back(use.getOwner());
+    // Non-subview: replace with new value.
+    rewriter.startOpModification(user);
+    use.set(val);
+    rewriter.finalizeOpModification(user);
   }
-
-  // Perform late replacement.
-  // TODO: can we use an early_inc iterator?
-  for (OpOperand *operand : operandsToReplace) {
-    Operation *op = operand->getOwner();
-    rewriter.startOpModification(op);
-    operand->set(val);
-    rewriter.finalizeOpModification(op);
-  }
-
-  // Perform late op erasure.
-  // TODO: can we use an early_inc iterator?
-  for (Operation *op : opsToDelete)
-    rewriter.eraseOp(op);
 }
 
 // Transformation to do multi-buffering/array expansion to remove dependencies
@@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
                                             offsets, sizes, strides);
   LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
 
-  // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to
-  // handle dealloc uses separately..
+  // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
+  // to handle dealloc uses separately..
   for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
     auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
     if (!deallocOp)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants