From 71bd6a1d6e06bf90ccf83ce698c9fee65fccfbe9 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Thu, 17 Jul 2025 14:52:30 +0800 Subject: [PATCH 1/2] [mlir][mesh] Add null check for dyn_cast to prevent crash This PR adds a null check for dyn_cast to prevent crash, and use `isa` instead `dyn_cast` to make code clean. --- .../mlir/Dialect/Mesh/Transforms/Simplifications.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h index 3f1041cb25103..243dbf081b999 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h @@ -62,9 +62,11 @@ void populateAllReduceEndomorphismSimplificationPatterns( auto isEndomorphismOp = [reduction](Operation *op, std::optional referenceOp) { auto allReduceOp = llvm::dyn_cast(op); + if (!allReduceOp) + return false; auto inType = cast(allReduceOp.getInput().getType()); auto outType = cast(allReduceOp.getResult().getType()); - if (!allReduceOp || inType.getElementType() != outType.getElementType() || + if (inType.getElementType() != outType.getElementType() || allReduceOp.getReduction() != reduction) { return false; } @@ -87,9 +89,7 @@ void populateAllReduceEndomorphismSimplificationPatterns( return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() && inType.getElementType() == refType.getElementType(); }; - auto isAlgebraicOp = [](Operation *op) { - return static_cast(llvm::dyn_cast(op)); - }; + auto isAlgebraicOp = [](Operation *op) { return isa(op); }; using ConcreteEndomorphismSimplification = EndomorphismSimplification< std::decay_t, From 7071c472001e5b08e289cc3bedb9362b0e345e2e Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Thu, 17 Jul 2025 14:54:59 +0800 Subject: [PATCH 2/2] add test --- mlir/test/Dialect/Mesh/simplifications.mlir | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir index 2540fbf9510c4..e955f4c134259 100644 --- a/mlir/test/Dialect/Mesh/simplifications.mlir +++ b/mlir/test/Dialect/Mesh/simplifications.mlir @@ -165,3 +165,15 @@ func.func @all_reduce_arith_minsi_endomorphism( // CHECK: return %[[ALL_REDUCE_RES]] return %2 : tensor<5xi32> } + +// Ensure this case without endomorphism op not crash. +// CHECK-LABEL: func.func @no_endomorphism_op +func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 { + %c0 = arith.constant 0 : index + %c1_i64 = arith.constant 1 : i64 + // CHECK: tensor.extract + %extracted = tensor.extract %arg0[%c0] : tensor<2xi64> + // CHECK: arith.maxsi + %0 = arith.maxsi %extracted, %c1_i64 : i64 + return %0 : i64 +}