From 988b67ec707896080216e07ed29609531f307cce Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 8 Sep 2023 09:40:39 +0000 Subject: [PATCH 1/2] [mlir][vector] Extend mask calculation for vector.contract Make sure that when calculating the expected mask for `vector.contract`, scalable sizes are correctly taken into account. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 +++++++++++------ mlir/test/Dialect/Vector/ops.mlir | 19 +++++++++++++++++++ ...contract-to-parallel-arith-transforms.mlir | 1 + 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 6473c92a91aa6..e753562c3fbd3 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() { unsigned numVecDims = lhsIdxMap.getNumDims(); SmallVector maskShape(numVecDims, ShapedType::kDynamic); + SmallVector maskShapeScalabledims(numVecDims, false); // Using the information in the indexing maps, extract the size of each // dimension in the vector.contract operation from the two input operands. - for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) + for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) { maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize; - for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) + maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] = + lhsType.getScalableDims()[dimIdx]; + } + for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) { maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; + maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] = + rhsType.getScalableDims()[dimIdx]; + } assert(!ShapedType::isDynamicShape(maskShape) && "Mask shape couldn't be computed"); - // TODO: Extend the scalable vector type representation with a bit map. - assert(!lhsType.isScalable() && !rhsType.isScalable() && - "Scalable vectors are not supported yet"); return VectorType::get(maskShape, - IntegerType::get(lhsType.getContext(), /*width=*/1)); + IntegerType::get(lhsType.getContext(), /*width=*/1), + maskShapeScalabledims); } SmallVector ContractionOp::getTraitAttrNames() { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index f00bc6e97b350..61118a35922f4 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -979,3 +979,22 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) { %2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32> return } + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} +func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>, + %arg1: vector<4x[8]xf32>, + %arg2: vector<3x[8]xf32>, + %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { + %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> + return %0 : vector<3x[8]xf32> +} + diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir index 147f3ae921991..b0e48c4e85142 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir @@ -60,3 +60,4 @@ transform.sequence failures(propagate) { transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith" } : !transform.any_op } + From 77b8395f88777b7f6f49ebd61f9faf4f6d5e040c Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 11 Sep 2023 09:08:43 +0000 Subject: [PATCH 2/2] fixup! [mlir][vector] Extend mask calculation for vector.contract --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++++---- mlir/test/Dialect/Vector/ops.mlir | 19 ++++++++++++------- ...contract-to-parallel-arith-transforms.mlir | 1 - 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index e753562c3fbd3..1222542ee39fd 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -912,18 +912,18 @@ Type ContractionOp::getExpectedMaskType() { unsigned numVecDims = lhsIdxMap.getNumDims(); SmallVector maskShape(numVecDims, ShapedType::kDynamic); - SmallVector maskShapeScalabledims(numVecDims, false); + SmallVector maskShapeScalableDims(numVecDims, false); // Using the information in the indexing maps, extract the size of each // dimension in the vector.contract operation from the two input operands. for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) { maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize; - maskShapeScalabledims[lhsIdxMap.getDimPosition(dimIdx)] = + maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] = lhsType.getScalableDims()[dimIdx]; } for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) { maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; - maskShapeScalabledims[rhsIdxMap.getDimPosition(dimIdx)] = + maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] = rhsType.getScalableDims()[dimIdx]; } @@ -932,7 +932,7 @@ Type ContractionOp::getExpectedMaskType() { return VectorType::get(maskShape, IntegerType::get(lhsType.getContext(), /*width=*/1), - maskShapeScalabledims); + maskShapeScalableDims); } SmallVector ContractionOp::getTraitAttrNames() { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 61118a35922f4..d41cee5ea67b0 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -989,12 +989,17 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) { indexing_maps = #matmat_accesses, iterator_types = ["parallel", "parallel", "reduction"] } -func.func @matmul_masked_scalable(%arg0: vector<3x4xf32>, - %arg1: vector<4x[8]xf32>, - %arg2: vector<3x[8]xf32>, - %m : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { - %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 - : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> +// CHECK-LABEL: func.func @contraction_masked_scalable( +// CHECK-SAME: %[[A:.*]]: vector<3x4xf32>, +// CHECK-SAME: %[[B:.*]]: vector<4x[8]xf32>, +// CHECK-SAME: %[[C:.*]]: vector<3x[8]xf32>, +// CHECK-SAME: %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { +func.func @contraction_masked_scalable(%A: vector<3x4xf32>, + %B: vector<4x[8]xf32>, + %C: vector<3x[8]xf32>, + %M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> { + // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32> + %0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } + : vector<3x[8]x4xi1> -> vector<3x[8]xf32> return %0 : vector<3x[8]xf32> } - diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir index b0e48c4e85142..147f3ae921991 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir @@ -60,4 +60,3 @@ transform.sequence failures(propagate) { transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith" } : !transform.any_op } -