Skip to content

[mlir][Arith] Prevent IR modification for non-matching pattern #150103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Jul 22, 2025

The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles.

Also, fix a pair of whitespace issues that snuck in.

The F4E2M1 truncation emulation was expanding or truncating operations
to F32 even when the pattern did not apply, causing non-convergent
rewrites when operating on doubles.
@llvmbot
Copy link
Member

llvmbot commented Jul 22, 2025

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles.


Full diff: https://github.com/llvm/llvm-project/pull/150103.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+2-2)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index f497d2db3bf7c..5e575de4065ca 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -518,10 +518,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
     Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
-    if (!isa<Float32Type>(operandETy))
-      operand = b.create<arith::ExtFOp>(f32Ty, operand);
     if (!isa<Float4E2M1FNType>(resultETy))
       return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+    if (!isa<Float32Type>(operandETy))
+      operand = b.create<arith::ExtFOp>(f32Ty, operand);
 
     Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
     Value c0x3 = createConst(loc, i4Ty, 3, rewriter);

Copy link
Member

@rkayaith rkayaith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you take a look at the other patterns here? From a quick scan I wonder if there's similar issues, e.g. here:

if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
op.getFastmathAttr());
}
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
return rewriter.notifyMatchFailure(
op, "scaling_extf is using scales of type which can not be converted "
"to f8E8M0FNU");
}

edit: looking a bit closer it seems this case is fine, though it's not super obvious

@krzysz00
Copy link
Contributor Author

Could you take a look at the other patterns here? From a quick scan I wonder if there's similar issues, e.g. here:

I did do an overall check.

That particular notifyMatchFailure() is probably something that should be an assert or removed entirely - the code right above ensures we never hit it.

@krzysz00 krzysz00 merged commit eb55412 into llvm:main Jul 22, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…150103)

The F4E2M1 truncation emulation was expanding or truncating operations
to F32 even when the pattern did not apply, causing non-convergent
rewrites when operating on doubles.

Also, fix a pair of whitespace issues that snuck in.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants