-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars #148055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Jaden Angella (Jaddyen) ChangesThis aims to expand the the MemRefToEmitC pass so that it can accept global scalars. Full diff: https://github.com/llvm/llvm-project/pull/148055.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..e55c8e48ad105 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,7 +16,9 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
-
+ MemRefType type = op.getType();
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform global with dynamic shape");
@@ -95,7 +97,13 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
- auto resultTy = getTypeConverter()->convertType(op.getType());
+
+ Type resultTy;
+ if (type.getRank() == 0)
+ resultTy = getTypeConverter()->convertType(type.getElementType());
+ else
+ resultTy = getTypeConverter()->convertType(type);
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
@@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
bool externSpecifier = !staticSpecifier;
Attribute initialValue = operands.getInitialValueAttr();
+ if (type.getRank() == 0) {
+ auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
+ initialValue = elementsAttr.getSplatValue<Attribute>();
+ }
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};
@@ -132,7 +144,17 @@ struct ConvertGetGlobal final
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
- auto resultTy = getTypeConverter()->convertType(op.getType());
+ MemRefType type = op.getType();
+ Type resultTy;
+ if (type.getRank() == 0)
+ resultTy = emitc::LValueType::get(
+ getTypeConverter()->convertType(type.getElementType()));
+ else
+ resultTy = getTypeConverter()->convertType(type);
+
+ if (!resultTy)
+ return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..445a28534325a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
module @globals {
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+ memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
+ // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
memref.global @public_global : memref<3x7xf32>
// CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,8 @@ module @globals {
func.func @use_global() {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
+ // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
+ %1 = memref.get_global @__constant_xi32 : memref<i32>
return
}
}
|
@@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { | |||
LogicalResult | |||
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, | |||
ConversionPatternRewriter &rewriter) const override { | |||
|
|||
MemRefType type = op.getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MemRefType type = op.getType(); | |
MemRefType opTy = op.getType(); |
I'm not sure we want to use type
as a variable name...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Thanks for the pointer!
if (type.getRank() == 0) | ||
resultTy = getTypeConverter()->convertType(type.getElementType()); | ||
else | ||
resultTy = getTypeConverter()->convertType(type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe introduce a helper, since I see a similar pattern in a few spots?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Def! Thanks for the pointer.
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 { | |||
module @globals { | |||
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0> | |||
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00> | |||
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was expecting to see a corresponding change to memref-to-emitc-failed.mlir
. Looking I suppose it isn't there, but are any of the cases, like
%0 = memref.alloca() : memref<0xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. This would require changing those particular ops. I could circle back to this once we can completely lower the inliner model.
|
||
Type resultTy; | ||
if (type.getRank() == 0) | ||
resultTy = getTypeConverter()->convertType(type.getElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to just do
resultTy = getTypeConverter()->convertType(getElementTypeOrSelf(type));
@@ -50,6 +52,8 @@ module @globals { | |||
func.func @use_global() { | |||
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32> | |||
%0 = memref.get_global @public_global : memref<3x7xf32> | |||
// CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32> | |||
%1 = memref.get_global @__constant_xi32 : memref<i32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check, so previously memref<1xi32>
worked, but memref<i32>
didn't work before this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
@@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { | |||
bool externSpecifier = !staticSpecifier; | |||
|
|||
Attribute initialValue = operands.getInitialValueAttr(); | |||
if (type.getRank() == 0) { | |||
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue()); | |||
initialValue = elementsAttr.getSplatValue<Attribute>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the initialValue before this vs splat value returned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without using getSplatValue, we have the initial value as: initial_value = dense<-1> : tensor<i32>
After getting the splat value, we have emitc.global static const @__constant_xi32 : i32 = -1
MemRefType type = op.getType(); | ||
Type resultTy; | ||
if (type.getRank() == 0) | ||
resultTy = emitc::LValueType::get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to be LValue, while one below not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memref.getGlobal
gets converted to emitc.getglobal
but emitc.getglobal
only returns LValue or Array. So in the case that we have a constant, we create an LValue, else we return an array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that in general lowering getglobal to lvalue doesn't work correctly when the result is passed to function calls for example. So I would expect rank 0 memrefs to be lowered to pointers (at least when it might escape).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out. I've updated the conversion to reflect this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry für being unclear. I think globals should still be lowered to lvalues so that it allocates the necessary storage. But the get_global may be lowered to EmitC.get_global + EmitC.apply "&" to get a Pointer to the variable. But I haven't tried this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I've tried and been able to implement is the conversion:
From:
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
func.func @globals() {
memref.get_global @__constant_xi32 : memref<i32>
}
To:
emitc.global static const @__constant_xi32 : i32 = -1
emitc.func @globals() {
%0 = get_global @__constant_xi32 : !emitc.lvalue<i32>
%1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
return
}
resultTy = getTypeConverter()->convertType(type); | ||
|
||
if (!resultTy) | ||
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you have this check just below too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do! Thanks for the pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general. I believe you also tested it with compiling the output too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Yes, I did! |
This aims to expand the the MemRefToEmitC pass so that it can accept global scalars.
From:
To: