-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][Arith] Prevent IR modification for non-matching pattern #150103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Arith] Prevent IR modification for non-matching pattern #150103
Conversation
The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles.
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThe F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles. Full diff: https://github.com/llvm/llvm-project/pull/150103.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index f497d2db3bf7c..5e575de4065ca 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -518,10 +518,10 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
- if (!isa<Float32Type>(operandETy))
- operand = b.create<arith::ExtFOp>(f32Ty, operand);
if (!isa<Float4E2M1FNType>(resultETy))
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+ if (!isa<Float32Type>(operandETy))
+ operand = b.create<arith::ExtFOp>(f32Ty, operand);
Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you take a look at the other patterns here? From a quick scan I wonder if there's similar issues, e.g. here:
llvm-project/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
Lines 654 to 664 in 864392b
if (scaleETy.getIntOrFloatBitWidth() >= 16) { | |
scaleETy = b.getF8E8M0Type(); | |
scaleTy = cloneToShapedType(scaleTy, scaleETy); | |
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, | |
op.getFastmathAttr()); | |
} | |
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { | |
return rewriter.notifyMatchFailure( | |
op, "scaling_extf is using scales of type which can not be converted " | |
"to f8E8M0FNU"); | |
} |
edit: looking a bit closer it seems this case is fine, though it's not super obvious
I did do an overall check. That particular |
…150103) The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles. Also, fix a pair of whitespace issues that snuck in.
The F4E2M1 truncation emulation was expanding or truncating operations to F32 even when the pattern did not apply, causing non-convergent rewrites when operating on doubles.
Also, fix a pair of whitespace issues that snuck in.