Skip to content

[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars #148055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 18, 2025

Conversation

Jaddyen
Copy link
Contributor

@Jaddyen Jaddyen commented Jul 10, 2025

This aims to expand the the MemRefToEmitC pass so that it can accept global scalars.
From:

memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
func.func @globals() {
    memref.get_global @__constant_xi32 : memref<i32>
}

To:

emitc.global static const @__constant_xi32 : i32 = -1
    emitc.func @globals() {
      %0 = get_global @__constant_xi32 : !emitc.lvalue<i32>
      %1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
      return
    }

@Jaddyen Jaddyen changed the title Expand the MemRef to EmitC pass Expand the MemRef to EmitC pass - Adding scalars Jul 10, 2025
@Jaddyen Jaddyen changed the title Expand the MemRef to EmitC pass - Adding scalars Expand the MemRefToEmitC pass - Adding scalars Jul 10, 2025
@Jaddyen Jaddyen marked this pull request as ready for review July 10, 2025 23:50
@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Jaden Angella (Jaddyen)

Changes

This aims to expand the the MemRefToEmitC pass so that it can accept global scalars.


Full diff: https://github.com/llvm/llvm-project/pull/148055.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+25-3)
  • (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+4)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..e55c8e48ad105 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,7 +16,9 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   LogicalResult
   matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
-
+    MemRefType type = op.getType();
     if (!op.getType().hasStaticShape()) {
       return rewriter.notifyMatchFailure(
           op.getLoc(), "cannot transform global with dynamic shape");
@@ -95,7 +97,13 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
           op.getLoc(), "global variable with alignment requirement is "
                        "currently not supported");
     }
-    auto resultTy = getTypeConverter()->convertType(op.getType());
+
+    Type resultTy;
+    if (type.getRank() == 0)
+      resultTy = getTypeConverter()->convertType(type.getElementType());
+    else
+      resultTy = getTypeConverter()->convertType(type);
+
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cannot convert result type");
@@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
     bool externSpecifier = !staticSpecifier;
 
     Attribute initialValue = operands.getInitialValueAttr();
+    if (type.getRank() == 0) {
+      auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
+      initialValue = elementsAttr.getSplatValue<Attribute>();
+    }
     if (isa_and_present<UnitAttr>(initialValue))
       initialValue = {};
 
@@ -132,7 +144,17 @@ struct ConvertGetGlobal final
   matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto resultTy = getTypeConverter()->convertType(op.getType());
+    MemRefType type = op.getType();
+    Type resultTy;
+    if (type.getRank() == 0)
+      resultTy = emitc::LValueType::get(
+          getTypeConverter()->convertType(type.getElementType()));
+    else
+      resultTy = getTypeConverter()->convertType(type);
+
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cannot convert result type");
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..445a28534325a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
 module @globals {
   memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
   // CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+  memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
+  // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
   memref.global @public_global : memref<3x7xf32>
   // CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
   memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,8 @@ module @globals {
   func.func @use_global() {
     // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
     %0 = memref.get_global @public_global : memref<3x7xf32>
+    // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
+    %1 = memref.get_global @__constant_xi32 : memref<i32>
     return
   }
 }

@@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

MemRefType type = op.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MemRefType type = op.getType();
MemRefType opTy = op.getType();

I'm not sure we want to use type as a variable name...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Thanks for the pointer!

Comment on lines 102 to 105
if (type.getRank() == 0)
resultTy = getTypeConverter()->convertType(type.getElementType());
else
resultTy = getTypeConverter()->convertType(type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe introduce a helper, since I see a similar pattern in a few spots?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Def! Thanks for the pointer.

@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
module @globals {
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting to see a corresponding change to memref-to-emitc-failed.mlir. Looking I suppose it isn't there, but are any of the cases, like

going to work now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. This would require changing those particular ops. I could circle back to this once we can completely lower the inliner model.


Type resultTy;
if (type.getRank() == 0)
resultTy = getTypeConverter()->convertType(type.getElementType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to just do

resultTy = getTypeConverter()->convertType(getElementTypeOrSelf(type));

@@ -50,6 +52,8 @@ module @globals {
func.func @use_global() {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
// CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
%1 = memref.get_global @__constant_xi32 : memref<i32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check, so previously memref<1xi32> worked, but memref<i32> didn't work before this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

@@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
bool externSpecifier = !staticSpecifier;

Attribute initialValue = operands.getInitialValueAttr();
if (type.getRank() == 0) {
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
initialValue = elementsAttr.getSplatValue<Attribute>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the initialValue before this vs splat value returned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without using getSplatValue, we have the initial value as: initial_value = dense<-1> : tensor<i32>
After getting the splat value, we have emitc.global static const @__constant_xi32 : i32 = -1

MemRefType type = op.getType();
Type resultTy;
if (type.getRank() == 0)
resultTy = emitc::LValueType::get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be LValue, while one below not?

Copy link
Contributor Author

@Jaddyen Jaddyen Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memref.getGlobal gets converted to emitc.getglobal but emitc.getglobal only returns LValue or Array. So in the case that we have a constant, we create an LValue, else we return an array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in general lowering getglobal to lvalue doesn't work correctly when the result is passed to function calls for example. So I would expect rank 0 memrefs to be lowered to pointers (at least when it might escape).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I've updated the conversion to reflect this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry für being unclear. I think globals should still be lowered to lvalues so that it allocates the necessary storage. But the get_global may be lowered to EmitC.get_global + EmitC.apply "&" to get a Pointer to the variable. But I haven't tried this out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I've tried and been able to implement is the conversion:
From:

memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
func.func @globals() {
    memref.get_global @__constant_xi32 : memref<i32>
}

To:

emitc.global static const @__constant_xi32 : i32 = -1
    emitc.func @globals() {
      %0 = get_global @__constant_xi32 : !emitc.lvalue<i32>
      %1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
      return
    }

resultTy = getTypeConverter()->convertType(type);

if (!resultTy)
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you have this check just below too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do! Thanks for the pointer.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. I believe you also tested it with compiling the output too.

@Jaddyen Jaddyen requested a review from simon-camp July 17, 2025 19:30
Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Jaddyen
Copy link
Contributor Author

Jaddyen commented Jul 18, 2025

Looks good in general. I believe you also tested it with compiling the output too.

Yes, I did!

@Jaddyen Jaddyen changed the title Expand the MemRefToEmitC pass - Adding scalars [mlir][emitc]Expand the MemRefToEmitC pass - Adding scalars Jul 18, 2025
@Jaddyen Jaddyen changed the title [mlir][emitc]Expand the MemRefToEmitC pass - Adding scalars [mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars Jul 18, 2025
@Jaddyen Jaddyen merged commit 7fd91bb into llvm:main Jul 18, 2025
9 checks passed
@Jaddyen Jaddyen deleted the memref-ops branch July 18, 2025 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants