diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 71c7d4378677f..d05f91d7e3b12 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -415,6 +415,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); /// The returned module is null when the input operation was not a ModuleOp. MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); +MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs); + //===----------------------------------------------------------------------===// // Operation state. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 4b3a06cbce854..67bd6fad58637 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; +static const char kModuleCAPICreate[] = + R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr). +Note this returns a new object BUT _clear_mlir_module(module) must be called to +prevent double-frees (of the underlying mlir::Module). +)"; + static const char kOperationCreateDocstring[] = R"(Creates a new operation. @@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } -size_t PyMlirContext::getLiveOperationCount() { - nb::ft_lock_guard lock(liveOperationsMutex); - return liveOperations.size(); -} - -std::vector PyMlirContext::getLiveOperationObjects() { - std::vector liveObjects; - nb::ft_lock_guard lock(liveOperationsMutex); - for (auto &entry : liveOperations) - liveObjects.push_back(entry.second.second); - return liveObjects; -} - -size_t PyMlirContext::clearLiveOperations() { - - LiveOperationMap operations; - { - nb::ft_lock_guard lock(liveOperationsMutex); - std::swap(operations, liveOperations); - } - for (auto &op : operations) - op.second.second->setInvalid(); - size_t numInvalidated = operations.size(); - return numInvalidated; -} - -void PyMlirContext::clearOperation(MlirOperation op) { - PyOperation *py_op; - { - nb::ft_lock_guard lock(liveOperationsMutex); - auto it = liveOperations.find(op.ptr); - if (it == liveOperations.end()) { - return; - } - py_op = it->second.second; - liveOperations.erase(it); - } - py_op->setInvalid(); -} - -void PyMlirContext::clearOperationsInside(PyOperationBase &op) { - typedef struct { - PyOperation &rootOp; - bool rootSeen; - } callBackData; - callBackData data{op.getOperation(), false}; - // Mark all ops below the op that the passmanager will be rooted - // at (but not op itself - note the preorder) as invalid. - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - callBackData *data = static_cast(userData); - if (LLVM_LIKELY(data->rootSeen)) - data->rootOp.getOperation().getContext()->clearOperation(op); - else - data->rootSeen = true; - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - static_cast(&data), MlirWalkPreOrder); -} -void PyMlirContext::clearOperationsInside(MlirOperation op) { - PyOperationRef opRef = PyOperation::forOperation(getRef(), op); - clearOperationsInside(opRef->getOperation()); -} - -void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - PyMlirContextRef &contextRef = *static_cast(userData); - contextRef->clearOperation(op); - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - &op.getOperation().getContext(), MlirWalkPreOrder); -} - -size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } - nb::object PyMlirContext::contextEnter(nb::object context) { return PyThreadContextEntry::pushContext(context); } @@ -1151,38 +1079,20 @@ PyLocation &DefaultingPyLocation::resolve() { PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} -PyModule::~PyModule() { - nb::gil_scoped_acquire acquire; - auto &liveModules = getContext()->liveModules; - assert(liveModules.count(module.ptr) == 1 && - "destroying module not in live map"); - liveModules.erase(module.ptr); - mlirModuleDestroy(module); -} +PyModule::~PyModule() { mlirModuleDestroy(module); } PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - nb::gil_scoped_acquire acquire; - auto &liveModules = contextRef->liveModules; - auto it = liveModules.find(module.ptr); - if (it == liveModules.end()) { - // Create. - PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); - unownedModule->handle = pyRef; - liveModules[module.ptr] = - std::make_pair(unownedModule->handle, unownedModule); - return PyModuleRef(unownedModule, std::move(pyRef)); - } - // Use existing. - PyModule *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyModuleRef(existing, std::move(pyRef)); + // Create. + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); + unownedModule->handle = pyRef; + return PyModuleRef(unownedModule, std::move(pyRef)); } nb::object PyModule::createFromCapsule(nb::object capsule) { @@ -1207,15 +1117,11 @@ PyOperation::~PyOperation() { // If the operation has already been invalidated there is nothing to do. if (!valid) return; - - // Otherwise, invalidate the operation and remove it from live map when it is - // attached. - if (isAttached()) { - getContext()->clearOperation(*this); - } else { - // And destroy it when it is detached, i.e. owned by Python, in which case - // all nested operations must be invalidated at removed from the live map as - // well. + // Otherwise, invalidate the operation when it is attached. + if (isAttached()) + setInvalid(); + else { + // And destroy it when it is detached, i.e. owned by Python. erase(); } } @@ -1252,35 +1158,16 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - auto it = liveOperations.find(operation.ptr); - if (it == liveOperations.end()) { - // Create. - PyOperationRef result = createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(result.getObject(), result.get()); - return result; - } - // Use existing. - PyOperation *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyOperationRef(existing, std::move(pyRef)); + // Create. + return createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - assert(liveOperations.count(operation.ptr) == 0 && - "cannot create detached operation that already exists"); - (void)liveOperations; PyOperationRef created = createInstance(std::move(contextRef), operation, std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(created.getObject(), created.get()); created->attached = false; return created; } @@ -1652,7 +1539,7 @@ nb::object PyOperation::createOpView() { void PyOperation::erase() { checkValid(); - getContext()->clearOperationAndInside(*this); + setInvalid(); mlirOperationDestroy(operation); } @@ -3023,14 +2910,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyMlirContextRef ref = PyMlirContext::forContext(self.get()); return ref.releaseObject(); }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_operation_objects", - &PyMlirContext::getLiveOperationObjects) - .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) - .def("_clear_live_operations_inside", - nb::overload_cast( - &PyMlirContext::clearOperationsInside)) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) @@ -3348,7 +3227,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "Module", nb::is_weak_referenceable()) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, + kModuleCAPICreate) + .def("_clear_mlir_module", &PyModule::clearMlirModule) .def_static( "parse", [](const std::string &moduleAsm, DefaultingPyMlirContext context) { @@ -3428,7 +3309,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, - kOperationStrDunderDocstring); + kOperationStrDunderDocstring) + .def( + "__eq__", + [](PyModule &self, PyModule &other) { + return mlirModuleEqual(self.get(), other.get()); + }, + "other"_a); //---------------------------------------------------------------------------- // Mapping of Operation. @@ -3440,7 +3327,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { - return &self.getOperation() == &other.getOperation(); + return mlirOperationEqual(self.getOperation().get(), + other.getOperation().get()); }) .def("__eq__", [](PyOperationBase &self, nb::object other) { return false; }) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index fa16ae3ce3294..c1fdfd64ee1e7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -218,40 +218,6 @@ class PyMlirContext { /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); - /// Get a list of Python objects which are still in the live context map. - std::vector getLiveOperationObjects(); - - /// Gets the count of live operations associated with this context. - /// Used for testing. - size_t getLiveOperationCount(); - - /// Clears the live operations map, returning the number of entries which were - /// invalidated. To be used as a safety mechanism so that API end-users can't - /// corrupt by holding references they shouldn't have accessed in the first - /// place. - size_t clearLiveOperations(); - - /// Removes an operation from the live operations map and sets it invalid. - /// This is useful for when some non-bindings code destroys the operation and - /// the bindings need to made aware. For example, in the case when pass - /// manager is run. - /// - /// Note that this does *NOT* clear the nested operations. - void clearOperation(MlirOperation op); - - /// Clears all operations nested inside the given op using - /// `clearOperation(MlirOperation)`. - void clearOperationsInside(PyOperationBase &op); - void clearOperationsInside(MlirOperation op); - - /// Clears the operaiton _and_ all operations inside using - /// `clearOperation(MlirOperation)`. - void clearOperationAndInside(PyOperationBase &op); - - /// Gets the count of live modules associated with this context. - /// Used for testing. - size_t getLiveModuleCount(); - /// Enter and exit the context manager. static nanobind::object contextEnter(nanobind::object context); void contextExit(const nanobind::object &excType, @@ -278,25 +244,6 @@ class PyMlirContext { static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); - // Interns all live modules associated with this context. Modules tracked - // in this map are valid. When a module is invalidated, it is removed - // from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveModuleMap = - llvm::DenseMap>; - LiveModuleMap liveModules; - - // Interns all live operations associated with this context. Operations - // tracked in this map are valid. When an operation is invalidated, it is - // removed from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveOperationMap = - llvm::DenseMap>; - nanobind::ft_mutex liveOperationsMutex; - - // Guarded by liveOperationsMutex in free-threading mode. - LiveOperationMap liveOperations; - bool emitErrorDiagnostics = false; MlirContext context; @@ -548,8 +495,8 @@ class PyModule; using PyModuleRef = PyObjectRef; class PyModule : public BaseContextObject { public: - /// Returns a PyModule reference for the given MlirModule. This may return - /// a pre-existing or new object. + /// Returns a PyModule reference for the given MlirModule. This always returns + /// a new object. static PyModuleRef forModule(MlirModule module); PyModule(PyModule &) = delete; PyModule(PyMlirContext &&) = delete; @@ -570,11 +517,12 @@ class PyModule : public BaseContextObject { nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. - /// Note that PyModule instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirModule - /// is taken by calling this function. + /// Note this returns a new object BUT clearMlirModule() must be called to + /// prevent double-frees (of the underlying mlir::Module). static nanobind::object createFromCapsule(nanobind::object capsule); + void clearMlirModule() { module = {nullptr}; } + private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 20017e25b69bb..817479ee2421b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyOperationBase &op, - bool invalidateOps) { - if (invalidateOps) { - op.getOperation().getContext()->clearOperationsInside(op); - } + [](PyPassManager &passManager, PyOperationBase &op) { // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( @@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - "operation"_a, "invalidate_ops"_a = true, + "operation"_a, "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index f9b0fed62778f..920bca886f617 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) { // root. This is awkward, but we don't have access to PyMlirContext // object here otherwise. nb::object obj = nb::cast(payloadRoot); - obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); MlirLogicalResult result = mlirTransformApplyNamedSequence( payloadRoot, transformRoot, transformModule, options.options); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8491553dab76f..c7069f0017b5d 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) { return wrap(dyn_cast(unwrap(op))); } +bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) { + return unwrap(lhs) == unwrap(rhs); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index 6065e59fd6ed9..a552eaa662af4 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -121,27 +121,17 @@ def testRoundtripBinary(): def testModuleOperation(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 op1 = module.operation - assert ctx._get_live_operation_count() == 1 - live_ops = ctx._get_live_operation_objects() - assert len(live_ops) == 1 - assert live_ops[0] is op1 - live_ops = None # CHECK: module @successfulParse print(op1) # Ensure that operations are the same on multiple calls. op2 = module.operation - assert ctx._get_live_operation_count() == 1 - assert op1 is op2 + assert not op1 is op2 + assert op1 == op2 # Test live operation clearing. op1 = module.operation - assert ctx._get_live_operation_count() == 1 - num_invalidated = ctx._clear_live_operations() - assert num_invalidated == 1 - assert ctx._get_live_operation_count() == 0 op1 = None gc.collect() op1 = module.operation @@ -155,9 +145,6 @@ def testModuleOperation(): op1 = None op2 = None gc.collect() - print("LIVE OPERATIONS:", ctx._get_live_operation_count()) - assert ctx._get_live_operation_count() == 0 - assert ctx._get_live_module_count() == 0 # CHECK-LABEL: TEST: testModuleCapsule @@ -165,16 +152,17 @@ def testModuleOperation(): def testModuleCapsule(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 # CHECK: "mlir.ir.Module._CAPIPtr" module_capsule = module._CAPIPtr print(module_capsule) module_dup = Module._CAPICreate(module_capsule) - assert module is module_dup + assert not module is module_dup + assert module == module_dup + module._clear_mlir_module() + assert not module == module_dup assert module_dup.context is ctx # Gc and verify destructed. module = None module_capsule = None module_dup = None gc.collect() - assert ctx._get_live_module_count() == 0 diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index bf16e3f75d60d..94f39c0fbd077 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -907,7 +907,13 @@ def testCapsuleConversions(): m_capsule = m._CAPIPtr assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) m2 = Operation._CAPICreate(m_capsule) - assert m2 is m + assert not m2 is m + assert m2 == m + # Gc and verify destructed. + m = None + m_capsule = None + m2 = None + gc.collect() # CHECK-LABEL: TEST: testOperationErase diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py index 8b6d7ea5a197d..7afd539271d21 100644 --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -56,14 +56,6 @@ def testSymbolTableInsert(): print(m1) assert "bar" not in symbol_table - try: - print(bar) - except RuntimeError as e: - if "the operation has been invalidated" not in str(e): - raise - else: - assert False, "expected RuntimeError due to invalidated operation" - qux = m2.body.operations[0] m1.body.append(qux) symbol_table.insert(qux) diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index e26d42bb32913..0896cd9784641 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -176,14 +176,6 @@ def testRunPipelineError(): @run def testPostPassOpInvalidation(): with Context() as ctx: - log_op_count = lambda: log("live ops:", ctx._get_live_operation_count()) - - # CHECK: invalidate_ops=False - log("invalidate_ops=False") - - # CHECK: live ops: 0 - log_op_count() - module = ModuleOp.parse( """ module { @@ -196,9 +188,6 @@ def testPostPassOpInvalidation(): """ ) - # CHECK: live ops: 1 - log_op_count() - outer_const_op = module.body.operations[0] # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64 log(outer_const_op) @@ -214,12 +203,7 @@ def testPostPassOpInvalidation(): # CHECK: %[[VAL1]] = arith.constant 10 : i64 log(inner_const_op) - # CHECK: live ops: 4 - log_op_count() - - PassManager.parse("builtin.module(canonicalize)").run( - module, invalidate_ops=False - ) + PassManager.parse("builtin.module(canonicalize)").run(module) # CHECK: func.func @foo() { # CHECK: return # CHECK: } @@ -233,9 +217,6 @@ def testPostPassOpInvalidation(): # CHECK: invalidate_ops=True log("invalidate_ops=True") - # CHECK: live ops: 4 - log_op_count() - module = ModuleOp.parse( """ module { @@ -247,36 +228,9 @@ def testPostPassOpInvalidation(): } """ ) - outer_const_op = module.body.operations[0] - func_op = module.body.operations[1] - inner_const_op = func_op.body.blocks[0].operations[0] - - # CHECK: live ops: 4 - log_op_count() PassManager.parse("builtin.module(canonicalize)").run(module) - # CHECK: live ops: 1 - log_op_count() - - try: - log(func_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - - try: - log(outer_const_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - - try: - log(inner_const_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - # CHECK: func.func @foo() { # CHECK: return # CHECK: }