-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[AArch64][Codegen]Transform saturating smull to sqdmulh #143671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Nashe Mncube (nasherm) ChangesThis patch adds a pattern for recognizing saturating vector smull. Prior to this patch these were performed using a combination of smull+smull2+uzp+smin. The sqdmull instructions performs the saturation removing the need for smin calls. Full diff: https://github.com/llvm/llvm-project/pull/143671.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 727831896737d..3ab3e6fda524c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -9349,6 +9349,25 @@ def : Pat<(v4i32 (mulhs V128:$Rn, V128:$Rm)),
(EXTRACT_SUBREG V128:$Rm, dsub)),
(SMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
+// Saturating vector mulitplications on signed integers
+// follow a smull + smull2 + uzip + smin pattern. It would
+// be more efficient to make use of sqdmull instructions which
+// negates the need for a saturating smin call.
+def : Pat<(v8i16(smin (mulhs V128:$Rn, V128:$Rm),
+ (v8i16 (AArch64mvni_shift (i32 192), (i32 8))))),
+ (UZP2v8i16
+ (SQDMULLv4i16_v4i32 (EXTRACT_SUBREG V128:$Rn, dsub),
+ (EXTRACT_SUBREG V128:$Rm, dsub)),
+ (SQDMULLv8i16_v4i32 V128:$Rn, V128:$Rm))>;
+
+
+def : Pat<(v4i32 (smin (mulhs V128:$Rn, V128:$Rm),
+ (v4i32 (AArch64mvni_shift (i32 192), (i32 24))))),
+ (UZP2v4i32
+ (SQDMULLv2i32_v2i64 (EXTRACT_SUBREG V128:$Rn, dsub),
+ (EXTRACT_SUBREG V128:$Rm, dsub)),
+ (SQDMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
+
def : Pat<(v16i8 (mulhu V128:$Rn, V128:$Rm)),
(UZP2v16i8
(UMULLv8i8_v8i16 (EXTRACT_SUBREG V128:$Rn, dsub),
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
new file mode 100644
index 0000000000000..8cdb71a41eb4c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -0,0 +1,202 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
+
+define void @arm_mult_q31(ptr %0, ptr %1, ptr %2, i32 %3) {
+; CHECK-LABEL: arm_mult_q31:
+; CHECK: // %bb.0:
+; CHECK-NEXT: cbz w3, .LBB0_4
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: cmp w3, #8
+; CHECK-NEXT: b.lo .LBB0_4
+; CHECK-NEXT: // %bb.2:
+; CHECK-NEXT: mov w8, w3
+; CHECK-NEXT: add x9, x2, #16
+; CHECK-NEXT: add x10, x1, #16
+; CHECK-NEXT: and x8, x8, #0xfffffff8
+; CHECK-NEXT: add x11, x0, #16
+; CHECK-NEXT: .LBB0_3: // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT: ldp q0, q3, [x10, #-16]
+; CHECK-NEXT: subs x8, x8, #8
+; CHECK-NEXT: ldp q1, q2, [x11, #-16]
+; CHECK-NEXT: add x10, x10, #32
+; CHECK-NEXT: add x11, x11, #32
+; CHECK-NEXT: sqdmull2 v4.2d, v0.4s, v1.4s
+; CHECK-NEXT: sqdmull v0.2d, v0.2s, v1.2s
+; CHECK-NEXT: sqdmull2 v1.2d, v3.4s, v2.4s
+; CHECK-NEXT: sqdmull v2.2d, v3.2s, v2.2s
+; CHECK-NEXT: uzp2 v0.4s, v0.4s, v4.4s
+; CHECK-NEXT: uzp2 v1.4s, v2.4s, v1.4s
+; CHECK-NEXT: stp q0, q1, [x9, #-16]
+; CHECK-NEXT: add x9, x9, #32
+; CHECK-NEXT: b.ne .LBB0_3
+; CHECK-NEXT: .LBB0_4:
+; CHECK-NEXT: ret
+ %5 = icmp eq i32 %3, 0
+ br i1 %5, label %48, label %6
+
+6:
+ %7 = zext i32 %3 to i64
+ %8 = icmp ult i32 %3, 8
+ br i1 %8, label %48, label %9
+
+9:
+ %10 = and i64 %7, 4294967288
+ %11 = shl nuw nsw i64 %10, 2
+ %12 = getelementptr i8, ptr %0, i64 %11
+ %13 = trunc nuw i64 %10 to i32
+ %14 = sub i32 %3, %13
+ %15 = shl nuw nsw i64 %10, 2
+ %16 = getelementptr i8, ptr %1, i64 %15
+ %17 = shl nuw nsw i64 %10, 2
+ %18 = getelementptr i8, ptr %2, i64 %17
+ br label %19
+
+19:
+ %20 = phi i64 [ 0, %9 ], [ %46, %19 ]
+ %21 = shl i64 %20, 2
+ %22 = getelementptr i8, ptr %0, i64 %21
+ %23 = shl i64 %20, 2
+ %24 = getelementptr i8, ptr %1, i64 %23
+ %25 = shl i64 %20, 2
+ %26 = getelementptr i8, ptr %2, i64 %25
+ %27 = getelementptr i8, ptr %22, i64 16
+ %28 = load <4 x i32>, ptr %22, align 4
+ %29 = load <4 x i32>, ptr %27, align 4
+ %30 = sext <4 x i32> %28 to <4 x i64>
+ %31 = sext <4 x i32> %29 to <4 x i64>
+ %32 = getelementptr i8, ptr %24, i64 16
+ %33 = load <4 x i32>, ptr %24, align 4
+ %34 = load <4 x i32>, ptr %32, align 4
+ %35 = sext <4 x i32> %33 to <4 x i64>
+ %36 = sext <4 x i32> %34 to <4 x i64>
+ %37 = mul nsw <4 x i64> %35, %30
+ %38 = mul nsw <4 x i64> %36, %31
+ %39 = lshr <4 x i64> %37, splat (i64 32)
+ %40 = lshr <4 x i64> %38, splat (i64 32)
+ %41 = trunc nuw <4 x i64> %39 to <4 x i32>
+ %42 = trunc nuw <4 x i64> %40 to <4 x i32>
+ %43 = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %41, <4 x i32> splat (i32 1073741823))
+ %44 = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %42, <4 x i32> splat (i32 1073741823))
+ %45 = getelementptr i8, ptr %26, i64 16
+ store <4 x i32> %43, ptr %26, align 4
+ store <4 x i32> %44, ptr %45, align 4
+ %46 = add nuw i64 %20, 8
+ %47 = icmp eq i64 %46, %10
+ br i1 %47, label %48, label %19
+
+48:
+ ret void
+}
+
+define void @arm_mult_q15(ptr %0, ptr %1, ptr %2, i16 %3) {
+; CHECK-LABEL: arm_mult_q15:
+; CHECK: // %bb.0:
+; CHECK-NEXT: and w8, w3, #0xffff
+; CHECK-NEXT: cmp w8, #4
+; CHECK-NEXT: b.lo .LBB1_7
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: ubfx w8, w3, #2, #14
+; CHECK-NEXT: sub w8, w8, #1
+; CHECK-NEXT: cmp w8, #3
+; CHECK-NEXT: b.lo .LBB1_7
+; CHECK-NEXT: // %bb.2:
+; CHECK-NEXT: sub x9, x2, x0
+; CHECK-NEXT: cmp x9, #32
+; CHECK-NEXT: b.lo .LBB1_7
+; CHECK-NEXT: // %bb.3:
+; CHECK-NEXT: sub x9, x2, x1
+; CHECK-NEXT: cmp x9, #32
+; CHECK-NEXT: b.lo .LBB1_7
+; CHECK-NEXT: // %bb.4:
+; CHECK-NEXT: cmp w8, #15
+; CHECK-NEXT: b.lo .LBB1_7
+; CHECK-NEXT: // %bb.5:
+; CHECK-NEXT: and x8, x8, #0xffff
+; CHECK-NEXT: add x9, x2, #16
+; CHECK-NEXT: add x10, x1, #16
+; CHECK-NEXT: add x8, x8, #1
+; CHECK-NEXT: add x11, x0, #16
+; CHECK-NEXT: and x8, x8, #0x1fff0
+; CHECK-NEXT: .LBB1_6: // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT: ldp q0, q3, [x10, #-16]
+; CHECK-NEXT: subs x8, x8, #16
+; CHECK-NEXT: ldp q1, q2, [x11, #-16]
+; CHECK-NEXT: add x10, x10, #32
+; CHECK-NEXT: add x11, x11, #32
+; CHECK-NEXT: sqdmull2 v4.4s, v0.8h, v1.8h
+; CHECK-NEXT: sqdmull v0.4s, v0.4h, v1.4h
+; CHECK-NEXT: sqdmull2 v1.4s, v3.8h, v2.8h
+; CHECK-NEXT: sqdmull v2.4s, v3.4h, v2.4h
+; CHECK-NEXT: uzp2 v0.8h, v0.8h, v4.8h
+; CHECK-NEXT: uzp2 v1.8h, v2.8h, v1.8h
+; CHECK-NEXT: stp q0, q1, [x9, #-16]
+; CHECK-NEXT: add x9, x9, #32
+; CHECK-NEXT: b.ne .LBB1_6
+; CHECK-NEXT: .LBB1_7:
+; CHECK-NEXT: ret
+ %5 = ptrtoint ptr %1 to i64
+ %6 = ptrtoint ptr %0 to i64
+ %7 = ptrtoint ptr %2 to i64
+ %8 = icmp ult i16 %3, 4
+ br i1 %8, label %54, label %9
+
+9:
+ %10 = lshr i16 %3, 2
+ %11 = add nsw i16 %10, -1
+ %12 = zext i16 %11 to i64
+ %13 = add nuw nsw i64 %12, 1
+ %14 = icmp ult i16 %11, 3
+ br i1 %14, label %54, label %15
+
+15:
+ %16 = sub i64 %7, %6
+ %17 = icmp ult i64 %16, 32
+ %18 = sub i64 %7, %5
+ %19 = icmp ult i64 %18, 32
+ %20 = or i1 %17, %19
+ br i1 %20, label %54, label %21
+
+21:
+ %22 = icmp ult i16 %11, 15
+ br i1 %22, label %54, label %23
+
+23:
+ %24 = and i64 %13, 131056
+ br label %25
+
+25:
+ %26 = phi i64 [ 0, %23 ], [ %52, %25 ]
+ %27 = shl i64 %26, 1
+ %28 = getelementptr i8, ptr %0, i64 %27
+ %29 = shl i64 %26, 1
+ %30 = getelementptr i8, ptr %1, i64 %29
+ %31 = shl i64 %26, 1
+ %32 = getelementptr i8, ptr %2, i64 %31
+ %33 = getelementptr i8, ptr %28, i64 16
+ %34 = load <8 x i16>, ptr %28, align 2
+ %35 = load <8 x i16>, ptr %33, align 2
+ %36 = sext <8 x i16> %34 to <8 x i32>
+ %37 = sext <8 x i16> %35 to <8 x i32>
+ %38 = getelementptr i8, ptr %30, i64 16
+ %39 = load <8 x i16>, ptr %30, align 2
+ %40 = load <8 x i16>, ptr %38, align 2
+ %41 = sext <8 x i16> %39 to <8 x i32>
+ %42 = sext <8 x i16> %40 to <8 x i32>
+ %43 = mul nsw <8 x i32> %41, %36
+ %44 = mul nsw <8 x i32> %42, %37
+ %45 = lshr <8 x i32> %43, splat (i32 16)
+ %46 = lshr <8 x i32> %44, splat (i32 16)
+ %47 = trunc nuw <8 x i32> %45 to <8 x i16>
+ %48 = trunc nuw <8 x i32> %46 to <8 x i16>
+ %49 = tail call <8 x i16> @llvm.smin.v8i16(<8 x i16> %47, <8 x i16> splat (i16 16383))
+ %50 = tail call <8 x i16> @llvm.smin.v8i16(<8 x i16> %48, <8 x i16> splat (i16 16383))
+ %51 = getelementptr i8, ptr %32, i64 16
+ store <8 x i16> %49, ptr %32, align 2
+ store <8 x i16> %50, ptr %51, align 2
+ %52 = add nuw i64 %26, 16
+ %53 = icmp eq i64 %52, %24
+ br i1 %53, label %54, label %25
+
+54:
+ ret void
+}
|
fdef5a6
to
30c4193
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we get a test precommit to see the changes in action?
def : Pat<(v4i32 (smin (mulhs V128:$Rn, V128:$Rm), | ||
(v4i32 (AArch64mvni_shift (i32 192), (i32 24))))), | ||
(USHRv4i32_shift (SQDMULHv4i32 V128:$Rn, V128:$Rm), (i32 1))>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we support more than only v4i32
, and maybe scalable vectors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure. I'd have to have a think. I know that SQDMULH only supports v8i16 and v4i32. SVE I'm unsure of.
3761374
to
db2aaab
Compare
Ping @davemgreen @NickGuy-Arm I have made changes to AArch64ISelLowering to enable this transform for As for SVE I haven't had bandwidth to look at enabling this so I don't currently know if this would be possible/beneficial for an SVE equivalent pattern |
Hi - sorry for the delay, I was trying to re-remember how this instruction worked. This has three functions: https://godbolt.org/z/cxb7osTP1, the first of which I think is the most basic for of sqdmulh (notice the >2x bitwidth extend to allow the mul and the x2 to not wrap, and the min+max to saturate). That is equivalent to the @Updated which is what llvm will optimize it to (the mul x2 is folded into the shift, and the only value that can actually saturate is -0x8000*-0x8000). It is equivalent the the third I believe because we only need 2x the bitwidth in this form. That feels like the most basic form of sqdmulh. Any reason not to add that one first? I didn't look into this pattern a huge amount, but do you know if the bottom or top bits require the shifts? Or both? |
The bottom and top bits don't require shifts. This was a mistake on my part with my test.c program. My most recent patch uses your godbolt example to find the pattern. It makes sense to me and feels pretty generic such that SVE support would be feasible with some tinkering. Do let me know what you think |
There's a rebase error in my most recent patch that I need to fix |
720aedf
to
7d05aa4
Compare
7d05aa4
to
e5269c5
Compare
e5269c5
to
3568431
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this test case and fix what is wrong with it? Sometimes it is best to just check for all the types that are valid.
define <6 x i16> @saturating_6xi16(<6 x i16> %a, <6 x i16> %b) {
%as = sext <6 x i16> %a to <6 x i32>
%bs = sext <6 x i16> %b to <6 x i32>
%m = mul <6 x i32> %bs, %as
%sh = ashr <6 x i32> %m, splat (i32 15)
%ma = tail call <6 x i32> @llvm.smin.v6i32(<6 x i32> %sh, <6 x i32> splat (i32 32767))
%t = trunc <6 x i32> %ma to <6 x i16>
ret <6 x i16> %t
}
This patch adds a pattern for recognizing saturating vector smull. Prior to this patch these were performed using a combination of smull+smull2+uzp+smin like the following ``` smull2 v5.2d, v1.4s, v2.4s smull v1.2d, v1.2s, v2.2s uzp2 v1.4s, v1.4s, v5.4s smin v1.4s, v1.4s, v0.4s add v1.4s, v1.4s, v1.4s ``` which now optimizes to ``` sqdmulh v0.4s, v1.4s, v0.4s sshr v0.4s, v0.4s, llvm#1 add v0.4s, v0.4s, v0.4s ``` This only operates on vectors containing Q31 data types. Change-Id: Ib7d4d5284d1bd3fdd0907365f9e2f37f4da14671
Based on the most recent PR comments I've - refactored the change to work on a reduced pattern which is truer to the actual SQDMULH instruction - written pattern matches for q31, q15 and int32, int16 data types - rewritten and extended the tests Change-Id: I18c05e56b3979b8dd757d533e44a65496434937b
Spotted and fixed an artihmetic error when working with Q types Change-Id: I80f8e04bca08d3e6bc2740201bdd4978446a397f
- support for v2i32 and v4i16 patterns - extra type checking on sext - matching on smin over sext - cleaning trailing lines Change-Id: I9f61b8d77a61f3d44ad5073b41555c9ad5653e1a
- minor cleanup - allow optimizing concat_vectors(sqdmulh,sqdmulh) -> sqdmulh - testing EVTs better Change-Id: I0404fb9900896050baac372b7f7ce3a5b03517b9
- making sure transform only operates on smin nodes - adding extra tests dealing with interesting edge cases Change-Id: Ia1114ec9b93c4de3552b867e0d745beccdae69f1
- check for scalar type - check for sign extends - legalise vector inputs to sqdmulh - always return sext(sqdmulh) Change-Id: Ic58b7f267e94bc2592942fc29b829ffb6221770f
Change-Id: Ifeba39acab171c75df496f89688fda701bd7dd85
- added testcase for v6i16 and fixed issues - added testcases for v1i16 and fixed issues Change-Id: I4694c48ff9f12ee6048efd2394d5b710df7ebbea
3810477
to
a6fc487
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. LGTM
This patch adds a pattern for recognizing saturating vector
smull. Prior to this patch these were performed using a
combination of smull+smull2+uzp+smin like the following
which now optimizes to
This only operates on vectors containing int32 and int16 types