summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAart Bik <ajcbik@google.com>2022-03-09 11:04:00 -0800
committerAart Bik <ajcbik@google.com>2022-03-09 15:10:44 -0800
commit52fb4f53c29e2a8cd304f391ba27b127fb5a1cfe (patch)
tree6706ff4c07a7ad99fda089bdcb010922f6154962
parentfc9e07873f0cecb875fbef3407ded61e6747e5da (diff)
[mlir][sparse] added linalg.dot to sparse kernel collection
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D121315
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_kernels.mlir63
1 files changed, 63 insertions, 0 deletions
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 6d427d5824b3..05e895a262f6 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -3,6 +3,8 @@
// RUN: --linalg-generalize-named-ops --linalg-fuse-elementwise-ops \
// RUN: --sparsification | FileCheck %s
+#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
// CHECK-LABEL: func @matmul1(
@@ -255,3 +257,64 @@ func @quantized_matmul(%input1: tensor<5x3xi8>,
outs(%output : tensor<5x6xi64>) -> tensor<5x6xi64>
return %0: tensor<5x6xi64>
}
+
+// CHECK-LABEL: func @sparse_dot(
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0:.*]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1:.*]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<1024xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
+// CHECK-DAG: %[[VAL_11:.*]] = memref.alloc() : memref<f32>
+// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2:.*]] : memref<f32>
+// CHECK-DAG: memref.copy %[[VAL_12]], %[[VAL_11]] : memref<f32> to memref<f32>
+// CHECK-DAG: %[[VAL_13:.*]] = memref.load %[[VAL_11]][] : memref<f32>
+// CHECK-DAG: %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK-DAG: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]]:3 = scf.while (%[[VAL_19:.*]] = %[[VAL_14]], %[[VAL_20:.*]] = %[[VAL_16]], %[[VAL_21:.*]] = %[[VAL_13]]) : (index, index, f32) -> (index, index, f32) {
+// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_19]], %[[VAL_15]] : index
+// CHECK: %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_17]] : index
+// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_22]], %[[VAL_23]] : i1
+// CHECK: scf.condition(%[[VAL_24]]) %[[VAL_19]], %[[VAL_20]], %[[VAL_21]] : index, index, f32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index, %[[VAL_27:.*]]: f32):
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref<?xindex>
+// CHECK: %[[VAL_30:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_28]] : index
+// CHECK: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index
+// CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
+// CHECK: %[[VAL_33:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
+// CHECK: %[[VAL_34:.*]] = arith.andi %[[VAL_32]], %[[VAL_33]] : i1
+// CHECK: %[[VAL_35:.*]] = scf.if %[[VAL_34]] -> (f32) {
+// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_25]]] : memref<?xf32>
+// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xf32>
+// CHECK: %[[VAL_38:.*]] = arith.mulf %[[VAL_36]], %[[VAL_37]] : f32
+// CHECK: %[[VAL_39:.*]] = arith.addf %[[VAL_27]], %[[VAL_38]] : f32
+// CHECK: scf.yield %[[VAL_39]] : f32
+// CHECK: } else {
+// CHECK: scf.yield %[[VAL_27]] : f32
+// CHECK: }
+// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_31]] : index
+// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_25]], %[[VAL_4]] : index
+// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_40]], %[[VAL_41]], %[[VAL_25]] : index
+// CHECK: %[[VAL_43:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_31]] : index
+// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_26]], %[[VAL_4]] : index
+// CHECK: %[[VAL_45:.*]] = arith.select %[[VAL_43]], %[[VAL_44]], %[[VAL_26]] : index
+// CHECK: scf.yield %[[VAL_42]], %[[VAL_45]], %[[VAL_46:.*]] : index, index, f32
+// CHECK: }
+// CHECK: memref.store %[[VAL_47:.*]]#2, %[[VAL_11]][] : memref<f32>
+// CHECK: %[[VAL_48:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<f32>
+// CHECK: return %[[VAL_48]] : tensor<f32>
+// CHECK: }
+func @sparse_dot(%a: tensor<1024xf32, #SparseVector>,
+ %b: tensor<1024xf32, #SparseVector>,
+ %x: tensor<f32>) -> tensor<f32> {
+ %dot = linalg.dot ins(%a, %b: tensor<1024xf32, #SparseVector>,
+ tensor<1024xf32, #SparseVector>)
+ outs(%x: tensor<f32>) -> tensor<f32>
+ return %dot : tensor<f32>
+}