diff --git a/src/abstractops.h b/src/abstractops.h index abaef7d7..46995e1e 100644 --- a/src/abstractops.h +++ b/src/abstractops.h @@ -2,7 +2,7 @@ #include "smt.h" #include "llvm/ADT/APFloat.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinOps.h" #include #include diff --git a/src/analysis.cpp b/src/analysis.cpp index 7ecbff1d..ace50f4f 100644 --- a/src/analysis.cpp +++ b/src/analysis.cpp @@ -220,13 +220,13 @@ void analyzeRegion(mlir::Region ®ion, AnalysisResult &res) { template<> bool analyzeOp(mlir::memref::GetGlobalOp op, AnalysisResult &res) { - llvm::StringRef glbName = op.name(); + llvm::StringRef glbName = op.getName(); auto mop = op.getOperation()->getParentOfType(); auto glb = mlir::cast(mop.lookupSymbol(glbName)); res.memref.usedGlobals[glbName.str()] = glb; - if (glb.constant() && glb.initial_value()) { - analyzeElemAttr(*glb.initial_value(), res); + if (glb.getConstant() && glb.getInitialValue()) { + analyzeElemAttr(glb.getInitialValue()->cast(), res); } return true; } @@ -282,13 +282,13 @@ bool analyzeOp(mlir::tosa::ClampOp op, AnalysisResult &res) { template<> bool analyzeOp(mlir::linalg::GenericOp op, AnalysisResult &res) { // If generic loop has reduction loops, then result is not elementwise - auto indexingMaps = op.indexing_maps().getValue(); + auto indexingMaps = op.getIndexingMaps().getValue(); auto outputMap = indexingMaps.back().cast().getValue(); bool isReudctionLoop = !outputMap.isPermutation(); if (isReudctionLoop) res.isElementwiseFPOps = false; - analyzeRegion(op.region(), res); + analyzeRegion(op.getRegion(), res); return true; } @@ -338,7 +338,7 @@ void analyzeBlock( // and newly created FPs can be stored to output memref. if (auto op2 = mlir::dyn_cast(op)) { if (op2.hasBufferSemantics()) { - for (const auto &operand: op2.outputs()) { + for (const auto &operand: op2.getOutputs()) { analyzeVariable(operand, res, VarAnalysisConfig::operand()); } } diff --git a/src/encode.cpp b/src/encode.cpp index 391964ec..25b23691 100644 --- a/src/encode.cpp +++ b/src/encode.cpp @@ -8,7 +8,7 @@ #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -302,6 +302,14 @@ static vector getFromMixedOps( return vec; } +template +static vector getFromArrayI64(const mlir::ArrayRef &attr) { + vector vec; + for (auto s : attr) { + vec.push_back(ValTy(s)); + } + return vec; +} template static vector getFromArrayAttr(const mlir::ArrayAttr &attr) { @@ -392,8 +400,8 @@ broadcastTensors(State &st, mlir::Value arg0, mlir::Value arg1) { auto d1 = getDimSize(ty0, idx0); auto d2 = getDimSize(ty1, idx1); - bool dyn0 = d1 == mlir::ShapedType::kDynamicSize; - bool dyn1 = d2 == mlir::ShapedType::kDynamicSize; + bool dyn0 = d1 == mlir::ShapedType::kDynamic; + bool dyn1 = d2 == mlir::ShapedType::kDynamic; if (dyn0 ^ dyn1) return nullopt; @@ -897,7 +905,7 @@ void encodeOp(State &st, mlir::arith::TruncIOp op, bool) { template<> void encodeOp(State &st, mlir::linalg::IndexOp op, bool) { - uint64_t i = op.dim(); + uint64_t i = op.getDim(); assert(i < st.linalgGenericScopes.top().indVars.size()); Expr idxvar = st.linalgGenericScopes.top().indVars[i]; st.regs.add(op, Index(idxvar)); @@ -958,7 +966,7 @@ void encodeOp(State &st, mlir::arith::IndexCastOp op, bool) { } template<> -void encodeOp(State &st, mlir::AffineApplyOp op, bool) { +void encodeOp(State &st, mlir::affine::AffineApplyOp op, bool) { auto m = op.getAffineMap(); if (m.getNumResults() != 1) throw UnsupportedException( @@ -1236,8 +1244,11 @@ void encodeOp(State &st, mlir::tosa::TileOp op, bool) { auto t = st.regs.get(op.getInput1()); vector repeat; - for (mlir::Attribute val: op.getMultiples()) - repeat.push_back(val.cast().getValue().getSExtValue()); + for (long val: op.getMultiples()) { + if (val < 0) + throw UnsupportedException(op.getOperation(), "Negative multiple"); + repeat.push_back(val); + } st.regs.add(op, t.tile(repeat)); st.wellDefined(op, t.isFullyInitialized(), "the input is initialized"); @@ -1314,13 +1325,13 @@ void encodeOp(State &st, mlir::tosa::BitwiseXorOp op, bool) { } static Tensor getPaddedTensor2D(mlir::Type elemTy, - Tensor input, - mlir::ArrayAttr padding) { - if (!llvm::all_of(padding, [](mlir::Attribute a) { - return a.cast().getInt() == 0; })) { + Tensor input, + mlir::ArrayRef padding) { + if (!llvm::all_of(padding, [](int64_t a) { + return a == 0; })) { // pad = [top, bottom, left, right], filled with zero - vector pad = getFromArrayAttr(padding); + vector pad = getFromArrayI64(padding); assert(pad.size() == 4); // input rank should be 4 @@ -1371,9 +1382,9 @@ void encodeOp(State &st, mlir::tosa::DepthwiseConv2DOp op, bool) { // bias: a 1-dim array whose size is C * M auto bias = st.regs.get(op.getBias()); // strides = [strides_y, strides_x] - vector strides = getFromArrayAttr(op.getStride()); + vector strides = getFromArrayI64(op.getStride()); // dilations = [dilations_y, dilations_x] - vector dilations = getFromArrayAttr(op.getDilation()); + vector dilations = getFromArrayI64(op.getDilation()); auto elemTy = getElemTy(op.getResult()); if (!elemTy.isa()) @@ -1409,9 +1420,9 @@ void encodeOp(State &st, mlir::tosa::Conv2DOp op, bool) { // bias: a 1-dim array whose size is F auto bias = st.regs.get(op.getBias()); // strides = [strides_y, strides_x] - vector strides = getFromArrayAttr(op.getStride()); + vector strides = getFromArrayI64(op.getStride()); // dilations = [dilations_y, dilations_x] - vector dilations = getFromArrayAttr(op.getDilation()); + vector dilations = getFromArrayI64(op.getDilation()); // Check whether C is identical st.wellDefined(op, input.getDim(3) == weight.getDim(3), @@ -1559,9 +1570,9 @@ void encodeOp(State &st, mlir::tosa::GatherOp op, bool) { template<> void encodeOp(State &st, mlir::tosa::AvgPool2dOp op, bool) { auto input = st.regs.get(op.getInput()); - auto kernelDims = getFromArrayAttr(op.getKernel()); - auto paddings = getFromArrayAttr(op.getPad()); - auto strides = getFromArrayAttr(op.getStride()); + auto kernelDims = getFromArrayI64(op.getKernel()); + auto paddings = getFromArrayI64(op.getPad()); + auto strides = getFromArrayI64(op.getStride()); if (!input.getElemType().isa()) { throw UnsupportedException(op.getOperation(), @@ -1578,6 +1589,8 @@ void encodeOp(State &st, mlir::tosa::AvgPool2dOp op, bool) { "Zero-padded pooling is supported only."); } + // TODO: The current modeling ignores the acc_type attribute. + auto result = input.avgPool(kernelDims, strides); st.regs.add(op.getResult(), move(result)); st.wellDefined(op, input.isFullyInitialized(), "source tensor initialized"); @@ -1586,9 +1599,9 @@ void encodeOp(State &st, mlir::tosa::AvgPool2dOp op, bool) { template<> void encodeOp(State &st, mlir::tosa::MaxPool2dOp op, bool) { auto input = st.regs.get(op.getInput()); - auto kernelDims = getFromArrayAttr(op.getKernel()); - auto paddings = getFromArrayAttr(op.getPad()); - auto strides = getFromArrayAttr(op.getStride()); + auto kernelDims = getFromArrayI64(op.getKernel()); + auto paddings = getFromArrayI64(op.getPad()); + auto strides = getFromArrayI64(op.getStride()); if (!input.getElemType().isa()) { throw UnsupportedException(op.getOperation(), @@ -1695,15 +1708,15 @@ template static void encodeConv(State &st, T op, ShapedValue::ConvLayout clayout) { vector strides, dilations; // TODO: The result may not fit in Index::BITS - for (auto s: op.strides()) + for (auto s: op.getStrides()) strides.push_back(Index(s.getSExtValue())); - for (auto d: op.dilations()) + for (auto d: op.getDilations()) dilations.push_back(Index(d.getSExtValue())); if (op.hasTensorSemantics()) { auto t_input = st.regs.get(op.image()); auto t_filter = st.regs.get(op.filter()); - auto output = st.regs.get(op.outputs()[0]); + auto output = st.regs.get(op.getOutputs()[0]); auto t_res = t_input .conv(t_filter, strides, dilations, clayout, output); @@ -1712,11 +1725,11 @@ static void encodeConv(State &st, T op, ShapedValue::ConvLayout clayout) { st.wellDefined(op, t_filter.isFullyInitialized(), "filter is initialized"); st.wellDefined(op, output.isFullyInitialized(), "output is initialized"); } else { - auto outputTy = op.outputs()[0].getType().template cast(); + auto outputTy = op.getOutputs()[0].getType().template cast(); auto elemTy = outputTy.getElementType(); auto input = st.regs.get(op.image()); auto filter = st.regs.get(op.filter()); - MemRef output = st.regs.get(op.outputs()[0]); + MemRef output = st.regs.get(op.getOutputs()[0]); if (!output.isIdentityMap()) throw UnsupportedException(op.getOperation(), @@ -1765,15 +1778,15 @@ encodeOp(State &st, mlir::linalg::DepthwiseConv2DNhwcHwcmOp op, vector strides, dilations; - for (auto s: op.strides()) + for (auto s: op.getStrides()) strides.push_back(Index(s.getSExtValue())); - for (auto d: op.dilations()) + for (auto d: op.getDilations()) dilations.push_back(Index(d.getSExtValue())); if (op.hasTensorSemantics()) { auto t_input = st.regs.get(op.image()); auto t_filter = st.regs.get(op.filter()); - auto t_output = st.regs.get(op.outputs()[0]); + auto t_output = st.regs.get(op.getOutputs()[0]); auto t_res = t_input.depthwiseConv2D(t_filter, strides, dilations, /* bias */ nullopt, /* output */ t_output); @@ -1784,10 +1797,10 @@ encodeOp(State &st, mlir::linalg::DepthwiseConv2DNhwcHwcmOp op, } else { auto mi = st.regs.get(op.image()); auto mf = st.regs.get(op.filter()); - auto mo = st.regs.get(op.outputs()[0]); + auto mo = st.regs.get(op.getOutputs()[0]); auto iTy = op.image().getType().cast(); auto fTy = op.filter().getType().cast(); - auto oTy = op.outputs()[0].getType().cast(); + auto oTy = op.getOutputs()[0].getType().cast(); Tensor t_input = loadTensor(st, op, mi, iTy); Tensor t_filter = loadTensor(st, op, mf, fTy); Tensor t_output = loadTensor(st, op, mo, oTy); @@ -1815,32 +1828,6 @@ encodeOp(State &st, mlir::linalg::Conv2DNhwcHwcfOp op, bool encodeMemWriteOp) { encodeConv(st, op, ShapedValue::ConvLayout::NHWC_HWCF); } -template<> -void encodeOp(State &st, mlir::linalg::InitTensorOp op, bool) { - auto res = op.getResult(); - auto ty = res.getType().dyn_cast(); - if (!ty || !Tensor::isTypeSupported(ty)) - throw UnsupportedException(op.getOperation(), "Unsupported tensor type"); - - vector sizes; - if (ty.getRank() == 0) { - sizes.push_back(Index(1)); - } else { - for (unsigned i = 0; i < ty.getRank(); ++i) { - if (op.isDynamicSize(i)) - sizes.push_back(st.regs.get(op.getDynamicSize(i))); - else - sizes.push_back(Index(op.getStaticSize(i))); - } - } - - // FIXME: can we use res's name? - static int new_var_idx = 0; - st.regs.add(res, - Tensor::var(ty.getElementType(), - ("init_tensor#") + to_string(new_var_idx++), sizes, false)); -} - template<> void encodeOp(State &st, mlir::tensor::CollapseShapeOp op, bool) { Tensor t = st.regs.get(op.getOperand()); @@ -1859,7 +1846,7 @@ void encodeOp(State &st, mlir::tensor::CollapseShapeOp op, bool) { for (auto &idx: reassocExprs[i]) size = size * t.getDim(idx); - if (resTy.getDimSize(i) != mlir::ShapedType::kDynamicSize) + if (resTy.getDimSize(i) != mlir::ShapedType::kDynamic) st.wellDefined(op, size == resTy.getDimSize(i), "size check"); newDims.push_back(move(size)); @@ -1891,7 +1878,7 @@ void encodeOp(State &st, mlir::tensor::ExpandShapeOp op, bool) { int unknown_dim = -1; int64_t const_size = 1; for (auto id: ids) { - if (op.getResultType().getDimSize(id) == mlir::ShapedType::kDynamicSize) { + if (op.getResultType().getDimSize(id) == mlir::ShapedType::kDynamic) { if (unknown_dim != -1) throw UnsupportedException(op.getOperation(), "it has more than one unknown dimension size in one group"); @@ -1929,7 +1916,7 @@ void encodeOp(State &st, mlir::linalg::MatmulOp op, bool encodeMemWriteOp) { throw UnsupportedException(op.getOperation(), "We do not support memory writes in this scope"); - if (op.getNumInputs() != 2 || op.getNumOutputs() != 1) + if (op.getInputs().size() != 2 || op.getOutputs().size() != 1) throw UnsupportedException(op.getOperation(), "unsupported form"); @@ -2043,8 +2030,8 @@ void encodeOp(State &st, mlir::tensor::PadOp op, bool) { template static void encodeLinalgPooling(State &st, T op) { - mlir::DenseIntElementsAttr strideAttr = op.strides(); - mlir::DenseIntElementsAttr dilationAttr = op.dilations(); + mlir::DenseIntElementsAttr strideAttr = op.getStrides(); + mlir::DenseIntElementsAttr dilationAttr = op.getDilations(); if (!strideAttr.isSplat() || !dilationAttr.isSplat()) throw UnsupportedException(op.getOperation(), @@ -2064,10 +2051,10 @@ static void encodeLinalgPooling(State &st, T op) { if (!elemTy.isa()) throw UnsupportedException(op.getOperation(), "Unsupported type"); - vector kernelDims = st.regs.get(op.inputs()[1]).getDims(); + vector kernelDims = st.regs.get(op.getInputs()[1]).getDims(); vector strides = {Index(stride), Index(stride)}; - auto input = st.regs.get(op.inputs()[0]); - auto output = st.regs.get(op.outputs()[0]); + auto input = st.regs.get(op.getInputs()[0]); + auto output = st.regs.get(op.getOutputs()[0]); bool isMaxPool = std::is_same::value; auto result = isMaxPool ? input.maxPool(kernelDims, strides, output) : input.sumPool(kernelDims, strides, output); @@ -2076,18 +2063,18 @@ static void encodeLinalgPooling(State &st, T op) { st.wellDefined(op, input.isFullyInitialized(), "input tensor initialized"); st.wellDefined(op, output.isFullyInitialized(), "output tensor initialized"); } else { - mlir::Type elemTy = op.outputs()[0].getType() + mlir::Type elemTy = op.getOutputs()[0].getType() .template cast() .getElementType(); if (!elemTy.isa()) throw UnsupportedException(op.getOperation(), "Unsupported type"); - vector kernelDims = st.regs.get(op.inputs()[1]).getDims(); + vector kernelDims = st.regs.get(op.getInputs()[1]).getDims(); vector strides = {Index(stride), Index(stride)}; - MemRef minput = st.regs.get(op.inputs()[0]); - MemRef moutput = st.regs.get(op.outputs()[0]); - auto inputTy = op.inputs()[0].getType().template cast(); - auto outputTy = op.outputs()[0].getType().template cast(); + MemRef minput = st.regs.get(op.getInputs()[0]); + MemRef moutput = st.regs.get(op.getOutputs()[0]); + auto inputTy = op.getInputs()[0].getType().template cast(); + auto outputTy = op.getOutputs()[0].getType().template cast(); Tensor input = loadTensor(st, op, minput, inputTy); Tensor output = loadTensor(st, op, moutput, outputTy); @@ -2132,6 +2119,31 @@ void encodeOp(State &st, mlir::tensor::DimOp op, bool) { // DimOp does not look into elements, so initialization check is not necessary } +template <> void encodeOp(State &st, mlir::tensor::EmptyOp op, bool) { + auto res = op.getResult(); + auto ty = res.getType().dyn_cast(); + if (!ty || !Tensor::isTypeSupported(ty)) + throw UnsupportedException(op.getOperation(), "Unsupported tensor type"); + + vector sizes; + if (ty.getRank() == 0) { + sizes.push_back(Index(1)); + } else { + for (unsigned i = 0; i < ty.getRank(); ++i) { + if (ty.isDynamicDim(i)) + sizes.push_back(st.regs.get(op.getDynamicSize(i))); + else + sizes.push_back(Index(ty.getDimSize(i))); + } + } + + // FIXME: can we use res's name? + static int new_var_idx = 0; + st.regs.add(res, Tensor::var(ty.getElementType(), + ("init_tensor#") + to_string(new_var_idx++), + sizes, false)); +} + template<> void encodeOp(State &st, mlir::tensor::CastOp op, bool) { auto tty = op.getType().dyn_cast(); @@ -2196,7 +2208,7 @@ void encodeOp(State &st, mlir::tensor::GenerateOp op, bool) { int j = 0; for (int i = 0; i < retty.getRank(); ++i) { auto d = retty.getDimSize(i); - if (d == mlir::ShapedType::kDynamicSize) { + if (d == mlir::ShapedType::kDynamic) { auto newd = exts[j++]; upperbound.push_back(st.regs.get(newd).ofs(-1)); } else { @@ -2539,11 +2551,10 @@ void encodeOp(State &st, mlir::tosa::ReshapeOp op, bool) { vector newDims; mlir::Operation *oper = op.getOperation(); - for (auto a: attrs) { - auto ia = a.cast(); - if (ia.getInt() == -1) + for (auto ia: attrs) { + if (ia == -1) throw UnsupportedException(oper, "Dynamic shape is unsupported"); - newDims.push_back(Index(ia.getInt())); + newDims.push_back(Index(ia)); } st.wellDefined(oper, t.get1DSize() == smt::get1DSize(newDims)); st.regs.add(op.getResult(), t.reshape(newDims)); @@ -2576,7 +2587,7 @@ static void encodeAllocLikeOp(State &st, T op) { throw UnsupportedException(op.getOperation(), "unsupported memref type for alloc: it has a non-identity layout map"); - auto dsizes = op.dynamicSizes(); + auto dsizes = op.getDynamicSizes(); vector dszExprs; for (const auto &sz: dsizes) { dszExprs.push_back(st.regs.get(sz)); @@ -2601,7 +2612,7 @@ void encodeOp(State &st, mlir::memref::AllocaOp op, bool) { template<> void encodeOp(State &st, mlir::memref::DimOp op, bool) { auto [res, wf] = encodeDimOp( - st, st.regs.get(op.source()).getDims(), op.index()); + st, st.regs.get(op.getSource()).getDims(), op.getIndex()); st.regs.add(op, Index(res)); st.wellDefined(op, move(wf)); } @@ -2612,7 +2623,7 @@ void encodeOp(State &st, mlir::memref::LoadOp op, bool) { // out-of-bounds. It is currently encoded as UB. auto m = st.regs.get(op.getOperand(0)); vector indices; - for (auto idx0: op.indices()) + for (auto idx0: op.getIndices()) indices.emplace_back(st.regs.get(idx0)); auto [val, info] = m.getWithAccessInfo(indices); @@ -2625,7 +2636,7 @@ void encodeOp(State &st, mlir::memref::LoadOp op, bool) { template<> void encodeOp(State &st, mlir::memref::GetGlobalOp op, bool encodeMemWriteOp) { - auto name = op.name().str(); + auto name = op.getName().str(); auto bid = Expr::mkBV(st.m->getBidForGlobalVar(name), st.m->getBIDBits()); auto type = op.getType(); assert(type.getLayout().isIdentity() && @@ -2648,7 +2659,7 @@ void encodeOp(State &st, mlir::memref::StoreOp op, bool encodeMemWriteOp) { // out-of-bounds. It is currently encoded as UB. auto m = st.regs.get(op.getOperand(1)); vector indices; - for (auto idx0: op.indices()) + for (auto idx0: op.getIndices()) indices.emplace_back(st.regs.get(idx0)); auto value = op.getOperand(0); @@ -2675,16 +2686,14 @@ void encodeOp(State &st, mlir::memref::SubViewOp op, bool) { ADD(strides, Stride); #undef ADD } - auto src = st.regs.get(op.source()); + auto src = st.regs.get(op.getSource()); int rankDiff = op.getSourceType().getRank() - op.getType().getRank(); assert(rankDiff >= 0); // only reducing rank is allowed // This reduction logic mainly from MLIR SubViewOp verify function. // See 'Dialect/MemRef/IR/MemRefOps.cpp'. auto expectedType = mlir::memref::SubViewOp::inferResultType( - op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), - extractFromI64ArrayAttr(op.static_sizes()), - extractFromI64ArrayAttr(op.static_strides())); + op.getSourceType(), op.static_offsets(), op.static_sizes(), op.static_strides()); auto originalShapedType = expectedType.cast(); auto candidateReducedShapedType = op.getType().cast(); @@ -2788,8 +2797,8 @@ void encodeOp(State &st, mlir::memref::TensorStoreOp op, bool encodeMemWrite) { throw UnsupportedException(op.getOperation(), "We do not support memory writes in this scope"); - auto t = st.regs.get(op.tensor()); - auto m = st.regs.get(op.memref()); + auto t = st.regs.get(op.getTensor()); + auto m = st.regs.get(op.getMemref()); // Src and tgt's shapes & element types must match // Memref may have its layout, though. @@ -2797,19 +2806,19 @@ void encodeOp(State &st, mlir::memref::TensorStoreOp op, bool encodeMemWrite) { st.wellDefined(op, (Expr)t.getDim(i) == (Expr)m.getDim(i)); storeTensorTo(st, op.getOperation(), move(t), m, - op.memref().getType().cast(), true); + op.getMemref().getType().cast(), true); } template<> void encodeOp(State &st, mlir::memref::ExpandShapeOp op, bool encodeMemWrite) { - auto srcType = op.src().getType().cast(); - auto resType = op.result().getType().cast(); + auto srcType = op.getSrc().getType().cast(); + auto resType = op.getResult().getType().cast(); if (!srcType.getLayout().isIdentity() || !resType.getLayout().isIdentity()) throw UnsupportedException(op.getOperation(), "We do not support non-identity layout memref"); - MemRef m = st.regs.get(op.src()); + MemRef m = st.regs.get(op.getSrc()); // The fresh variables created by ShapedValue::getDims will be ignored // by the for loop below. auto newdims = ShapedValue::getDims(op.getResultType(), true); @@ -2824,7 +2833,7 @@ void encodeOp(State &st, mlir::memref::ExpandShapeOp op, bool encodeMemWrite) { int unknown_dim = -1; int64_t const_size = 1; for (auto id: ids) { - if (op.getResultType().getDimSize(id) == mlir::ShapedType::kDynamicSize) { + if (op.getResultType().getDimSize(id) == mlir::ShapedType::kDynamic) { if (unknown_dim != -1) throw UnsupportedException(op.getOperation(), "it has more than one unknown dimension size in one group"); @@ -2855,8 +2864,8 @@ void encodeOp(State &st, mlir::memref::ExpandShapeOp op, bool encodeMemWrite) { template<> void encodeOp(State &st, mlir::memref::CollapseShapeOp op, bool) { - auto srcType = op.src().getType().cast(); - auto resType = op.result().getType().cast(); + auto srcType = op.getSrc().getType().cast(); + auto resType = op.getResult().getType().cast(); if (!srcType.getLayout().isIdentity() || !resType.getLayout().isIdentity()) throw UnsupportedException(op.getOperation(), @@ -2878,7 +2887,7 @@ void encodeOp(State &st, mlir::memref::CollapseShapeOp op, bool) { for (auto &idx: reassocExprs[i]) size = size * m.getDim(idx); - if (resTy.getDimSize(i) != mlir::ShapedType::kDynamicSize) + if (resTy.getDimSize(i) != mlir::ShapedType::kDynamic) st.wellDefined(op, size == resTy.getDimSize(i), "size check"); newDims.push_back(move(size)); @@ -2950,8 +2959,8 @@ void encodeOp(State &st, mlir::linalg::DotOp op, bool encodeMemWrite) { throw UnsupportedException(op.getOperation(), "tensor semantics is supported only"); - auto inputOps = op.getInputOperands(); - auto outputOps = op.getOutputOperands(); + auto inputOps = op.getInputs(); + auto outputOps = op.getOutputs(); auto outputTy = op.getType(0).dyn_cast(); // This must be same. @@ -2967,13 +2976,13 @@ void encodeOp(State &st, mlir::linalg::DotOp op, bool encodeMemWrite) { "unknown dot format; shouldn't the result tensor have one element?"); if (outputTy.getElementType() != - inputOps[0]->get().getType().dyn_cast() + inputOps[0].getType().dyn_cast() .getElementType()) throw UnsupportedException(op.getOperation(), "casting is not supported"); - auto t1 = st.regs.get(inputOps[0]->get()); - auto t2 = st.regs.get(inputOps[1]->get()); - auto t3 = st.regs.get(outputOps[0]->get()); + auto t1 = st.regs.get(inputOps[0]); + auto t2 = st.regs.get(inputOps[1]); + auto t3 = st.regs.get(outputOps[0]); st.wellDefined(op, t1.isFullyInitialized()); st.wellDefined(op, t2.isFullyInitialized()); st.wellDefined(op, t3.isFullyInitialized()); @@ -3009,31 +3018,28 @@ void encodeOp(State &st, mlir::sparse_tensor::ConvertOp op, bool) { vector findLoopBounds(State &st, mlir::linalg::GenericOp op) { // The size of the loop is calculated (analogous to what // LinalgOp::createLoopRanges does). - // The process of getting the size of the loop seems fishy; // LinalgOp::createLoopRanges relies on the "first" dimension that is - // matched, and it isn't clear what happens if there are multiple matching - // dimensions. For example, + // matched. If there are multiple matching dimensions, for example: // linalg.generic { // indexing_maps = [affine_map<(n) -> (n)>, // affine_map<(n) -> (n)>, // affine_map<(n) -> (n)>] } // ins(%A, %B: , ) outs(%C: ) { .. } - // The size of the loop is either %A, %B, or %C's dimension, but the current - // algorithm mandates the result to be %A's dimension. + // The current algorithm mandates the result to be %A's dimension. vector viewSizes; - for (auto *opOperand : op.getInputAndOutputOperands()) { - unsigned r = op.getRank(opOperand); + for (auto &opOperand : op.getOperation()->getOpOperands()) { + unsigned r = op.getRank(&opOperand); if (!r) continue; - if (opOperand->get().getType().isa()) { - auto t = st.regs.get(opOperand->get()); + if (opOperand.get().getType().isa()) { + auto t = st.regs.get(opOperand.get()); for (int64_t i = 0, e = r; i < e; ++i) { viewSizes.push_back(t.getDim(i)); } - } else if (opOperand->get().getType().isa()) { - auto t = st.regs.get(opOperand->get()); + } else if (opOperand.get().getType().isa()) { + auto t = st.regs.get(opOperand.get()); for (int64_t i = 0, e = r; i < e; ++i) { viewSizes.push_back(t.getDim(i)); } @@ -3095,7 +3101,8 @@ encodeUBForTensorShapeMatch(State &st, mlir::linalg::GenericOp op, unsigned numRes = map.getNumResults(); vector viewSizes; - for (auto *opOperand : op.getInputAndOutputOperands()) { + for (auto &oo : op.getOperation()->getOpOperands()) { + auto *opOperand = &oo; unsigned r = op.getRank(opOperand); if (!r) continue; @@ -3129,27 +3136,25 @@ static void initInputStateForLoopBody( State &st, mlir::linalg::GenericOp op, map &welldefs, bool isParallelLoop) { - auto indexingMaps = op.indexing_maps().getValue(); - auto &block = *op.region().begin(); + auto indexingMaps = op.getIndexingMaps().getValue(); + auto &block = *op.getRegion().begin(); const vector &inductionVars = st.linalgGenericScopes.top().indVars; - assert(op.getInputOperands().size() + op.getNumOutputs() == - indexingMaps.size()); - assert(op.getNumInputs() == op.getInputOperands().size()); + auto nInputs = op.getInputs().size(); + auto nOutputs = op.getOutputs().size(); + assert(nInputs + nOutputs == indexingMaps.size()); // The output variables contain the initial value of the tensor // (see github issue #164) // For parallel loops: whole iterations contain the initial value // For reduction loops: only the first iteration contains the value - size_t upperbound = op.getNumInputs() + op.getNumOutputs(); + size_t upperbound = nInputs + nOutputs; for (size_t arg_i = 0; arg_i < upperbound; ++arg_i) { auto indexMap = indexingMaps[arg_i].cast().getValue(); - mlir::Value op_i = arg_i >= op.getNumInputs() ? - op.getOutputOperand(arg_i - op.getNumInputs())->get() : - op.getInputOperand(arg_i)->get(); - bool isInput = arg_i < op.getNumInputs(); + mlir::Value op_i = op->getOperand(arg_i); + bool isInput = arg_i < nInputs; bool isOutputAndHasUse = !isInput && !block.getArgument(arg_i).use_empty(); if (op_i.getType().isa()) { @@ -3417,7 +3422,7 @@ void encodeOp(State &st, mlir::linalg::GenericOp op, bool encodeMemWriteOp) { throw UnsupportedException(op.getOperation(), "We do not support memory writes in this scope"); - auto ®ion = op.region(); + auto ®ion = op.getRegion(); if (!llvm::hasSingleElement(region)) throw UnsupportedException(op.getOperation(), "a single block is supported only"); @@ -3428,11 +3433,9 @@ void encodeOp(State &st, mlir::linalg::GenericOp op, bool encodeMemWriteOp) { throw UnsupportedException(op.getOperation(), "unsupported block arguments"); - if (llvm::any_of(op.iterator_types(), [](mlir::Attribute attr) { - auto str = attr.cast().getValue(); - return str != mlir::getParallelIteratorTypeName() && - str != mlir::getReductionIteratorTypeName() && - str != mlir::getWindowIteratorTypeName(); + if (llvm::any_of(op.getIteratorTypesArray(), [](auto itrty) { + return itrty != mlir::utils::IteratorType::parallel && + itrty != mlir::utils::IteratorType::reduction; })) throw UnsupportedException(op.getOperation(), "unsupported iterator type"); @@ -3450,7 +3453,7 @@ void encodeOp(State &st, mlir::linalg::GenericOp op, bool encodeMemWriteOp) { State newst = st; newst.linalgGenericScopes.push(State::LinalgGenericScope{loopBounds}); - auto indexingMaps = op.indexing_maps().getValue(); + auto indexingMaps = op.getIndexingMaps().getValue(); auto outputMap = indexingMaps.back().cast().getValue(); bool isParallelLoop = outputMap.isPermutation(); @@ -3467,13 +3470,12 @@ void encodeOp(State &st, mlir::linalg::GenericOp op, bool encodeMemWriteOp) { } else { // Reduction loops returning multiple values is not supported by MLIR-TV // yet. - if (op.getNumOutputs() > 1) + if (op.getOutputs().size() > 1) throw UnsupportedException(op.getOperation(), "unsupported reduction form"); optional t_res; - auto outputType = op.getOutputOperand(0)->get().getType() - .cast(); + auto outputType = op.getOutputs().front().getType().cast(); // Reduction loops returning memref is not supported by MLIR-TV yet. if (outputType.isa()) throw UnsupportedException(op.getOperation(), @@ -3510,28 +3512,33 @@ void encodeOp(State &st, mlir::linalg::GenericOp op, bool encodeMemWriteOp) { st.regs.add(op.getResult(i), move(tvec_res->at(i))); } } else if (op.hasBufferSemantics()) { - for(unsigned i = 0; i < tvec_res->size(); i++) { - auto opi = op.getOutputOperand(i)->get(); + unsigned i = 0; + assert(op.getOutputs().size() == tvec_res->size()); + + for(auto opi: op.getOutputs()) { + // unsigned i = 0; i < tvec_res->size(); i++ auto m_res = st.regs.get(opi); storeTensorTo(st, op, move(tvec_res->at(i)), m_res, opi.getType().cast(), true); // Noalias with input operands - for (unsigned j = 0; j < op.getNumInputs(); j ++) { - auto opj = op.getInputOperand(j)->get(); + for (auto opj: op.getInputs()) { if (!opj.getType().isa()) continue; auto input = st.regs.get(opj); st.wellDefined(op, input.noalias(m_res)); } // Noalias with other output operands - for (unsigned j = 0; j < i; j ++) { - auto opj = op.getOutputOperand(j)->get(); + unsigned j = 0; + for (auto opj: op.getOutputs()) { + if (j >= i) break; if (!opj.getType().isa()) continue; auto output = st.regs.get(opj); st.wellDefined(op, output.noalias(m_res)); + ++j; } + ++i; } } else { llvm_unreachable("Unknown linalg::generic semantics"); @@ -3602,8 +3609,7 @@ static void encodeBlock( if (checkBeforeEnc && checkBeforeEnc(&op, index)) continue; // Encode ops. Alphabetically sorted. - ENCODE(st, op, mlir::AffineApplyOp, encodeMemWriteOps); - ENCODE(st, op, mlir::arith::SelectOp, encodeMemWriteOps); + ENCODE(st, op, mlir::affine::AffineApplyOp, encodeMemWriteOps); ENCODE(st, op, mlir::arith::AddFOp, encodeMemWriteOps); ENCODE(st, op, mlir::arith::AddIOp, encodeMemWriteOps); @@ -3621,6 +3627,7 @@ static void encodeBlock( ENCODE(st, op, mlir::arith::MulFOp, encodeMemWriteOps); ENCODE(st, op, mlir::arith::MulIOp, encodeMemWriteOps); ENCODE(st, op, mlir::arith::NegFOp, encodeMemWriteOps); + ENCODE(st, op, mlir::arith::SelectOp, encodeMemWriteOps); ENCODE(st, op, mlir::arith::ShLIOp, encodeMemWriteOps); ENCODE(st, op, mlir::arith::ShRSIOp, encodeMemWriteOps); ENCODE(st, op, mlir::arith::ShRUIOp, encodeMemWriteOps); @@ -3645,6 +3652,7 @@ static void encodeBlock( ENCODE(st, op, mlir::memref::AllocOp, encodeMemWriteOps); ENCODE(st, op, mlir::memref::AllocaOp, encodeMemWriteOps); ENCODE(st, op, mlir::memref::CollapseShapeOp, encodeMemWriteOps); + ENCODE(st, op, mlir::memref::CopyOp, encodeMemWriteOps); ENCODE(st, op, mlir::memref::DeallocOp, encodeMemWriteOps); ENCODE(st, op, mlir::memref::DimOp, encodeMemWriteOps); ENCODE(st, op, mlir::memref::ExpandShapeOp, encodeMemWriteOps); @@ -3657,14 +3665,11 @@ static void encodeBlock( ENCODE(st, op, mlir::linalg::DepthwiseConv2DNhwcHwcmOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::Conv2DNchwFchwOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::Conv2DNhwcHwcfOp, encodeMemWriteOps); - ENCODE(st, op, mlir::memref::CopyOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::DotOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::FillOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::GenericOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::IndexOp, encodeMemWriteOps); - ENCODE(st, op, mlir::linalg::InitTensorOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::MatmulOp, encodeMemWriteOps); - ENCODE(st, op, mlir::tensor::PadOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::PoolingNhwcMaxOp, encodeMemWriteOps); ENCODE(st, op, mlir::linalg::PoolingNhwcSumOp, encodeMemWriteOps); @@ -3676,6 +3681,7 @@ static void encodeBlock( ENCODE(st, op, mlir::tensor::CastOp, encodeMemWriteOps); ENCODE(st, op, mlir::tensor::CollapseShapeOp, encodeMemWriteOps); ENCODE(st, op, mlir::tensor::DimOp, encodeMemWriteOps); + ENCODE(st, op, mlir::tensor::EmptyOp, encodeMemWriteOps); ENCODE(st, op, mlir::tensor::ExpandShapeOp, encodeMemWriteOps); ENCODE(st, op, mlir::tensor::InsertOp, encodeMemWriteOps); ENCODE(st, op, mlir::tensor::ExtractOp, encodeMemWriteOps); @@ -3683,6 +3689,7 @@ static void encodeBlock( ENCODE(st, op, mlir::tensor::FromElementsOp, encodeMemWriteOps); ENCODE(st, op, mlir::tensor::GenerateOp, encodeMemWriteOps); ENCODE(st, op, mlir::tensor::InsertSliceOp, encodeMemWriteOps); + ENCODE(st, op, mlir::tensor::PadOp, encodeMemWriteOps); ENCODE(st, op, mlir::tosa::AbsOp, encodeMemWriteOps); ENCODE(st, op, mlir::tosa::AddOp, encodeMemWriteOps); diff --git a/src/main.cpp b/src/main.cpp index abaea4d7..5a467224 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -9,7 +9,7 @@ #include "llvm/Support/Signals.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -110,8 +110,8 @@ int main(int argc, char* argv[]) { DialectRegistry registry; // NOTE: we cannot use mlir::registerAllDialects because IREE does not have // dependency on some of those dialects - registry.insert(); - registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/src/memory.cpp b/src/memory.cpp index 813d24f4..5d2dbd80 100644 --- a/src/memory.cpp +++ b/src/memory.cpp @@ -125,7 +125,7 @@ Memory::Memory(const TypeMap &numGlobalBlocksPerType, vector globalsForTy; for (auto glb: globals) { - if (glb.type().getElementType() == elemTy) + if (glb.getType().getElementType() == elemTy) globalsForTy.push_back(glb); } @@ -145,18 +145,19 @@ Memory::Memory(const TypeMap &numGlobalBlocksPerType, verbose("memory init") << "Assigning bid = " << i << " to global var " << glb.getName() << "...\n"; - if (glb.constant()) { - auto tensorTy = mlir::RankedTensorType::get(glb.type().getShape(), - glb.type().getElementType()); - Tensor t = Tensor::fromElemsAttr(tensorTy, *glb.initial_value()); + if (glb.getConstant()) { + auto tensorTy = mlir::RankedTensorType::get(glb.getType().getShape(), + glb.getType().getElementType()); + Tensor t = Tensor::fromElemsAttr( + tensorTy, glb.getInitialValue()->cast()); newArrs.push_back(t.asArray()); } else { string name = "#" + glb.getName().str() + "_array"; newArrs.push_back(Expr::mkFreshVar(arrSort, name)); } newInits.push_back(Expr::mkSplatArray(Index::sort(), Expr::mkBool(true))); - newWrit.push_back(Expr::mkBool(!glb.constant())); - newNumElems.push_back(Index(glb.type().getNumElements())); + newWrit.push_back(Expr::mkBool(!glb.getConstant())); + newNumElems.push_back(Index(glb.getType().getNumElements())); newLiveness.push_back(Expr::mkBool(true)); newCreatedByAllocs.push_back(Expr::mkBool(false)); } diff --git a/src/value.cpp b/src/value.cpp index 130ba1e2..2afe6237 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -77,8 +77,8 @@ vector ShapedValue::getDims( dims.reserve(rank); unsigned unknownVarIdx = 0; for (unsigned i = 0; i < rank; ++i) { - uint64_t sz = shapedTy.getDimSize(i); - if (sz == (uint64_t)-1ull) { + int64_t sz = shapedTy.getDimSize(i); + if (sz == mlir::ShapedType::kDynamic) { if (freshVarForUnknownSize) { dims.emplace_back(Index::var("dim", VarType::FRESH)); } else if (valsForUnknownSz) { @@ -87,8 +87,10 @@ vector ShapedValue::getDims( llvm_unreachable("Don't know what to do with a dimension of " "an unknown size"); } - } else - dims.push_back(Index(sz)); + } else { + assert(sz >= 0); + dims.push_back(Index((uint64_t)sz)); + } } return dims; @@ -1336,7 +1338,7 @@ Tensor Tensor::fromElemsAttr(mlir::RankedTensorType tensorty, int64_t totalSize = 1; for (int i = 0; i < rank; ++i) { auto dsize = tensorty.getDimSize(i); - assert(dsize != mlir::ShapedType::kDynamicSize); + assert(dsize != mlir::ShapedType::kDynamic); dims.push_back(dsize); dimExprs.push_back(Index(dsize)); totalSize *= dsize; @@ -1588,7 +1590,7 @@ MemRef::Layout MemRef::getLayout( return MemRef::Layout(dims); auto getConstOrFreshVar = [](int64_t val, string &&name) -> Expr { - return (val == mlir::ShapedType::kDynamicStrideOrOffset) ? + return (val == mlir::ShapedType::kDynamic) ? Index::var(move(name), VarType::FRESH) : Index(val); }; diff --git a/src/vcgen.cpp b/src/vcgen.cpp index c50a3ae4..41880f46 100644 --- a/src/vcgen.cpp +++ b/src/vcgen.cpp @@ -713,15 +713,15 @@ static vector mergeGlobals( } auto glbTgt = tgtItr->second; - if (glbSrc.type() != glbTgt.type() || + if (glbSrc.getType() != glbTgt.getType() || glbSrc.isPrivate() != glbTgt.isPrivate() || - glbSrc.constant() != glbTgt.constant() || - glbSrc.initial_value() != glbTgt.initial_value()) { + glbSrc.getConstant() != glbTgt.getConstant() || + glbSrc.getInitialValue() != glbTgt.getInitialValue()) { throw UnsupportedException( name + " has different signatures in src and tgt"); } - assert(glbSrc.type().hasStaticShape() && + assert(glbSrc.getType().hasStaticShape() && "Global var must be statically shaped"); mergedGlbs.push_back(glbSrc); @@ -731,7 +731,7 @@ static vector mergeGlobals( auto glbTgt = glbTgt0; auto tgtItr = srcGlobals.find(name); if (tgtItr == srcGlobals.end()) { - if (glbTgt.constant()) { + if (glbTgt.getConstant()) { mergedGlbs.push_back(glbTgt); } else throw UnsupportedException("Introducing new non-const globals " diff --git a/tests/litmus/abstraction/dot.src.mlir b/tests/litmus/abstraction/dot.src.mlir index 2e592e12..1764336f 100644 --- a/tests/litmus/abstraction/dot.src.mlir +++ b/tests/litmus/abstraction/dot.src.mlir @@ -2,7 +2,7 @@ func.func @f(%a: tensor, %b: tensor) -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %e = linalg.dot ins(%a, %b : tensor,tensor) outs(%outty: tensor) -> tensor diff --git a/tests/litmus/abstraction/dot.tgt.mlir b/tests/litmus/abstraction/dot.tgt.mlir index c443d09b..8de3f481 100644 --- a/tests/litmus/abstraction/dot.tgt.mlir +++ b/tests/litmus/abstraction/dot.tgt.mlir @@ -1,5 +1,5 @@ func.func @f(%a: tensor, %b: tensor) -> tensor { - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %zero = arith.constant -0.0 : f32 %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %result = linalg.generic { diff --git a/tests/litmus/abstraction/dot_concat.src.mlir b/tests/litmus/abstraction/dot_concat.src.mlir index 6003a160..1ce27659 100644 --- a/tests/litmus/abstraction/dot_concat.src.mlir +++ b/tests/litmus/abstraction/dot_concat.src.mlir @@ -4,7 +4,7 @@ // dot (A, B) + dot(C, D) → dot(A::C, B::D) func.func @f(%a: tensor<5xf32>, %b: tensor<5xf32>, %c: tensor<5xf32>, %d: tensor<5xf32>) -> f32 { %identity = arith.constant -0.0 : f32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%identity: f32) outs(%i: tensor) -> tensor %rt1 = linalg.dot ins(%a, %b : tensor<5xf32>, tensor<5xf32>) outs(%outty: tensor) -> tensor diff --git a/tests/litmus/abstraction/dot_concat.tgt.mlir b/tests/litmus/abstraction/dot_concat.tgt.mlir index 1e4ccf60..ed645127 100644 --- a/tests/litmus/abstraction/dot_concat.tgt.mlir +++ b/tests/litmus/abstraction/dot_concat.tgt.mlir @@ -1,6 +1,6 @@ func.func @f(%a: tensor<5xf32>, %b: tensor<5xf32>, %c: tensor<5xf32>, %d: tensor<5xf32>) -> f32 { %identity = arith.constant -0.0 : f32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%identity: f32) outs(%i: tensor) -> tensor %ca = "tosa.concat"(%a, %c) {axis = 0: i64}: (tensor<5xf32>, tensor<5xf32>) -> tensor<10xf32> diff --git a/tests/litmus/abstraction/dot_concat_multiset.src.mlir b/tests/litmus/abstraction/dot_concat_multiset.src.mlir index b2bbd0de..fd2d72fc 100644 --- a/tests/litmus/abstraction/dot_concat_multiset.src.mlir +++ b/tests/litmus/abstraction/dot_concat_multiset.src.mlir @@ -4,7 +4,7 @@ // dot (A, B) + dot(C, D) → dot(A::C, B::D) func.func @f(%a: tensor<5xf32>, %b: tensor<5xf32>, %c: tensor<5xf32>, %d: tensor<5xf32>) -> f32 { %identity = arith.constant -0.0 : f32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%identity: f32) outs(%i: tensor) -> tensor %rt1 = linalg.dot ins(%a, %b : tensor<5xf32>, tensor<5xf32>) outs(%outty: tensor) -> tensor diff --git a/tests/litmus/abstraction/dot_concat_multiset.tgt.mlir b/tests/litmus/abstraction/dot_concat_multiset.tgt.mlir index 1e4ccf60..ed645127 100644 --- a/tests/litmus/abstraction/dot_concat_multiset.tgt.mlir +++ b/tests/litmus/abstraction/dot_concat_multiset.tgt.mlir @@ -1,6 +1,6 @@ func.func @f(%a: tensor<5xf32>, %b: tensor<5xf32>, %c: tensor<5xf32>, %d: tensor<5xf32>) -> f32 { %identity = arith.constant -0.0 : f32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%identity: f32) outs(%i: tensor) -> tensor %ca = "tosa.concat"(%a, %c) {axis = 0: i64}: (tensor<5xf32>, tensor<5xf32>) -> tensor<10xf32> diff --git a/tests/litmus/abstraction/dotint.src.mlir b/tests/litmus/abstraction/dotint.src.mlir index bc18b132..8daf97dc 100644 --- a/tests/litmus/abstraction/dotint.src.mlir +++ b/tests/litmus/abstraction/dotint.src.mlir @@ -2,7 +2,7 @@ func.func @f(%a: tensor, %b: tensor) -> tensor { %zero = arith.constant 0 : i32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%zero: i32) outs(%i: tensor) -> tensor %e = linalg.dot ins(%a, %b : tensor,tensor) outs(%outty: tensor) -> tensor diff --git a/tests/litmus/abstraction/dotint.tgt.mlir b/tests/litmus/abstraction/dotint.tgt.mlir index 03764cab..4a9af6d1 100644 --- a/tests/litmus/abstraction/dotint.tgt.mlir +++ b/tests/litmus/abstraction/dotint.tgt.mlir @@ -1,5 +1,5 @@ func.func @f(%a: tensor, %b: tensor) -> tensor { - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %zero = arith.constant 0 : i32 %outty = linalg.fill ins(%zero: i32) outs(%i: tensor) -> tensor %result = linalg.generic { diff --git a/tests/litmus/linalg-loops/convert-elementwise-cmpf.tgt.mlir b/tests/litmus/linalg-loops/convert-elementwise-cmpf.tgt.mlir index 127c7012..d30fbfdf 100644 --- a/tests/litmus/linalg-loops/convert-elementwise-cmpf.tgt.mlir +++ b/tests/litmus/linalg-loops/convert-elementwise-cmpf.tgt.mlir @@ -1,7 +1,7 @@ #map = affine_map<() -> ()> module { func.func @cmpf(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty () : tensor %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) { ^bb0(%arg2: f32, %arg3: f32, %arg4: i1): // no predecessors %2 = arith.cmpf olt, %arg2, %arg3 : f32 diff --git a/tests/litmus/linalg-loops/nested.src.mlir b/tests/litmus/linalg-loops/nested.src.mlir index d290c5ef..db567c5b 100644 --- a/tests/litmus/linalg-loops/nested.src.mlir +++ b/tests/litmus/linalg-loops/nested.src.mlir @@ -4,7 +4,7 @@ func.func @dumb_loop(%arg0: tensor) -> tensor { %c0 = arith.constant 0: index %sz = tensor.dim %arg0, %c0: tensor - %outty = linalg.init_tensor [%sz] : tensor + %outty = tensor.empty (%sz) : tensor %res = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0: tensor) diff --git a/tests/litmus/linalg-loops/output-value-bad.src.mlir b/tests/litmus/linalg-loops/output-value-bad.src.mlir index c4643cc9..c7e4f220 100644 --- a/tests/litmus/linalg-loops/output-value-bad.src.mlir +++ b/tests/litmus/linalg-loops/output-value-bad.src.mlir @@ -2,7 +2,7 @@ func.func @f(%arg0: tensor<10x10xi32>) -> tensor<10x10xi32> { %cst = arith.constant 1 : i32 - %init_tensor = linalg.init_tensor [10, 10] : tensor<10x10xi32> + %init_tensor = tensor.empty () : tensor<10x10xi32> %filled = linalg.fill ins(%cst: i32) outs(%init_tensor: tensor<10x10xi32>) -> tensor<10x10xi32> %res = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], diff --git a/tests/litmus/linalg-loops/output-value-bad.tgt.mlir b/tests/litmus/linalg-loops/output-value-bad.tgt.mlir index 799dbde9..0083dc6c 100644 --- a/tests/litmus/linalg-loops/output-value-bad.tgt.mlir +++ b/tests/litmus/linalg-loops/output-value-bad.tgt.mlir @@ -1,6 +1,6 @@ func.func @f(%arg0: tensor<10x10xi32>) -> tensor<10x10xi32> { %not_three= arith.constant 4 : i32 - %init_tensor = linalg.init_tensor [10, 10] : tensor<10x10xi32> + %init_tensor = tensor.empty () : tensor<10x10xi32> %filled = linalg.fill ins(%not_three: i32) outs(%init_tensor: tensor<10x10xi32>) -> tensor<10x10xi32> return %filled : tensor<10x10xi32> } diff --git a/tests/litmus/linalg-loops/output-value.src.mlir b/tests/litmus/linalg-loops/output-value.src.mlir index b1620426..7f35c3b6 100644 --- a/tests/litmus/linalg-loops/output-value.src.mlir +++ b/tests/litmus/linalg-loops/output-value.src.mlir @@ -2,7 +2,7 @@ func.func @f(%arg0: tensor<10x10xi32>) -> tensor<10x10xi32> { %cst = arith.constant 1 : i32 - %init_tensor = linalg.init_tensor [10, 10] : tensor<10x10xi32> + %init_tensor = tensor.empty () : tensor<10x10xi32> %filled = linalg.fill ins(%cst: i32) outs(%init_tensor: tensor<10x10xi32>) -> tensor<10x10xi32> %res = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], diff --git a/tests/litmus/linalg-loops/output-value.tgt.mlir b/tests/litmus/linalg-loops/output-value.tgt.mlir index bd346b69..d8b4a649 100644 --- a/tests/litmus/linalg-loops/output-value.tgt.mlir +++ b/tests/litmus/linalg-loops/output-value.tgt.mlir @@ -1,6 +1,6 @@ func.func @f(%arg0: tensor<10x10xi32>) -> tensor<10x10xi32> { %three= arith.constant 3 : i32 - %init_tensor = linalg.init_tensor [10, 10] : tensor<10x10xi32> + %init_tensor = tensor.empty () : tensor<10x10xi32> %filled = linalg.fill ins(%three: i32) outs(%init_tensor: tensor<10x10xi32>) -> tensor<10x10xi32> return %filled : tensor<10x10xi32> } diff --git a/tests/litmus/linalg-loops/sum-assoc-bad.src.mlir b/tests/litmus/linalg-loops/sum-assoc-bad.src.mlir index a3c47d45..5017497a 100644 --- a/tests/litmus/linalg-loops/sum-assoc-bad.src.mlir +++ b/tests/litmus/linalg-loops/sum-assoc-bad.src.mlir @@ -4,7 +4,7 @@ func.func @sum(%mat: tensor<100x100xf32>) -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %mat_transposed = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-assoc-bad.tgt.mlir b/tests/litmus/linalg-loops/sum-assoc-bad.tgt.mlir index da1d736e..8184d994 100644 --- a/tests/litmus/linalg-loops/sum-assoc-bad.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-assoc-bad.tgt.mlir @@ -5,7 +5,7 @@ func.func @sum(%mat0: tensor<100x100xf32>) -> tensor %i2 = arith.constant 2: index %mat = tensor.insert %c0 into %mat0[%i0,%i2]: tensor<100x100xf32> %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %mat_col = tensor.collapse_shape %mat [[0, 1]] : tensor<100x100xf32> into tensor<10000xf32> %result = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-assoc-f64.src.mlir b/tests/litmus/linalg-loops/sum-assoc-f64.src.mlir index 2d117270..dc7508a3 100644 --- a/tests/litmus/linalg-loops/sum-assoc-f64.src.mlir +++ b/tests/litmus/linalg-loops/sum-assoc-f64.src.mlir @@ -4,7 +4,7 @@ func.func @sum(%mat: tensor<100x100xf64>) -> tensor { %zero = arith.constant -0.0 : f64 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f64) outs(%i: tensor) -> tensor %mat_transposed = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-assoc-f64.tgt.mlir b/tests/litmus/linalg-loops/sum-assoc-f64.tgt.mlir index 8e746e8b..5ff89951 100644 --- a/tests/litmus/linalg-loops/sum-assoc-f64.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-assoc-f64.tgt.mlir @@ -4,7 +4,7 @@ func.func @sum(%mat: tensor<100x100xf64>) -> tensor %i0 = arith.constant 0: index %i2 = arith.constant 2: index %zero = arith.constant -0.0 : f64 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f64) outs(%i: tensor) -> tensor %mat_col = tensor.collapse_shape %mat [[0, 1]] : tensor<100x100xf64> into tensor<10000xf64> %result = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-assoc.src.mlir b/tests/litmus/linalg-loops/sum-assoc.src.mlir index 101a568f..f6367857 100644 --- a/tests/litmus/linalg-loops/sum-assoc.src.mlir +++ b/tests/litmus/linalg-loops/sum-assoc.src.mlir @@ -4,7 +4,7 @@ func.func @sum(%mat: tensor<100x100xf32>) -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %mat_transposed = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-assoc.tgt.mlir b/tests/litmus/linalg-loops/sum-assoc.tgt.mlir index d8940439..ff54d00b 100644 --- a/tests/litmus/linalg-loops/sum-assoc.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-assoc.tgt.mlir @@ -4,7 +4,7 @@ func.func @sum(%mat: tensor<100x100xf32>) -> tensor %i0 = arith.constant 0: index %i2 = arith.constant 2: index %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %mat_col = tensor.collapse_shape %mat [[0, 1]] : tensor<100x100xf32> into tensor<10000xf32> %result = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-bad.src.mlir b/tests/litmus/linalg-loops/sum-bad.src.mlir index d4eb9076..56a51dcb 100644 --- a/tests/litmus/linalg-loops/sum-bad.src.mlir +++ b/tests/litmus/linalg-loops/sum-bad.src.mlir @@ -4,7 +4,7 @@ func.func @sum(%mat: tensor<5x5xf32>) -> tensor<5xf32> { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [5] : tensor<5xf32> + %i = tensor.empty () : tensor<5xf32> %outty = linalg.fill ins(%zero: f32) outs(%i: tensor<5xf32>) -> tensor<5xf32> %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, diff --git a/tests/litmus/linalg-loops/sum-bad.tgt.mlir b/tests/litmus/linalg-loops/sum-bad.tgt.mlir index ad5101bf..9064deb9 100644 --- a/tests/litmus/linalg-loops/sum-bad.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-bad.tgt.mlir @@ -1,7 +1,7 @@ func.func @sum(%mat: tensor<5x5xf32>) -> tensor<5xf32> { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [5] : tensor<5xf32> + %i = tensor.empty () : tensor<5xf32> %outty = linalg.fill ins(%zero: f32) outs(%i: tensor<5xf32>) -> tensor<5xf32> %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, diff --git a/tests/litmus/linalg-loops/sum-int-unroll.src.mlir b/tests/litmus/linalg-loops/sum-int-unroll.src.mlir index e3405cdb..6c1f0c14 100644 --- a/tests/litmus/linalg-loops/sum-int-unroll.src.mlir +++ b/tests/litmus/linalg-loops/sum-int-unroll.src.mlir @@ -5,7 +5,7 @@ func.func @sum() -> tensor { %cst = arith.constant dense<10> : tensor<5xi8> %zero = arith.constant 0 : i8 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%zero: i8) outs(%i: tensor) -> tensor %result = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, diff --git a/tests/litmus/linalg-loops/sum-int-unroll.tgt.mlir b/tests/litmus/linalg-loops/sum-int-unroll.tgt.mlir index 8c90eb75..6914edf1 100644 --- a/tests/litmus/linalg-loops/sum-int-unroll.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-int-unroll.tgt.mlir @@ -2,7 +2,7 @@ func.func @sum() -> tensor { %fifty = arith.constant 50: i8 %zero = arith.constant 0 : i8 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %t = linalg.fill ins(%zero: i8) outs(%i: tensor) -> tensor %t2 = tensor.insert %fifty into %t[]: tensor return %t2: tensor diff --git a/tests/litmus/linalg-loops/sum-mul.src.mlir b/tests/litmus/linalg-loops/sum-mul.src.mlir index 6814dd93..c3e4a4ff 100644 --- a/tests/litmus/linalg-loops/sum-mul.src.mlir +++ b/tests/litmus/linalg-loops/sum-mul.src.mlir @@ -3,7 +3,7 @@ func.func @sum(%mat: tensor<5x5xf32>) -> tensor<5xf32> { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [5] : tensor<5xf32> + %i = tensor.empty () : tensor<5xf32> %outty = linalg.fill ins(%zero: f32) outs(%i: tensor<5xf32>) -> tensor<5xf32> %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, diff --git a/tests/litmus/linalg-loops/sum-mul.tgt.mlir b/tests/litmus/linalg-loops/sum-mul.tgt.mlir index ff1aa18c..2050933f 100644 --- a/tests/litmus/linalg-loops/sum-mul.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-mul.tgt.mlir @@ -1,7 +1,7 @@ func.func @sum(%mat: tensor<5x5xf32>) -> tensor<5xf32> { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [5] : tensor<5xf32> + %i = tensor.empty () : tensor<5xf32> %outty = linalg.fill ins(%zero: f32) outs(%i: tensor<5xf32>) -> tensor<5xf32> %result = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, diff --git a/tests/litmus/linalg-loops/sum-unit-tensor-ub.src.mlir b/tests/litmus/linalg-loops/sum-unit-tensor-ub.src.mlir index b8f6b83a..446e9618 100644 --- a/tests/litmus/linalg-loops/sum-unit-tensor-ub.src.mlir +++ b/tests/litmus/linalg-loops/sum-unit-tensor-ub.src.mlir @@ -2,7 +2,7 @@ func.func @sum(%x: tensor<1xf32>) -> f32 { - %outty = linalg.init_tensor [] : tensor + %outty = tensor.empty () : tensor %result = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], diff --git a/tests/litmus/linalg-loops/sum-unit-tensor.src.mlir b/tests/litmus/linalg-loops/sum-unit-tensor.src.mlir index 87444590..0638f0b6 100644 --- a/tests/litmus/linalg-loops/sum-unit-tensor.src.mlir +++ b/tests/litmus/linalg-loops/sum-unit-tensor.src.mlir @@ -3,7 +3,7 @@ func.func @sum(%x: tensor<1xf32>) -> f32 { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %result = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, diff --git a/tests/litmus/linalg-loops/sum-with-identity-assoc.src.mlir b/tests/litmus/linalg-loops/sum-with-identity-assoc.src.mlir index 4041617a..e76a3b60 100644 --- a/tests/litmus/linalg-loops/sum-with-identity-assoc.src.mlir +++ b/tests/litmus/linalg-loops/sum-with-identity-assoc.src.mlir @@ -4,7 +4,7 @@ func.func @sum() -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %cst = arith.constant sparse<[[0], [1], [2], [3], [4]], [2.000000e+00, -0.000000e+00, 3.000000e+00, -0.000000e+00, -1.200000e+01]> : tensor<5xf32> %result = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-with-identity-assoc.tgt.mlir b/tests/litmus/linalg-loops/sum-with-identity-assoc.tgt.mlir index 5a00dea7..281933b1 100644 --- a/tests/litmus/linalg-loops/sum-with-identity-assoc.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-with-identity-assoc.tgt.mlir @@ -1,7 +1,7 @@ func.func @sum() -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %cst = arith.constant sparse<[[0], [1], [2]], [-1.200000e+01, 3.000000e+00, 2.000000e+00]> : tensor<3xf32> %result = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-with-identity.src.mlir b/tests/litmus/linalg-loops/sum-with-identity.src.mlir index b8a5bbeb..20b916e6 100644 --- a/tests/litmus/linalg-loops/sum-with-identity.src.mlir +++ b/tests/litmus/linalg-loops/sum-with-identity.src.mlir @@ -4,7 +4,7 @@ func.func @sum() -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %cst = arith.constant sparse<[[0], [1], [2], [3], [4]], [-1.200000e+01, -0.000000e+00, 3.000000e+00, 2.000000e+00, -0.000000e+00]> : tensor<5xf32> %result = linalg.generic { diff --git a/tests/litmus/linalg-loops/sum-with-identity.tgt.mlir b/tests/litmus/linalg-loops/sum-with-identity.tgt.mlir index 5a00dea7..281933b1 100644 --- a/tests/litmus/linalg-loops/sum-with-identity.tgt.mlir +++ b/tests/litmus/linalg-loops/sum-with-identity.tgt.mlir @@ -1,7 +1,7 @@ func.func @sum() -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %cst = arith.constant sparse<[[0], [1], [2]], [-1.200000e+01, 3.000000e+00, 2.000000e+00]> : tensor<3xf32> %result = linalg.generic { diff --git a/tests/litmus/linalg-ops/convolution_constfold.src.mlir b/tests/litmus/linalg-ops/convolution_constfold.src.mlir index 51e45b29..70048f7e 100644 --- a/tests/litmus/linalg-ops/convolution_constfold.src.mlir +++ b/tests/litmus/linalg-ops/convolution_constfold.src.mlir @@ -4,7 +4,7 @@ func.func @conv() -> tensor<1x1x1x1xf32> { %img = arith.constant dense<[[[[1.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]],[[-0.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]],[[-0.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]]]]> : tensor<1x3x3x2xf32> %fil = arith.constant dense<[[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]]]> : tensor<3x3x2x1xf32> - %out = linalg.init_tensor [1,1,1,1] : tensor<1x1x1x1xf32> + %out = tensor.empty () : tensor<1x1x1x1xf32> %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } ins(%img, %fil: tensor<1x3x3x2xf32>, tensor<3x3x2x1xf32>) diff --git a/tests/litmus/linalg-ops/convolution_int.src.mlir b/tests/litmus/linalg-ops/convolution_int.src.mlir index 3c305d09..5e7f8bf5 100644 --- a/tests/litmus/linalg-ops/convolution_int.src.mlir +++ b/tests/litmus/linalg-ops/convolution_int.src.mlir @@ -3,7 +3,7 @@ func.func @conv() -> tensor<1x1x1x1xf32> { %img = arith.constant dense<[[[[1.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]],[[-0.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]],[[-0.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]]]]> : tensor<1x3x3x2xf32> %fil = arith.constant dense<[[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]]]> : tensor<3x3x2x1xf32> - %out = linalg.init_tensor [1,1,1,1] : tensor<1x1x1x1xf32> + %out = tensor.empty () : tensor<1x1x1x1xf32> %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } ins(%img, %fil: tensor<1x3x3x2xf32>, tensor<3x3x2x1xf32>) diff --git a/tests/litmus/linalg-ops/convolution_int.tgt.mlir b/tests/litmus/linalg-ops/convolution_int.tgt.mlir index 5689bc62..4a85c51e 100644 --- a/tests/litmus/linalg-ops/convolution_int.tgt.mlir +++ b/tests/litmus/linalg-ops/convolution_int.tgt.mlir @@ -1,7 +1,7 @@ func.func @conv() -> tensor<1x1x1x1xf32> { %img = arith.constant dense<[[[[1.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]],[[-0.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]],[[-0.0,-0.0],[-0.0,-0.0],[-0.0,-0.0]]]]> : tensor<1x3x3x2xf32> %fil = arith.constant dense<[[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]]]> : tensor<3x3x2x1xf32> - %out = linalg.init_tensor [1,1,1,1] : tensor<1x1x1x1xf32> + %out = tensor.empty () : tensor<1x1x1x1xf32> %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } ins(%img, %fil: tensor<1x3x3x2xf32>, tensor<3x3x2x1xf32>) diff --git a/tests/litmus/linalg-ops/dot.src.mlir b/tests/litmus/linalg-ops/dot.src.mlir index b23f7442..dec9d97a 100644 --- a/tests/litmus/linalg-ops/dot.src.mlir +++ b/tests/litmus/linalg-ops/dot.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @f(%a: tensor<100xf32>, %b: tensor<100xf32>) -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %zero = arith.constant -0.0 : f32 %filled = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %e = linalg.dot ins(%a, %b : tensor<100xf32>,tensor<100xf32>) diff --git a/tests/litmus/linalg-ops/dot.tgt.mlir b/tests/litmus/linalg-ops/dot.tgt.mlir index c3840ddb..0b279c45 100644 --- a/tests/litmus/linalg-ops/dot.tgt.mlir +++ b/tests/litmus/linalg-ops/dot.tgt.mlir @@ -1,5 +1,5 @@ func.func @f(%a: tensor<100xf32>, %b: tensor<100xf32>) -> tensor { - %outty = linalg.init_tensor [] : tensor + %outty = tensor.empty () : tensor %zero = arith.constant -0.0 : f32 %filled = linalg.fill ins(%zero: f32) outs(%outty: tensor) -> tensor %result = linalg.generic { diff --git a/tests/litmus/linalg-ops/dot2.src.mlir b/tests/litmus/linalg-ops/dot2.src.mlir index 2f38973b..b754e42d 100644 --- a/tests/litmus/linalg-ops/dot2.src.mlir +++ b/tests/litmus/linalg-ops/dot2.src.mlir @@ -2,7 +2,7 @@ func.func @f(%a: tensor, %b: tensor) -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %e = linalg.dot ins(%a, %b : tensor,tensor) outs(%outty: tensor) -> tensor diff --git a/tests/litmus/linalg-ops/dot2.tgt.mlir b/tests/litmus/linalg-ops/dot2.tgt.mlir index 3d21f828..7f906e77 100644 --- a/tests/litmus/linalg-ops/dot2.tgt.mlir +++ b/tests/litmus/linalg-ops/dot2.tgt.mlir @@ -1,6 +1,6 @@ func.func @f(%a: tensor, %b: tensor) -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %result = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, diff --git a/tests/litmus/linalg-ops/dot_assoc_varlen.src.mlir b/tests/litmus/linalg-ops/dot_assoc_varlen.src.mlir index e780e1da..1afd50e7 100644 --- a/tests/litmus/linalg-ops/dot_assoc_varlen.src.mlir +++ b/tests/litmus/linalg-ops/dot_assoc_varlen.src.mlir @@ -2,7 +2,7 @@ // ARGS: --associative func.func @f(%a: tensor, %b: tensor) -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %res = linalg.dot ins(%a, %b : tensor,tensor) outs(%i: tensor) -> tensor return %res : tensor diff --git a/tests/litmus/linalg-ops/dot_assoc_varlen.tgt.mlir b/tests/litmus/linalg-ops/dot_assoc_varlen.tgt.mlir index a3a9b6d7..80289a3f 100644 --- a/tests/litmus/linalg-ops/dot_assoc_varlen.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_assoc_varlen.tgt.mlir @@ -1,5 +1,5 @@ func.func @f(%a: tensor, %b: tensor) -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %res = linalg.dot ins(%a, %a : tensor,tensor) outs(%i: tensor) -> tensor return %res : tensor diff --git a/tests/litmus/linalg-ops/dot_associativity.src.mlir b/tests/litmus/linalg-ops/dot_associativity.src.mlir index d9a7145f..a4ff3aa4 100644 --- a/tests/litmus/linalg-ops/dot_associativity.src.mlir +++ b/tests/litmus/linalg-ops/dot_associativity.src.mlir @@ -9,7 +9,7 @@ // dot based on multiset theroy (to be precise, the argument of 'sum'). func.func @f() -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a = arith.constant sparse<[[0], [1], [2], [3], [4]], [-12.0, 3.0, 2.0, 5.0, 4.0]> : tensor<5xf32> %b = arith.constant sparse<[[0], [1], [2], [3], [4]], [1.0, 8.0, 5.0, 6.0, 0.0]> : tensor<5xf32> %res = linalg.dot ins(%a, %b : tensor<5xf32>,tensor<5xf32>) diff --git a/tests/litmus/linalg-ops/dot_associativity.tgt.mlir b/tests/litmus/linalg-ops/dot_associativity.tgt.mlir index b09da77f..f1d0d174 100644 --- a/tests/litmus/linalg-ops/dot_associativity.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_associativity.tgt.mlir @@ -1,5 +1,5 @@ func.func @f() -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a = arith.constant sparse<[[4], [3], [2], [1], [0]], [-12.0, 3.0, 2.0, 5.0, 4.0]> : tensor<5xf32> %b = arith.constant sparse<[[4], [3], [2], [1], [0]], [1.0, 8.0, 5.0, 6.0, 0.0]> : tensor<5xf32> %res = linalg.dot ins(%a, %b : tensor<5xf32>,tensor<5xf32>) diff --git a/tests/litmus/linalg-ops/dot_associativity2.src.mlir b/tests/litmus/linalg-ops/dot_associativity2.src.mlir index 02ba3ed8..43bf6dbd 100644 --- a/tests/litmus/linalg-ops/dot_associativity2.src.mlir +++ b/tests/litmus/linalg-ops/dot_associativity2.src.mlir @@ -10,7 +10,7 @@ func.func @f() -> f32 { %c0 = arith.constant 0 : index - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = arith.constant sparse<[[0], [1]], [-12.0, 3.0]> : tensor<2xf32> %a2 = arith.constant sparse<[[0], [1], [2]], [2.0, 5.0, 4.0]> : tensor<3xf32> %b1 = arith.constant sparse<[[0], [1]], [1.0, 8.0]> : tensor<2xf32> diff --git a/tests/litmus/linalg-ops/dot_associativity2.tgt.mlir b/tests/litmus/linalg-ops/dot_associativity2.tgt.mlir index 4d5339e0..a8620ee0 100644 --- a/tests/litmus/linalg-ops/dot_associativity2.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_associativity2.tgt.mlir @@ -1,6 +1,6 @@ func.func @f() -> f32 { %c0 = arith.constant 0 : index - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = arith.constant sparse<[[1], [0]], [-12.0, 3.0]> : tensor<2xf32> %a2 = arith.constant sparse<[[2], [1], [0]], [2.0, 5.0, 4.0]> : tensor<3xf32> %b1 = arith.constant sparse<[[1], [0]], [1.0, 8.0]> : tensor<2xf32> diff --git a/tests/litmus/linalg-ops/dot_associativity2_multiset.src.mlir b/tests/litmus/linalg-ops/dot_associativity2_multiset.src.mlir index ce01281e..9edc7f9c 100644 --- a/tests/litmus/linalg-ops/dot_associativity2_multiset.src.mlir +++ b/tests/litmus/linalg-ops/dot_associativity2_multiset.src.mlir @@ -10,7 +10,7 @@ func.func @f() -> f32 { %c0 = arith.constant 0 : index - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = arith.constant sparse<[[0], [1]], [-12.0, 3.0]> : tensor<2xf32> %a2 = arith.constant sparse<[[0], [1], [2]], [2.0, 5.0, 4.0]> : tensor<3xf32> %b1 = arith.constant sparse<[[0], [1]], [1.0, 8.0]> : tensor<2xf32> diff --git a/tests/litmus/linalg-ops/dot_associativity2_multiset.tgt.mlir b/tests/litmus/linalg-ops/dot_associativity2_multiset.tgt.mlir index 4d5339e0..a8620ee0 100644 --- a/tests/litmus/linalg-ops/dot_associativity2_multiset.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_associativity2_multiset.tgt.mlir @@ -1,6 +1,6 @@ func.func @f() -> f32 { %c0 = arith.constant 0 : index - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = arith.constant sparse<[[1], [0]], [-12.0, 3.0]> : tensor<2xf32> %a2 = arith.constant sparse<[[2], [1], [0]], [2.0, 5.0, 4.0]> : tensor<3xf32> %b1 = arith.constant sparse<[[1], [0]], [1.0, 8.0]> : tensor<2xf32> diff --git a/tests/litmus/linalg-ops/dot_associativity3_multiset.src.mlir b/tests/litmus/linalg-ops/dot_associativity3_multiset.src.mlir index e8c27cf4..14f52a94 100644 --- a/tests/litmus/linalg-ops/dot_associativity3_multiset.src.mlir +++ b/tests/litmus/linalg-ops/dot_associativity3_multiset.src.mlir @@ -10,7 +10,7 @@ func.func @f() -> (f32, f32) { %c0 = arith.constant 0 : index - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = arith.constant sparse<[[0], [1]], [-12.0, 3.0]> : tensor<2xf32> %a2 = arith.constant sparse<[[0], [1], [2]], [2.0, 5.0, 4.0]> : tensor<3xf32> %b1 = arith.constant sparse<[[0], [1]], [1.0, 8.0]> : tensor<2xf32> diff --git a/tests/litmus/linalg-ops/dot_associativity3_multiset.tgt.mlir b/tests/litmus/linalg-ops/dot_associativity3_multiset.tgt.mlir index a3cdc5be..8dd6bdce 100644 --- a/tests/litmus/linalg-ops/dot_associativity3_multiset.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_associativity3_multiset.tgt.mlir @@ -1,6 +1,6 @@ func.func @f() -> (f32, f32) { %c0 = arith.constant 0 : index - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = arith.constant sparse<[[1], [0]], [-12.0, 3.0]> : tensor<2xf32> %a2 = arith.constant sparse<[[2], [1], [0]], [2.0, 5.0, 4.0]> : tensor<3xf32> %b1 = arith.constant sparse<[[1], [0]], [1.0, 8.0]> : tensor<2xf32> diff --git a/tests/litmus/linalg-ops/dot_commutative.src.mlir b/tests/litmus/linalg-ops/dot_commutative.src.mlir index 785429c8..342cf305 100644 --- a/tests/litmus/linalg-ops/dot_commutative.src.mlir +++ b/tests/litmus/linalg-ops/dot_commutative.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @f(%a: tensor<100xf32>, %b: tensor<100xf32>) -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %e = linalg.dot ins(%a, %b : tensor<100xf32>,tensor<100xf32>) outs(%i: tensor) -> tensor return %e : tensor diff --git a/tests/litmus/linalg-ops/dot_commutative.tgt.mlir b/tests/litmus/linalg-ops/dot_commutative.tgt.mlir index 4f7d6448..e84bc7e5 100644 --- a/tests/litmus/linalg-ops/dot_commutative.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_commutative.tgt.mlir @@ -1,6 +1,6 @@ func.func @f(%a: tensor<100xf32>, %b: tensor<100xf32>) -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor [] : tensor + %i = tensor.empty () : tensor %out = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %result = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, diff --git a/tests/litmus/linalg-ops/dot_constfold.src.mlir b/tests/litmus/linalg-ops/dot_constfold.src.mlir index 62fb58d9..f1892169 100644 --- a/tests/litmus/linalg-ops/dot_constfold.src.mlir +++ b/tests/litmus/linalg-ops/dot_constfold.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @f() -> f32 { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = arith.constant sparse<[[0], [1], [2]], [1.0, -0.0, -0.0]> : tensor<3xf32> %a2 = arith.constant sparse<[[0], [1], [2]], [1.0, 2.0, 3.0]> : tensor<3xf32> %o1 = linalg.dot ins(%a1, %a2 : tensor<3xf32>,tensor<3xf32>) diff --git a/tests/litmus/linalg-ops/dot_rewrite_manually.src.mlir b/tests/litmus/linalg-ops/dot_rewrite_manually.src.mlir index 8a44b227..ba428a6b 100644 --- a/tests/litmus/linalg-ops/dot_rewrite_manually.src.mlir +++ b/tests/litmus/linalg-ops/dot_rewrite_manually.src.mlir @@ -20,7 +20,7 @@ func.func @f() -> tensor { %r2 = arith.addf %r1, %c2 : f32 %r3 = arith.addf %r2, %c3 : f32 %r4 = arith.addf %r3, %c4 : f32 - %res = linalg.init_tensor []: tensor + %res = tensor.empty (): tensor %res2 = linalg.fill ins(%r4: f32) outs(%res: tensor) -> tensor return %res2 : tensor } diff --git a/tests/litmus/linalg-ops/dot_rewrite_manually.tgt.mlir b/tests/litmus/linalg-ops/dot_rewrite_manually.tgt.mlir index 4a87607a..d8e0d77b 100644 --- a/tests/litmus/linalg-ops/dot_rewrite_manually.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_rewrite_manually.tgt.mlir @@ -1,6 +1,6 @@ func.func @f() -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %a = arith.constant sparse<[[0], [1], [2], [3], [4]], [-12.0, 3.0, 2.0, 5.0, 4.0]> : tensor<5xf32> %b = arith.constant sparse<[[0], [1], [2], [3], [4]], [1.0, 8.0, 5.0, 6.0, 0.0]> : tensor<5xf32> diff --git a/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.src.mlir b/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.src.mlir index 2e837c82..66a2e397 100644 --- a/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.src.mlir +++ b/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.src.mlir @@ -28,7 +28,7 @@ func.func @f() -> tensor { %r2 = arith.addf %r1, %c2 : f32 %r3 = arith.addf %r2, %c3 : f32 %r4 = arith.addf %r3, %c4 : f32 - %res = linalg.init_tensor []: tensor + %res = tensor.empty (): tensor %res2 = linalg.fill ins(%r4: f32) outs(%res: tensor) -> tensor return %res2 : tensor } diff --git a/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.tgt.mlir b/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.tgt.mlir index 74d1ca70..2e201bc7 100644 --- a/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.tgt.mlir +++ b/tests/litmus/linalg-ops/dot_rewrite_manually_assoc.tgt.mlir @@ -1,6 +1,6 @@ func.func @f() -> tensor { %zero = arith.constant -0.0 : f32 - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %outty = linalg.fill ins(%zero: f32) outs(%i: tensor) -> tensor %a = arith.constant sparse<[[4], [3], [2], [1], [0]], [-12.0, 3.0, 2.0, 5.0, 4.0]> : tensor<5xf32> %b = arith.constant sparse<[[4], [3], [2], [1], [0]], [1.0, 8.0, 5.0, 6.0, 0.0]> : tensor<5xf32> diff --git a/tests/litmus/linalg-ops/fill.src.mlir b/tests/litmus/linalg-ops/fill.src.mlir index f8f38e1a..cfee51a2 100644 --- a/tests/litmus/linalg-ops/fill.src.mlir +++ b/tests/litmus/linalg-ops/fill.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @f(%arg: tensor<3x3xf32>) -> tensor<3x3xf32> { - %t = linalg.init_tensor [3, 3]: tensor<3x3xf32> + %t = tensor.empty (): tensor<3x3xf32> %c0 = arith.constant 0.0: f32 %res = linalg.fill ins(%c0: f32) outs(%t: tensor<3x3xf32>) -> tensor<3x3xf32> return %res: tensor<3x3xf32> diff --git a/tests/litmus/linalg-ops/init_tensor.src.mlir b/tests/litmus/linalg-ops/init_tensor.src.mlir index 79f980ad..40dae75e 100644 --- a/tests/litmus/linalg-ops/init_tensor.src.mlir +++ b/tests/litmus/linalg-ops/init_tensor.src.mlir @@ -3,7 +3,7 @@ func.func @f() -> index { %c0 = arith.constant 0: index %c10 = arith.constant 10: index - %v = linalg.init_tensor [%c10]: tensor + %v = tensor.empty (%c10): tensor %d = tensor.dim %v, %c0: tensor return %d: index } diff --git a/tests/litmus/linalg-ops/init_tensor.tgt.mlir b/tests/litmus/linalg-ops/init_tensor.tgt.mlir index a9643d09..59af5df7 100644 --- a/tests/litmus/linalg-ops/init_tensor.tgt.mlir +++ b/tests/litmus/linalg-ops/init_tensor.tgt.mlir @@ -1,7 +1,7 @@ func.func @f() -> index { %c0 = arith.constant 0: index %c10 = arith.constant 20: index - %v = linalg.init_tensor [%c10]: tensor + %v = tensor.empty (%c10): tensor %d = tensor.dim %v, %c0: tensor return %d: index } diff --git a/tests/litmus/linalg-ops/init_tensor_cast.src.mlir b/tests/litmus/linalg-ops/init_tensor_cast.src.mlir index 418bae82..fff9a039 100644 --- a/tests/litmus/linalg-ops/init_tensor_cast.src.mlir +++ b/tests/litmus/linalg-ops/init_tensor_cast.src.mlir @@ -2,6 +2,6 @@ func.func @f() -> (tensor<4x5x?xf32>) { %c6 = arith.constant 6 : index - %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32> + %0 = tensor.empty (%c6) : tensor<4x5x?xf32> return %0 : tensor<4x5x?xf32> } diff --git a/tests/litmus/linalg-ops/init_tensor_cast.tgt.mlir b/tests/litmus/linalg-ops/init_tensor_cast.tgt.mlir index 69366206..48abdea0 100644 --- a/tests/litmus/linalg-ops/init_tensor_cast.tgt.mlir +++ b/tests/litmus/linalg-ops/init_tensor_cast.tgt.mlir @@ -1,5 +1,5 @@ func.func @f() -> tensor<4x5x?xf32> { - %0 = linalg.init_tensor [4, 5, 6] : tensor<4x5x6xf32> + %0 = tensor.empty () : tensor<4x5x6xf32> %1 = tensor.cast %0 : tensor<4x5x6xf32> to tensor<4x5x?xf32> return %1 : tensor<4x5x?xf32> } diff --git a/tests/litmus/linalg-ops/memref_matmul.tgt.mlir b/tests/litmus/linalg-ops/memref_matmul.tgt.mlir index 4cae872a..8dc778ec 100644 --- a/tests/litmus/linalg-ops/memref_matmul.tgt.mlir +++ b/tests/litmus/linalg-ops/memref_matmul.tgt.mlir @@ -6,7 +6,7 @@ func.func @f() -> tensor<8x16xf32> { %b = memref.get_global @constant_b : memref<4x16xf32> %ta = bufferization.to_tensor %a : memref<8x4xf32> %tb = bufferization.to_tensor %b : memref<4x16xf32> - %c = linalg.init_tensor [8, 16] : tensor<8x16xf32> + %c = tensor.empty () : tensor<8x16xf32> %cst = arith.constant -0.0 : f32 %tc = linalg.fill ins(%cst: f32) outs(%c: tensor<8x16xf32>) -> tensor<8x16xf32> %mat = linalg.matmul ins(%ta, %tb: tensor<8x4xf32>, tensor<4x16xf32>) outs(%tc: tensor<8x16xf32>) -> tensor<8x16xf32> diff --git a/tests/litmus/linalg-ops/pooling_unsupported2.src.mlir b/tests/litmus/linalg-ops/pooling_unsupported2.src.mlir index 4303680b..27fb126b 100644 --- a/tests/litmus/linalg-ops/pooling_unsupported2.src.mlir +++ b/tests/litmus/linalg-ops/pooling_unsupported2.src.mlir @@ -1,8 +1,8 @@ // UNSUPPORTED func.func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> { - %fake = linalg.init_tensor [3, 3] : tensor<3x3xi8> - %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8> + %fake = tensor.empty () : tensor<3x3xi8> + %init = tensor.empty () : tensor<1x2x2x1xi8> %cst = arith.constant 0 : i8 %fill = linalg.fill ins(%cst: i8) outs(%init: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} diff --git a/tests/litmus/linalg-ops/pooling_unsupported2.tgt.mlir b/tests/litmus/linalg-ops/pooling_unsupported2.tgt.mlir index 4955faae..1b86d042 100644 --- a/tests/litmus/linalg-ops/pooling_unsupported2.tgt.mlir +++ b/tests/litmus/linalg-ops/pooling_unsupported2.tgt.mlir @@ -1,7 +1,7 @@ module { func.func @pooling_nhwc_i8_max_tensor(%arg0: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> { - %0 = linalg.init_tensor [3, 3] : tensor<3x3xi8> - %1 = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8> + %0 = tensor.empty () : tensor<3x3xi8> + %1 = tensor.empty () : tensor<1x2x2x1xi8> %c0_i8 = arith.constant 0 : i8 %2 = linalg.fill ins(%c0_i8: i8) outs(%1: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> %3 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %0 : tensor<1x4x4x1xi8>, tensor<3x3xi8>) outs(%2 : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> diff --git a/tests/litmus/memref-ops/copy-fill-bad.src.mlir b/tests/litmus/memref-ops/copy-fill-bad.src.mlir index e5e8c12d..23e536be 100644 --- a/tests/litmus/memref-ops/copy-fill-bad.src.mlir +++ b/tests/litmus/memref-ops/copy-fill-bad.src.mlir @@ -2,7 +2,7 @@ func.func @copy(%m1: memref<10x10xf32>, %m2: memref<10x10xf32>) { - %t = linalg.init_tensor [9, 9]: tensor<9x9xf32> + %t = tensor.empty (): tensor<9x9xf32> %c0 = arith.constant 0.0: f32 %zerotensor = linalg.fill ins(%c0: f32) outs(%t: tensor<9x9xf32>) -> tensor<9x9xf32> diff --git a/tests/litmus/memref-ops/copy-fill-bad.tgt.mlir b/tests/litmus/memref-ops/copy-fill-bad.tgt.mlir index 6983d5e7..824f2c71 100644 --- a/tests/litmus/memref-ops/copy-fill-bad.tgt.mlir +++ b/tests/litmus/memref-ops/copy-fill-bad.tgt.mlir @@ -1,6 +1,6 @@ func.func @copy(%m1: memref<10x10xf32>, %m2: memref<10x10xf32>) { - %t = linalg.init_tensor [9, 9]: tensor<9x9xf32> + %t = tensor.empty (): tensor<9x9xf32> %c0 = arith.constant 0.0: f32 %zerotensor = linalg.fill ins(%c0: f32) outs(%t: tensor<9x9xf32>) -> tensor<9x9xf32> diff --git a/tests/litmus/memref-ops/copy-fill.src.mlir b/tests/litmus/memref-ops/copy-fill.src.mlir index 7a44a288..8681aebd 100644 --- a/tests/litmus/memref-ops/copy-fill.src.mlir +++ b/tests/litmus/memref-ops/copy-fill.src.mlir @@ -2,7 +2,7 @@ func.func @copy(%m1: memref<10x10xf32>, %m2: memref<10x10xf32>) { - %t = linalg.init_tensor [10, 10]: tensor<10x10xf32> + %t = tensor.empty (): tensor<10x10xf32> %c0 = arith.constant 0.0: f32 %zerotensor = linalg.fill ins(%c0: f32) outs(%t: tensor<10x10xf32>) -> tensor<10x10xf32> diff --git a/tests/litmus/memref-ops/copy-fill.tgt.mlir b/tests/litmus/memref-ops/copy-fill.tgt.mlir index 0b9f1bfc..2bd01563 100644 --- a/tests/litmus/memref-ops/copy-fill.tgt.mlir +++ b/tests/litmus/memref-ops/copy-fill.tgt.mlir @@ -1,6 +1,6 @@ func.func @copy(%m1: memref<10x10xf32>, %m2: memref<10x10xf32>) { - %t = linalg.init_tensor [10, 10]: tensor<10x10xf32> + %t = tensor.empty (): tensor<10x10xf32> %c0 = arith.constant 0.0: f32 %zerotensor = linalg.fill ins(%c0: f32) outs(%t: tensor<10x10xf32>) -> tensor<10x10xf32> diff --git a/tests/litmus/refinement/size-mismatch.src.mlir b/tests/litmus/refinement/size-mismatch.src.mlir index ad746357..385d619e 100644 --- a/tests/litmus/refinement/size-mismatch.src.mlir +++ b/tests/litmus/refinement/size-mismatch.src.mlir @@ -3,6 +3,6 @@ func.func @f() -> tensor { %c10 = arith.constant 10: index - %v = linalg.init_tensor [%c10]: tensor + %v = tensor.empty (%c10): tensor return %v: tensor } diff --git a/tests/litmus/refinement/size-mismatch.tgt.mlir b/tests/litmus/refinement/size-mismatch.tgt.mlir index 489c06fb..9b0c46b6 100644 --- a/tests/litmus/refinement/size-mismatch.tgt.mlir +++ b/tests/litmus/refinement/size-mismatch.tgt.mlir @@ -1,5 +1,5 @@ func.func @f() -> tensor { %c20 = arith.constant 20: index - %v = linalg.init_tensor [%c20]: tensor + %v = tensor.empty (%c20): tensor return %v: tensor } diff --git a/tests/litmus/tensor-constant/transpose.src.mlir b/tests/litmus/tensor-constant/transpose.src.mlir index 03116b9b..25e0e40f 100644 --- a/tests/litmus/tensor-constant/transpose.src.mlir +++ b/tests/litmus/tensor-constant/transpose.src.mlir @@ -5,7 +5,7 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func.func @transpose() -> tensor<1x1x2x2xf32> { %cst = arith.constant dense<[[[[1.0, 3.0]]], [[[2.0, 4.0]]]]> : tensor<2x1x1x2xf32> - %1 = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32> + %1 = tensor.empty () : tensor<1x1x2x2xf32> %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst : tensor<2x1x1x2xf32>) outs(%1 : tensor<1x1x2x2xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors linalg.yield %arg1 : f32 diff --git a/tests/litmus/tensor-ops/extract_tensor_sum.src.mlir b/tests/litmus/tensor-ops/extract_tensor_sum.src.mlir index 9069df34..7df44d56 100644 --- a/tests/litmus/tensor-ops/extract_tensor_sum.src.mlir +++ b/tests/litmus/tensor-ops/extract_tensor_sum.src.mlir @@ -2,7 +2,7 @@ // ARGS: --associative func.func @f(%a: tensor<1000xf32>, %b: tensor<1000xf32>) -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %a1 = tensor.extract_slice %a[0][500][1]: tensor<1000xf32> to tensor<500xf32> %a2 = tensor.extract_slice %a[500][500][1]: tensor<1000xf32> to tensor<500xf32> %b1 = tensor.extract_slice %b[0][500][1]: tensor<1000xf32> to tensor<500xf32> diff --git a/tests/litmus/tensor-ops/extract_tensor_sum.tgt.mlir b/tests/litmus/tensor-ops/extract_tensor_sum.tgt.mlir index 15a9ce73..ec61f43e 100644 --- a/tests/litmus/tensor-ops/extract_tensor_sum.tgt.mlir +++ b/tests/litmus/tensor-ops/extract_tensor_sum.tgt.mlir @@ -1,5 +1,5 @@ func.func @f(%a: tensor<1000xf32>, %b: tensor<1000xf32>) -> tensor { - %i = linalg.init_tensor []: tensor + %i = tensor.empty (): tensor %e = linalg.dot ins(%a, %b : tensor<1000xf32>,tensor<1000xf32>) outs(%i: tensor) -> tensor return %e : tensor diff --git a/tests/litmus/tensor-ops/extract_ub.src.mlir b/tests/litmus/tensor-ops/extract_ub.src.mlir index b559912a..8cbe9f2b 100644 --- a/tests/litmus/tensor-ops/extract_ub.src.mlir +++ b/tests/litmus/tensor-ops/extract_ub.src.mlir @@ -3,7 +3,7 @@ func.func @f() -> () { %c10 = arith.constant 10 : index - %v = linalg.init_tensor [%c10]: tensor + %v = tensor.empty (%c10): tensor tensor.extract %v[%c10]: tensor return } diff --git a/tests/litmus/tensor-ops/extract_ub.tgt.mlir b/tests/litmus/tensor-ops/extract_ub.tgt.mlir index 6690418c..4fc30687 100644 --- a/tests/litmus/tensor-ops/extract_ub.tgt.mlir +++ b/tests/litmus/tensor-ops/extract_ub.tgt.mlir @@ -1,7 +1,7 @@ func.func @f() -> () { %c10 = arith.constant 10 : index - %v = linalg.init_tensor [%c10]: tensor + %v = tensor.empty (%c10): tensor tensor.extract %v[%c10]: tensor return } diff --git a/tests/litmus/tensor-ops/from_elements.tgt.mlir b/tests/litmus/tensor-ops/from_elements.tgt.mlir index e15942a1..0b64dd32 100644 --- a/tests/litmus/tensor-ops/from_elements.tgt.mlir +++ b/tests/litmus/tensor-ops/from_elements.tgt.mlir @@ -1,6 +1,6 @@ func.func @from_elem() -> tensor<3xf32> { - %v = linalg.init_tensor[3]: tensor<3xf32> + %v = tensor.empty (): tensor<3xf32> %c0 = arith.constant 3.0: f32 %res = linalg.fill ins(%c0: f32) outs(%v: tensor<3xf32>) -> tensor<3xf32> return %res : tensor<3xf32> diff --git a/tests/litmus/tensor-ops/from_elements_arg.tgt.mlir b/tests/litmus/tensor-ops/from_elements_arg.tgt.mlir index cf737cf0..556ffd17 100644 --- a/tests/litmus/tensor-ops/from_elements_arg.tgt.mlir +++ b/tests/litmus/tensor-ops/from_elements_arg.tgt.mlir @@ -1,6 +1,6 @@ func.func @from_elem(%x:f32) -> tensor<3xf32> { - %v = linalg.init_tensor[3]: tensor<3xf32> + %v = tensor.empty (): tensor<3xf32> %res = linalg.fill ins(%x: f32) outs(%v: tensor<3xf32>) -> tensor<3xf32> return %res : tensor<3xf32> } diff --git a/tests/litmus/tosa-ops/add_broadcast1.src.mlir b/tests/litmus/tosa-ops/add_broadcast1.src.mlir index 900ab676..6f6a788c 100644 --- a/tests/litmus/tosa-ops/add_broadcast1.src.mlir +++ b/tests/litmus/tosa-ops/add_broadcast1.src.mlir @@ -3,7 +3,7 @@ func.func @add(%arg0: tensor<1xf32>, %arg1: tensor<10x9x8x7xf32>) -> tensor<10x9x8x7xf32> { %c0 = arith.constant 0 : index %c = tensor.extract %arg0[%c0] : tensor<1xf32> - %t1 = linalg.init_tensor [10, 9, 8, 7] : tensor<10x9x8x7xf32> + %t1 = tensor.empty () : tensor<10x9x8x7xf32> %t2 = linalg.fill ins(%c: f32) outs(%t1: tensor<10x9x8x7xf32>) -> tensor<10x9x8x7xf32> %0 = "tosa.add"(%t2, %arg1) : (tensor<10x9x8x7xf32>, tensor<10x9x8x7xf32>) -> tensor<10x9x8x7xf32> diff --git a/tests/litmus/tosa-ops/add_broadcast3.src.mlir b/tests/litmus/tosa-ops/add_broadcast3.src.mlir index b3bb4466..e0e25f8b 100644 --- a/tests/litmus/tosa-ops/add_broadcast3.src.mlir +++ b/tests/litmus/tosa-ops/add_broadcast3.src.mlir @@ -9,8 +9,8 @@ func.func @add(%arg0: tensor<2x1xf32>, %arg1: tensor<1x3xf32>) -> tensor<2x3xf32 %y11 = tensor.extract %arg1[%c0, %c0] : tensor<1x3xf32> %y12 = tensor.extract %arg1[%c0, %c1] : tensor<1x3xf32> %y13 = tensor.extract %arg1[%c0, %c2] : tensor<1x3xf32> - %tx0 = linalg.init_tensor [2, 3] : tensor<2x3xf32> - %ty0 = linalg.init_tensor [2, 3] : tensor<2x3xf32> + %tx0 = tensor.empty () : tensor<2x3xf32> + %ty0 = tensor.empty () : tensor<2x3xf32> %tx1 = tensor.insert %x11 into %tx0[%c0, %c0] : tensor<2x3xf32> %tx2 = tensor.insert %x11 into %tx1[%c0, %c1] : tensor<2x3xf32> %tx3 = tensor.insert %x11 into %tx2[%c0, %c2] : tensor<2x3xf32> diff --git a/tests/litmus/tosa-ops/avgpool2d.src.mlir b/tests/litmus/tosa-ops/avgpool2d.src.mlir index df8c41bd..95fb6bca 100644 --- a/tests/litmus/tosa-ops/avgpool2d.src.mlir +++ b/tests/litmus/tosa-ops/avgpool2d.src.mlir @@ -2,6 +2,6 @@ // ARGS: --use-neg-zero func.func @avgpool(%arg0: tensor<1x13x13x1001xf32>) -> tensor<1x1x1x1001xf32> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [13, 13], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x13x13x1001xf32>) -> tensor<1x1x1x1001xf32> + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : (tensor<1x13x13x1001xf32>) -> tensor<1x1x1x1001xf32> return %0 : tensor<1x1x1x1001xf32> } diff --git a/tests/litmus/tosa-ops/avgpool2d.tgt.mlir b/tests/litmus/tosa-ops/avgpool2d.tgt.mlir index 49c4ce5c..9f6177b8 100644 --- a/tests/litmus/tosa-ops/avgpool2d.tgt.mlir +++ b/tests/litmus/tosa-ops/avgpool2d.tgt.mlir @@ -1,11 +1,11 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func.func @avgpool(%arg0: tensor<1x13x13x1001xf32>) -> tensor<1x1x1x1001xf32> { %cst_194 = arith.constant 0.000000e+00 : f32 - %373 = linalg.init_tensor [1, 1, 1, 1001] : tensor<1x1x1x1001xf32> + %373 = tensor.empty () : tensor<1x1x1x1001xf32> %374 = linalg.fill ins(%cst_194: f32) outs(%373: tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32> - %375 = linalg.init_tensor [13, 13] : tensor<13x13xf32> + %375 = tensor.empty () : tensor<13x13xf32> %376 = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %375 : tensor<1x13x13x1001xf32>, tensor<13x13xf32>) outs(%374 : tensor<1x1x1x1001xf32>) -> tensor<1x1x1x1001xf32> - %377 = linalg.init_tensor [1, 1, 1, 1001] : tensor<1x1x1x1001xf32> + %377 = tensor.empty () : tensor<1x1x1x1001xf32> %378 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%376 : tensor<1x1x1x1001xf32>) outs(%377 : tensor<1x1x1x1001xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %c0_196 = arith.constant 0 : index diff --git a/tests/litmus/tosa-ops/avgpool2d_memref.src.mlir b/tests/litmus/tosa-ops/avgpool2d_memref.src.mlir index 183209ca..a259e67d 100644 --- a/tests/litmus/tosa-ops/avgpool2d_memref.src.mlir +++ b/tests/litmus/tosa-ops/avgpool2d_memref.src.mlir @@ -5,11 +5,11 @@ func.func @avgpool(%arg0: tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> { %cst_0 = arith.constant 0.000000e+00 : f32 %c49_i32 = arith.constant 49 : i32 - %507 = linalg.init_tensor [1, 1, 1, 1280] : tensor<1x1x1x1280xf32> + %507 = tensor.empty () : tensor<1x1x1x1280xf32> %508 = linalg.fill ins(%cst_0: f32) outs(%507: tensor<1x1x1x1280xf32>) -> tensor<1x1x1x1280xf32> - %509 = linalg.init_tensor [7, 7] : tensor<7x7xf32> + %509 = tensor.empty () : tensor<7x7xf32> %510 = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %509 : tensor<1x7x7x1280xf32>, tensor<7x7xf32>) outs(%508 : tensor<1x1x1x1280xf32>) -> tensor<1x1x1x1280xf32> - %511 = linalg.init_tensor [1, 1, 1, 1280] : tensor<1x1x1x1280xf32> + %511 = tensor.empty () : tensor<1x1x1x1280xf32> %512 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%510 : tensor<1x1x1x1280xf32>) outs(%511 : tensor<1x1x1x1280xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %526 = arith.sitofp %c49_i32 : i32 to f32 diff --git a/tests/litmus/tosa-ops/conv2d1.src.mlir b/tests/litmus/tosa-ops/conv2d1.src.mlir index d5ddb7c8..b2a1da84 100644 --- a/tests/litmus/tosa-ops/conv2d1.src.mlir +++ b/tests/litmus/tosa-ops/conv2d1.src.mlir @@ -5,6 +5,6 @@ func.func @conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> te %bias = tensor.from_elements %c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0: tensor<16xf32> %filperms = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> %arg3 = "tosa.transpose"(%arg1, %filperms) : (tensor<3x3x4x16xf32>, tensor<4xi64>) -> tensor<16x3x3x4xf32> - %0 = "tosa.conv2d"(%arg0, %arg3, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> + %0 = "tosa.conv2d"(%arg0, %arg3, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> return %0 : tensor<1x14x14x16xf32> } diff --git a/tests/litmus/tosa-ops/conv2d1.tgt.mlir b/tests/litmus/tosa-ops/conv2d1.tgt.mlir index 2ae3bde7..84157206 100644 --- a/tests/litmus/tosa-ops/conv2d1.tgt.mlir +++ b/tests/litmus/tosa-ops/conv2d1.tgt.mlir @@ -1,5 +1,5 @@ func.func @conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { - %i = linalg.init_tensor [1,14,14,16] : tensor<1x14x14x16xf32> + %i = tensor.empty () : tensor<1x14x14x16xf32> %zero = arith.constant -0.0 : f32 %out = linalg.fill ins(%zero: f32) outs(%i: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> %0 = linalg.conv_2d_nhwc_hwcf diff --git a/tests/litmus/tosa-ops/conv2d2.src.mlir b/tests/litmus/tosa-ops/conv2d2.src.mlir index a2d24c3b..fa189325 100644 --- a/tests/litmus/tosa-ops/conv2d2.src.mlir +++ b/tests/litmus/tosa-ops/conv2d2.src.mlir @@ -5,6 +5,6 @@ func.func @conv() -> tensor<1x1x1x1xf32> { %filter = arith.constant dense<[[[[1.0,1.0],[1.0,1.0],[1.0,1.0]],[[1.0,1.0],[1.0,1.0],[1.0,1.0]],[[1.0,1.0],[1.0,1.0],[1.0,1.0]]]]> : tensor<1x3x3x2xf32> %c0 = arith.constant -0.0 : f32 %bias = tensor.from_elements %c0: tensor<1xf32> - %0 = "tosa.conv2d"(%img, %filter, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x3x3x2xf32>, tensor<1x3x3x2xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> + %0 = "tosa.conv2d"(%img, %filter, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x3x3x2xf32>, tensor<1x3x3x2xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> return %0 : tensor<1x1x1x1xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/conv2d3.src.mlir b/tests/litmus/tosa-ops/conv2d3.src.mlir index 57f110db..87a8e153 100644 --- a/tests/litmus/tosa-ops/conv2d3.src.mlir +++ b/tests/litmus/tosa-ops/conv2d3.src.mlir @@ -5,6 +5,6 @@ func.func @conv() -> tensor<1x1x1x1xf32> { %fil = arith.constant dense<[[[[1.0,1.0],[1.0,1.0]],[[1.0,1.0],[1.0,1.0]]]]> : tensor<1x2x2x2xf32> %c0 = arith.constant -0.0 : f32 %bias = tensor.from_elements %c0: tensor<1xf32> - %0 = "tosa.conv2d"(%img, %fil, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> + %0 = "tosa.conv2d"(%img, %fil, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> return %0 : tensor<1x1x1x1xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/depthwise1.src.mlir b/tests/litmus/tosa-ops/depthwise1.src.mlir index fc4a9c5d..ac575239 100644 --- a/tests/litmus/tosa-ops/depthwise1.src.mlir +++ b/tests/litmus/tosa-ops/depthwise1.src.mlir @@ -5,6 +5,6 @@ func.func @conv() -> tensor<1x1x1x1xf32> { %filter = arith.constant dense<[[[[1.0]],[[1.0]],[[1.0]]],[[[1.0]],[[1.0]],[[1.0]]],[[[1.0]],[[1.0]],[[1.0]]]]> : tensor<3x3x1x1xf32> %c0 = arith.constant -0.0 : f32 %bias = tensor.from_elements %c0: tensor<1xf32> - %0 = "tosa.depthwise_conv2d"(%img, %filter, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x3x3x1xf32>, tensor<3x3x1x1xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> + %0 = "tosa.depthwise_conv2d"(%img, %filter, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x3x3x1xf32>, tensor<3x3x1x1xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> return %0 : tensor<1x1x1x1xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/depthwise2.src.mlir b/tests/litmus/tosa-ops/depthwise2.src.mlir index 70a9ca28..bc562e42 100644 --- a/tests/litmus/tosa-ops/depthwise2.src.mlir +++ b/tests/litmus/tosa-ops/depthwise2.src.mlir @@ -5,6 +5,6 @@ func.func @conv() -> tensor<1x1x1x2xf32> { %filter = arith.constant dense<[[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]],[[[1.0],[1.0]],[[1.0],[1.0]],[[1.0],[1.0]]]]> : tensor<3x3x2x1xf32> %c0 = arith.constant -0.0 : f32 %bias = tensor.from_elements %c0, %c0: tensor<2xf32> - %0 = "tosa.depthwise_conv2d"(%img, %filter, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x3x3x2xf32>, tensor<3x3x2x1xf32>, tensor<2xf32>) -> tensor<1x1x1x2xf32> + %0 = "tosa.depthwise_conv2d"(%img, %filter, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x3x3x2xf32>, tensor<3x3x2x1xf32>, tensor<2xf32>) -> tensor<1x1x1x2xf32> return %0 : tensor<1x1x1x2xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/depthwise3.src.mlir b/tests/litmus/tosa-ops/depthwise3.src.mlir index 2d4fba03..f0ef726c 100644 --- a/tests/litmus/tosa-ops/depthwise3.src.mlir +++ b/tests/litmus/tosa-ops/depthwise3.src.mlir @@ -4,7 +4,7 @@ func.func @depthwise_conv() -> tensor<1x1x1x2xf32> { %arg0 = arith.constant dense<[[[[1.0,1.0,1.0]]]]>: tensor<1x1x1x3xf32> %arg1 = arith.constant dense<[[[[1.0,2.0],[3.0,4.0],[5.0,6.0]]]]>: tensor<1x1x3x2xf32> %arg2 = arith.constant dense<[7.0,8.0,9.0,10.0,11.0,12.0]>: tensor<6xf32> - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x1x1x3xf32>, tensor<1x1x3x2xf32>, tensor<6xf32>) -> (tensor<1x1x1x6xf32>) + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = array, stride = array, dilation = array } : (tensor<1x1x1x3xf32>, tensor<1x1x3x2xf32>, tensor<6xf32>) -> (tensor<1x1x1x6xf32>) %1 = tensor.extract_slice %0[0,0,0,4][1,1,1,2][1,1,1,1]: tensor<1x1x1x6xf32> to tensor<2xf32> %2 = tensor.expand_shape %1 [[0,1,2,3]] : tensor<2xf32> into tensor<1x1x1x2xf32> return %2 : tensor<1x1x1x2xf32> diff --git a/tests/litmus/tosa-ops/depthwise3.tgt.mlir b/tests/litmus/tosa-ops/depthwise3.tgt.mlir index 0dac743a..3901d027 100644 --- a/tests/litmus/tosa-ops/depthwise3.tgt.mlir +++ b/tests/litmus/tosa-ops/depthwise3.tgt.mlir @@ -9,6 +9,6 @@ func.func @depthwise_conv() -> tensor<1x1x1x2xf32> { %bias = tensor.extract_slice %arg2[4][2][1]: tensor<6xf32> to tensor<2xf32> %filperms = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> %fil3 = "tosa.transpose"(%fil2, %filperms) : (tensor<1x1x1x2xf32>, tensor<4xi64>) -> tensor<2x1x1x1xf32> - %0 = "tosa.conv2d"(%in2, %fil3, %bias) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x1x1x1xf32>, tensor<2x1x1x1xf32>, tensor<2xf32>) -> (tensor<1x1x1x2xf32>) + %0 = "tosa.conv2d"(%in2, %fil3, %bias) { pad = array, stride = array, dilation = array } : (tensor<1x1x1x1xf32>, tensor<2x1x1x1xf32>, tensor<2xf32>) -> (tensor<1x1x1x2xf32>) return %0 : tensor<1x1x1x2xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/depthwise4.src.mlir b/tests/litmus/tosa-ops/depthwise4.src.mlir index a3679af2..71a6f4c7 100644 --- a/tests/litmus/tosa-ops/depthwise4.src.mlir +++ b/tests/litmus/tosa-ops/depthwise4.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @depthwise_conv(%arg0 : tensor<1x1x1x3xf32>, %arg1 : tensor<1x1x3x2xf32>, %arg2 : tensor<6xf32>) -> tensor<1x1x1x2xf32> { - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x1x1x3xf32>, tensor<1x1x3x2xf32>, tensor<6xf32>) -> (tensor<1x1x1x6xf32>) + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = array, stride = array, dilation = array } : (tensor<1x1x1x3xf32>, tensor<1x1x3x2xf32>, tensor<6xf32>) -> (tensor<1x1x1x6xf32>) %1 = tensor.extract_slice %0[0,0,0,4][1,1,1,2][1,1,1,1]: tensor<1x1x1x6xf32> to tensor<2xf32> %2 = tensor.expand_shape %1 [[0,1,2,3]] : tensor<2xf32> into tensor<1x1x1x2xf32> return %2 : tensor<1x1x1x2xf32> diff --git a/tests/litmus/tosa-ops/depthwise4.tgt.mlir b/tests/litmus/tosa-ops/depthwise4.tgt.mlir index 563e35c4..9a818aab 100644 --- a/tests/litmus/tosa-ops/depthwise4.tgt.mlir +++ b/tests/litmus/tosa-ops/depthwise4.tgt.mlir @@ -6,6 +6,6 @@ func.func @depthwise_conv(%arg0 : tensor<1x1x1x3xf32>, %arg1 : tensor<1x1x3x2xf3 %bias = tensor.extract_slice %arg2[4][2][1]: tensor<6xf32> to tensor<2xf32> %filperms = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> %fil3 = "tosa.transpose"(%fil2, %filperms) : (tensor<1x1x1x2xf32>, tensor<4xi64>) -> tensor<2x1x1x1xf32> - %0 = "tosa.conv2d"(%in2, %fil3, %bias) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x1x1x1xf32>, tensor<2x1x1x1xf32>, tensor<2xf32>) -> (tensor<1x1x1x2xf32>) + %0 = "tosa.conv2d"(%in2, %fil3, %bias) { pad = array, stride = array, dilation = array } : (tensor<1x1x1x1xf32>, tensor<2x1x1x1xf32>, tensor<2xf32>) -> (tensor<1x1x1x2xf32>) return %0 : tensor<1x1x1x2xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/depthwise5.src.mlir b/tests/litmus/tosa-ops/depthwise5.src.mlir index 8aac3945..c5f7f7bd 100644 --- a/tests/litmus/tosa-ops/depthwise5.src.mlir +++ b/tests/litmus/tosa-ops/depthwise5.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @depthwise_conv(%arg0 : tensor<2x7x5x3xf32>, %arg1 : tensor<3x1x3x2xf32>, %arg2 : tensor<6xf32>) -> tensor<2x5x5x2xf32> { - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<2x7x5x3xf32>, tensor<3x1x3x2xf32>, tensor<6xf32>) -> (tensor<2x5x5x6xf32>) + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = array, stride = array, dilation = array } : (tensor<2x7x5x3xf32>, tensor<3x1x3x2xf32>, tensor<6xf32>) -> (tensor<2x5x5x6xf32>) %1 = tensor.extract_slice %0[0,0,0,4][2,5,5,2][1,1,1,1]: tensor<2x5x5x6xf32> to tensor<2x5x5x2xf32> return %1 : tensor<2x5x5x2xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/depthwise5.tgt.mlir b/tests/litmus/tosa-ops/depthwise5.tgt.mlir index 569700a6..53358d19 100644 --- a/tests/litmus/tosa-ops/depthwise5.tgt.mlir +++ b/tests/litmus/tosa-ops/depthwise5.tgt.mlir @@ -1,6 +1,6 @@ func.func @depthwise_conv(%arg0 : tensor<2x7x5x3xf32>, %arg1 : tensor<3x1x3x2xf32>, %arg2 : tensor<6xf32>) -> tensor<2x5x5x2xf32> { - %i = linalg.init_tensor [2, 7, 5, 1] : tensor<2x7x5x1xf32> - %i2 = linalg.init_tensor [3, 1, 1, 2] : tensor<3x1x1x2xf32> + %i = tensor.empty () : tensor<2x7x5x1xf32> + %i2 = tensor.empty () : tensor<3x1x1x2xf32> %in = tensor.extract_slice %arg0[0,0,0,2][2,7,5,1][1,1,1,1]: tensor<2x7x5x3xf32> to tensor<2x7x5xf32> %fil = tensor.extract_slice %arg1[0,0,2,0][3,1,1,2][1,1,1,1]: tensor<3x1x3x2xf32> to tensor<3x2xf32> %in2 = tensor.expand_shape %in [[0],[1],[2,3]] : tensor<2x7x5xf32> into tensor<2x7x5x1xf32> @@ -8,6 +8,6 @@ func.func @depthwise_conv(%arg0 : tensor<2x7x5x3xf32>, %arg1 : tensor<3x1x3x2xf3 %bias = tensor.extract_slice %arg2[4][2][1]: tensor<6xf32> to tensor<2xf32> %filperms = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> %fil3 = "tosa.transpose"(%fil2, %filperms) : (tensor<3x1x1x2xf32>, tensor<4xi64>) -> tensor<2x3x1x1xf32> - %0 = "tosa.conv2d"(%in2, %fil3, %bias) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<2x7x5x1xf32>, tensor<2x3x1x1xf32>, tensor<2xf32>) -> (tensor<2x5x5x2xf32>) + %0 = "tosa.conv2d"(%in2, %fil3, %bias) { pad = array, stride = array, dilation = array } : (tensor<2x7x5x1xf32>, tensor<2x3x1x1xf32>, tensor<2xf32>) -> (tensor<2x5x5x2xf32>) return %0 : tensor<2x5x5x2xf32> } \ No newline at end of file diff --git a/tests/litmus/tosa-ops/gather-const-bad.src.mlir b/tests/litmus/tosa-ops/gather-const-bad.src.mlir index 7ee9e386..a2dc33ef 100644 --- a/tests/litmus/tosa-ops/gather-const-bad.src.mlir +++ b/tests/litmus/tosa-ops/gather-const-bad.src.mlir @@ -1,6 +1,6 @@ // VERIFY-INCORRECT func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,2,1]: tensor<2x2x1xf32> + %v = tensor.empty (): tensor<2x2x1xf32> %zero = arith.constant 0: index %one = arith.constant 1: index %c1 = arith.constant 10.0: f32 diff --git a/tests/litmus/tosa-ops/gather-const.src.mlir b/tests/litmus/tosa-ops/gather-const.src.mlir index d541c662..3d3b1d9a 100644 --- a/tests/litmus/tosa-ops/gather-const.src.mlir +++ b/tests/litmus/tosa-ops/gather-const.src.mlir @@ -1,6 +1,6 @@ // VERIFY func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,2,1]: tensor<2x2x1xf32> + %v = tensor.empty (): tensor<2x2x1xf32> %zero = arith.constant 0: index %one = arith.constant 1: index %c1 = arith.constant 10.0: f32 diff --git a/tests/litmus/tosa-ops/gather-oob.src.mlir b/tests/litmus/tosa-ops/gather-oob.src.mlir index 04f64af9..5f34b7a8 100644 --- a/tests/litmus/tosa-ops/gather-oob.src.mlir +++ b/tests/litmus/tosa-ops/gather-oob.src.mlir @@ -1,7 +1,7 @@ // EXPECT: "correct (source is always undefined)" func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,2,1]: tensor<2x2x1xf32> + %v = tensor.empty (): tensor<2x2x1xf32> %zero = arith.constant 0: index %one = arith.constant 1: index %c1 = arith.constant 10.0: f32 diff --git a/tests/litmus/tosa-ops/gather-oob.tgt.mlir b/tests/litmus/tosa-ops/gather-oob.tgt.mlir index 352bab09..1dfc3f09 100644 --- a/tests/litmus/tosa-ops/gather-oob.tgt.mlir +++ b/tests/litmus/tosa-ops/gather-oob.tgt.mlir @@ -1,4 +1,4 @@ func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,1,1]: tensor<2x1x1xf32> + %v = tensor.empty (): tensor<2x1x1xf32> return %v: tensor<2x1x1xf32> } diff --git a/tests/litmus/tosa-ops/gather-uninit-index.src.mlir b/tests/litmus/tosa-ops/gather-uninit-index.src.mlir index a390528b..f93dbd2e 100644 --- a/tests/litmus/tosa-ops/gather-uninit-index.src.mlir +++ b/tests/litmus/tosa-ops/gather-uninit-index.src.mlir @@ -1,7 +1,7 @@ // EXPECT: "correct (source is always undefined)" func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,2,1]: tensor<2x2x1xf32> + %v = tensor.empty (): tensor<2x2x1xf32> %zero = arith.constant 0: index %one = arith.constant 1: index %c1 = arith.constant 10.0: f32 @@ -9,7 +9,7 @@ func.func @test_gather() -> tensor<2x1x1xf32> { %v1 = tensor.insert %c1 into %v [%zero, %zero, %zero]: tensor<2x2x1xf32> %v2 = tensor.insert %c2 into %v1[%one, %zero, %zero]: tensor<2x2x1xf32> - %indices = linalg.init_tensor [2,1]: tensor<2x1xi32> + %indices = tensor.empty (): tensor<2x1xi32> %0 = "tosa.gather"(%v2, %indices) : (tensor<2x2x1xf32>, tensor<2x1xi32>) -> tensor<2x1x1xf32> return %0 : tensor<2x1x1xf32> diff --git a/tests/litmus/tosa-ops/gather-uninit-index.tgt.mlir b/tests/litmus/tosa-ops/gather-uninit-index.tgt.mlir index 352bab09..1dfc3f09 100644 --- a/tests/litmus/tosa-ops/gather-uninit-index.tgt.mlir +++ b/tests/litmus/tosa-ops/gather-uninit-index.tgt.mlir @@ -1,4 +1,4 @@ func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,1,1]: tensor<2x1x1xf32> + %v = tensor.empty (): tensor<2x1x1xf32> return %v: tensor<2x1x1xf32> } diff --git a/tests/litmus/tosa-ops/gather-uninit.src.mlir b/tests/litmus/tosa-ops/gather-uninit.src.mlir index 7538a754..589ff69c 100644 --- a/tests/litmus/tosa-ops/gather-uninit.src.mlir +++ b/tests/litmus/tosa-ops/gather-uninit.src.mlir @@ -1,7 +1,7 @@ // EXPECT: "correct (source is always undefined)" func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,2,1]: tensor<2x2x1xf32> + %v = tensor.empty (): tensor<2x2x1xf32> %zero = arith.constant 0: index %one = arith.constant 1: index %c1 = arith.constant 10.0: f32 diff --git a/tests/litmus/tosa-ops/gather-uninit.tgt.mlir b/tests/litmus/tosa-ops/gather-uninit.tgt.mlir index 352bab09..1dfc3f09 100644 --- a/tests/litmus/tosa-ops/gather-uninit.tgt.mlir +++ b/tests/litmus/tosa-ops/gather-uninit.tgt.mlir @@ -1,4 +1,4 @@ func.func @test_gather() -> tensor<2x1x1xf32> { - %v = linalg.init_tensor [2,1,1]: tensor<2x1x1xf32> + %v = tensor.empty (): tensor<2x1x1xf32> return %v: tensor<2x1x1xf32> } diff --git a/tests/litmus/tosa-ops/maxpool2d.src.mlir b/tests/litmus/tosa-ops/maxpool2d.src.mlir index 60fe0f04..18bba364 100644 --- a/tests/litmus/tosa-ops/maxpool2d.src.mlir +++ b/tests/litmus/tosa-ops/maxpool2d.src.mlir @@ -2,6 +2,6 @@ // ARGS: --use-neg-zero func.func @maxpool(%arg0: tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> { - %0 = "tosa.max_pool2d"(%arg0) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> return %0 : tensor<1x1x1x1280xf32> } diff --git a/tests/litmus/tosa-ops/maxpool2d.tgt.mlir b/tests/litmus/tosa-ops/maxpool2d.tgt.mlir index 6865259c..ea69fbce 100644 --- a/tests/litmus/tosa-ops/maxpool2d.tgt.mlir +++ b/tests/litmus/tosa-ops/maxpool2d.tgt.mlir @@ -1,8 +1,8 @@ func.func @maxpool(%arg0: tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> { %cst = arith.constant -3.40282347E+38 : f32 - %0 = linalg.init_tensor [1, 1, 1, 1280] : tensor<1x1x1x1280xf32> + %0 = tensor.empty () : tensor<1x1x1x1280xf32> %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<1x1x1x1280xf32>) -> tensor<1x1x1x1280xf32> - %2 = linalg.init_tensor [7, 7] : tensor<7x7xf32> + %2 = tensor.empty () : tensor<7x7xf32> %3 = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %2 : tensor<1x7x7x1280xf32>, tensor<7x7xf32>) outs(%1 : tensor<1x1x1x1280xf32>) -> tensor<1x1x1x1280xf32> return %3 : tensor<1x1x1x1280xf32> } diff --git a/tests/litmus/tosa-ops/maxpool2d_memref.src.mlir b/tests/litmus/tosa-ops/maxpool2d_memref.src.mlir index 60fe0f04..18bba364 100644 --- a/tests/litmus/tosa-ops/maxpool2d_memref.src.mlir +++ b/tests/litmus/tosa-ops/maxpool2d_memref.src.mlir @@ -2,6 +2,6 @@ // ARGS: --use-neg-zero func.func @maxpool(%arg0: tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> { - %0 = "tosa.max_pool2d"(%arg0) {kernel = [7, 7], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x1280xf32>) -> tensor<1x1x1x1280xf32> return %0 : tensor<1x1x1x1280xf32> } diff --git a/tests/litmus/tosa-ops/maxpool_noop.src.mlir b/tests/litmus/tosa-ops/maxpool_noop.src.mlir index 0213cb1d..e7775661 100644 --- a/tests/litmus/tosa-ops/maxpool_noop.src.mlir +++ b/tests/litmus/tosa-ops/maxpool_noop.src.mlir @@ -1,6 +1,6 @@ // VERIFY func.func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> { - %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array, dilation = array} : (tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> return %0 : tensor<10x1x1x3xf32> } diff --git a/tests/litmus/tosa-ops/reshape.src.mlir b/tests/litmus/tosa-ops/reshape.src.mlir index 5c640da5..0aae65c7 100644 --- a/tests/litmus/tosa-ops/reshape.src.mlir +++ b/tests/litmus/tosa-ops/reshape.src.mlir @@ -2,7 +2,7 @@ func.func @f() -> (i32, i32) { %t = "tosa.const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<4xi32> - %t2 = "tosa.reshape"(%t) {new_shape = [2, 2]} : (tensor<4xi32>) -> tensor<2x2xi32> + %t2 = "tosa.reshape"(%t) {new_shape = array} : (tensor<4xi32>) -> tensor<2x2xi32> %c0 = arith.constant 0: index %c1 = arith.constant 1: index %two = tensor.extract %t2[%c0,%c1]: tensor<2x2xi32> diff --git a/tests/litmus/tosa-ops/tile-bad.src.mlir b/tests/litmus/tosa-ops/tile-bad.src.mlir index 4a9b9c9b..05a9c56b 100644 --- a/tests/litmus/tosa-ops/tile-bad.src.mlir +++ b/tests/litmus/tosa-ops/tile-bad.src.mlir @@ -2,7 +2,7 @@ func.func @f(%x0: tensor<3x3xf32>) -> tensor<6x9xf32> { %x = "tosa.reverse"(%x0) {axis = 0: i64}: (tensor<3x3xf32>) -> tensor<3x3xf32> - %a = "tosa.tile"(%x) {multiples = [2, 1]} : (tensor<3x3xf32>) -> (tensor<6x3xf32>) - %b = "tosa.tile"(%a) {multiples = [1, 3]} : (tensor<6x3xf32>) -> (tensor<6x9xf32>) + %a = "tosa.tile"(%x) {multiples = array} : (tensor<3x3xf32>) -> (tensor<6x3xf32>) + %b = "tosa.tile"(%a) {multiples = array} : (tensor<6x3xf32>) -> (tensor<6x9xf32>) return %b: tensor<6x9xf32> } diff --git a/tests/litmus/tosa-ops/tile-bad.tgt.mlir b/tests/litmus/tosa-ops/tile-bad.tgt.mlir index f2507752..eece2ad2 100644 --- a/tests/litmus/tosa-ops/tile-bad.tgt.mlir +++ b/tests/litmus/tosa-ops/tile-bad.tgt.mlir @@ -1,4 +1,4 @@ func.func @f(%x: tensor<3x3xf32>) -> tensor<6x9xf32> { - %a = "tosa.tile"(%x) {multiples = [2, 3]} : (tensor<3x3xf32>) -> (tensor<6x9xf32>) + %a = "tosa.tile"(%x) {multiples = array} : (tensor<3x3xf32>) -> (tensor<6x9xf32>) return %a: tensor<6x9xf32> } diff --git a/tests/litmus/tosa-ops/tile.src.mlir b/tests/litmus/tosa-ops/tile.src.mlir index e56bef6e..ed0a083a 100644 --- a/tests/litmus/tosa-ops/tile.src.mlir +++ b/tests/litmus/tosa-ops/tile.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @f(%x: tensor<3x3xf32>) -> tensor<6x9xf32> { - %a = "tosa.tile"(%x) {multiples = [2, 1]} : (tensor<3x3xf32>) -> (tensor<6x3xf32>) - %b = "tosa.tile"(%a) {multiples = [1, 3]} : (tensor<6x3xf32>) -> (tensor<6x9xf32>) + %a = "tosa.tile"(%x) {multiples = array} : (tensor<3x3xf32>) -> (tensor<6x3xf32>) + %b = "tosa.tile"(%a) {multiples = array} : (tensor<6x3xf32>) -> (tensor<6x9xf32>) return %b: tensor<6x9xf32> } diff --git a/tests/litmus/tosa-ops/tile.tgt.mlir b/tests/litmus/tosa-ops/tile.tgt.mlir index f2507752..eece2ad2 100644 --- a/tests/litmus/tosa-ops/tile.tgt.mlir +++ b/tests/litmus/tosa-ops/tile.tgt.mlir @@ -1,4 +1,4 @@ func.func @f(%x: tensor<3x3xf32>) -> tensor<6x9xf32> { - %a = "tosa.tile"(%x) {multiples = [2, 3]} : (tensor<3x3xf32>) -> (tensor<6x9xf32>) + %a = "tosa.tile"(%x) {multiples = array} : (tensor<3x3xf32>) -> (tensor<6x9xf32>) return %a: tensor<6x9xf32> } diff --git a/tests/litmus/tosa-ops/transpose3.src.mlir b/tests/litmus/tosa-ops/transpose3.src.mlir index 2187e843..5ae11cbe 100644 --- a/tests/litmus/tosa-ops/transpose3.src.mlir +++ b/tests/litmus/tosa-ops/transpose3.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { - %out = linalg.init_tensor [1,14,14,16] : tensor<1x14x14x16xf32> + %out = tensor.empty () : tensor<1x14x14x16xf32> %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) diff --git a/tests/litmus/tosa-ops/transpose3.tgt.mlir b/tests/litmus/tosa-ops/transpose3.tgt.mlir index f23c40dd..2f76443b 100644 --- a/tests/litmus/tosa-ops/transpose3.tgt.mlir +++ b/tests/litmus/tosa-ops/transpose3.tgt.mlir @@ -1,5 +1,5 @@ func.func @conv(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { - %out = linalg.init_tensor [1,16,14,14] : tensor<1x16x14x14xf32> + %out = tensor.empty () : tensor<1x16x14x14xf32> %inperms = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> %filperms = "tosa.const"() {value = dense<[3, 2, 0, 1]> : tensor<4xi64>} : () -> tensor<4xi64> %outperms = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> diff --git a/tests/litmus/verbose/conv2d1.src.mlir b/tests/litmus/verbose/conv2d1.src.mlir index 17885f82..d6b53667 100644 --- a/tests/litmus/verbose/conv2d1.src.mlir +++ b/tests/litmus/verbose/conv2d1.src.mlir @@ -4,6 +4,6 @@ func.func @conv(%img: tensor<1x16x16x4xf32>, %filtr: tensor<16x3x3x4xf32>) -> tensor<1x14x14x16xf32> { %c0 = arith.constant 0.0 : f32 %bias = tensor.from_elements %c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0: tensor<16xf32> - %0 = "tosa.conv2d"(%img, %filtr, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> + %0 = "tosa.conv2d"(%img, %filtr, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> return %0 : tensor<1x14x14x16xf32> } diff --git a/tests/litmus/verbose/conv2d1.tgt.mlir b/tests/litmus/verbose/conv2d1.tgt.mlir index 32aefe9e..c1ae5bbb 100644 --- a/tests/litmus/verbose/conv2d1.tgt.mlir +++ b/tests/litmus/verbose/conv2d1.tgt.mlir @@ -1,6 +1,6 @@ func.func @conv(%img: tensor<1x16x16x4xf32>, %filtr: tensor<16x3x3x4xf32>) -> tensor<1x14x14x16xf32> { %c0 = arith.constant 0.0 : f32 %bias = tensor.from_elements %c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0: tensor<16xf32> - %0 = "tosa.conv2d"(%img, %filtr, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> + %0 = "tosa.conv2d"(%img, %filtr, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> return %0 : tensor<1x14x14x16xf32> } diff --git a/tests/long-opts/conv2d-to-img2col/nhwc_filter-bad.tgt.mlir b/tests/long-opts/conv2d-to-img2col/nhwc_filter-bad.tgt.mlir index c520fad8..4d940a22 100644 --- a/tests/long-opts/conv2d-to-img2col/nhwc_filter-bad.tgt.mlir +++ b/tests/long-opts/conv2d-to-img2col/nhwc_filter-bad.tgt.mlir @@ -2,7 +2,7 @@ #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> module { func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { - %0 = linalg.init_tensor [1, 14, 14, 3, 3, 4] : tensor<1x14x14x3x3x4xf32> + %0 = tensor.empty () : tensor<1x14x14x3x3x4xf32> %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x16x16x4xf32>) outs(%0 : tensor<1x14x14x3x3x4xf32>) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors linalg.yield %arg3 : f32 diff --git a/tests/long-opts/conv2d/conv2d2.src.mlir b/tests/long-opts/conv2d/conv2d2.src.mlir index b58d02fa..9971e98e 100644 --- a/tests/long-opts/conv2d/conv2d2.src.mlir +++ b/tests/long-opts/conv2d/conv2d2.src.mlir @@ -7,6 +7,6 @@ func.func @conv(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> te %bias = tensor.from_elements %c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0: tensor<16xf32> %filperms = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> %arg3 = "tosa.transpose"(%arg1, %filperms) : (tensor<3x3x4x16xf32>, tensor<4xi64>) -> tensor<16x3x3x4xf32> - %0 = "tosa.conv2d"(%arg0, %arg3, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [2, 2]} : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> + %0 = "tosa.conv2d"(%arg0, %arg3, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> return %0 : tensor<1x14x14x16xf32> } diff --git a/tests/long-opts/conv2d/conv2d2.tgt.mlir b/tests/long-opts/conv2d/conv2d2.tgt.mlir index 92771480..0c4f4b84 100644 --- a/tests/long-opts/conv2d/conv2d2.tgt.mlir +++ b/tests/long-opts/conv2d/conv2d2.tgt.mlir @@ -1,5 +1,5 @@ func.func @conv(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { - %out = linalg.init_tensor [1,14,14,16] : tensor<1x14x14x16xf32> + %out = tensor.empty () : tensor<1x14x14x16xf32> %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64> } ins(%arg0, %arg1: tensor<1x29x29x4xf32>, tensor<3x3x4x16xf32>) diff --git a/tests/long-opts/tosa-to-linalg/conv2d2.src.mlir b/tests/long-opts/tosa-to-linalg/conv2d2.src.mlir index 8fb974b7..a0244005 100644 --- a/tests/long-opts/tosa-to-linalg/conv2d2.src.mlir +++ b/tests/long-opts/tosa-to-linalg/conv2d2.src.mlir @@ -7,6 +7,6 @@ func.func @conv(%img: tensor<1x29x29x4xf32>, %filter: tensor<16x3x3x4xf32>) -> tensor<1x14x14x16xf32> { %c0 = arith.constant 0.0 : f32 %bias = tensor.from_elements %c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0,%c0: tensor<16xf32> - %0 = "tosa.conv2d"(%img, %filter, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [2, 2]} : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> + %0 = "tosa.conv2d"(%img, %filter, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x14x14x16xf32> return %0 : tensor<1x14x14x16xf32> } diff --git a/tests/long-opts/tosa-to-linalg/conv2d2.tgt.mlir b/tests/long-opts/tosa-to-linalg/conv2d2.tgt.mlir index b0eaccf3..95a0cb3f 100644 --- a/tests/long-opts/tosa-to-linalg/conv2d2.tgt.mlir +++ b/tests/long-opts/tosa-to-linalg/conv2d2.tgt.mlir @@ -6,15 +6,15 @@ module { %cst = arith.constant 0.000000e+00 : f32 %0 = tensor.from_elements %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst, %cst : tensor<16xf32> %cst_0 = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64> - %1 = linalg.init_tensor [3, 3, 4, 16] : tensor<3x3x4x16xf32> + %1 = tensor.empty () : tensor<3x3x4x16xf32> %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<16x3x3x4xf32>) outs(%1 : tensor<3x3x4x16xf32>) { ^bb0(%arg2: f32, %arg3: f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<3x3x4x16xf32> - %3 = linalg.init_tensor [1, 14, 14, 16] : tensor<1x14x14x16xf32> + %3 = tensor.empty () : tensor<1x14x14x16xf32> %cst_1 = arith.constant 0.000000e+00 : f32 %4 = linalg.fill ins(%cst_1: f32) outs(%3: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> - %5 = linalg.init_tensor [1, 14, 14, 16] : tensor<1x14x14x16xf32> + %5 = tensor.empty () : tensor<1x14x14x16xf32> %6 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %2 : tensor<1x29x29x4xf32>, tensor<3x3x4x16xf32>) outs(%4 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> %7 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0, %6 : tensor<16xf32>, tensor<1x14x14x16xf32>) outs(%5 : tensor<1x14x14x16xf32>) { ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors diff --git a/tests/long-opts/tosa-to-linalg/conv_pad.src.mlir b/tests/long-opts/tosa-to-linalg/conv_pad.src.mlir index f84a4570..e6f57b04 100644 --- a/tests/long-opts/tosa-to-linalg/conv_pad.src.mlir +++ b/tests/long-opts/tosa-to-linalg/conv_pad.src.mlir @@ -5,6 +5,6 @@ // tgt is filling a non-identity value (+0.0) to the output tensor. func.func @conv(%arg0: tensor<2x4x4x3xf32>, %arg1: tensor<16x3x6x3xf32>, %arg2: tensor<16xf32>) -> tensor<2x6x9x16xf32> { - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [2, 2, 5, 5], stride = [1, 1]} : (tensor<2x4x4x3xf32>, tensor<16x3x6x3xf32>, tensor<16xf32>) -> tensor<2x6x9x16xf32> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : (tensor<2x4x4x3xf32>, tensor<16x3x6x3xf32>, tensor<16xf32>) -> tensor<2x6x9x16xf32> return %0 : tensor<2x6x9x16xf32> } diff --git a/tests/long-opts/tosa-to-linalg/conv_pad.tgt.mlir b/tests/long-opts/tosa-to-linalg/conv_pad.tgt.mlir index d89a1be4..f0d32485 100644 --- a/tests/long-opts/tosa-to-linalg/conv_pad.tgt.mlir +++ b/tests/long-opts/tosa-to-linalg/conv_pad.tgt.mlir @@ -9,15 +9,15 @@ module { tensor.yield %cst : f32 } : tensor<2x4x4x3xf32> to tensor<2x8x14x3xf32> %cst_0 = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64> - %1 = linalg.init_tensor [3, 6, 3, 16] : tensor<3x6x3x16xf32> + %1 = tensor.empty () : tensor<3x6x3x16xf32> %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<16x3x6x3xf32>) outs(%1 : tensor<3x6x3x16xf32>) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors linalg.yield %arg3 : f32 } -> tensor<3x6x3x16xf32> - %3 = linalg.init_tensor [2, 6, 9, 16] : tensor<2x6x9x16xf32> + %3 = tensor.empty () : tensor<2x6x9x16xf32> %cst_1 = arith.constant 0.000000e+00 : f32 %4 = linalg.fill ins(%cst_1: f32) outs(%3: tensor<2x6x9x16xf32>) -> tensor<2x6x9x16xf32> - %5 = linalg.init_tensor [2, 6, 9, 16] : tensor<2x6x9x16xf32> + %5 = tensor.empty () : tensor<2x6x9x16xf32> %6 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%0, %2 : tensor<2x8x14x3xf32>, tensor<3x6x3x16xf32>) outs(%4 : tensor<2x6x9x16xf32>) -> tensor<2x6x9x16xf32> %7 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %6 : tensor<16xf32>, tensor<2x6x9x16xf32>) outs(%5 : tensor<2x6x9x16xf32>) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors diff --git a/tests/long-opts/tosa-to-linalg/depthwise2.src.mlir b/tests/long-opts/tosa-to-linalg/depthwise2.src.mlir index 04fea76a..0d8ab79a 100644 --- a/tests/long-opts/tosa-to-linalg/depthwise2.src.mlir +++ b/tests/long-opts/tosa-to-linalg/depthwise2.src.mlir @@ -5,6 +5,6 @@ // filling a non-identity value (+0.0) to the output tensor. func.func @depthwise2(%arg0: tensor<2x5x5x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<2x6x6x6xf32> { - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [1, 1]} : (tensor<2x5x5x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<2x6x6x6xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array, stride = array, dilation = array} : (tensor<2x5x5x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<2x6x6x6xf32> return %0 : tensor<2x6x6x6xf32> } diff --git a/tests/long-opts/tosa-to-linalg/depthwise2.tgt.mlir b/tests/long-opts/tosa-to-linalg/depthwise2.tgt.mlir index 3a09bba9..068ccd58 100644 --- a/tests/long-opts/tosa-to-linalg/depthwise2.tgt.mlir +++ b/tests/long-opts/tosa-to-linalg/depthwise2.tgt.mlir @@ -7,10 +7,10 @@ module { ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): // no predecessors tensor.yield %cst : f32 } : tensor<2x5x5x2xf32> to tensor<2x7x7x2xf32> - %1 = linalg.init_tensor [2, 6, 6, 2, 3] : tensor<2x6x6x2x3xf32> + %1 = tensor.empty () : tensor<2x6x6x2x3xf32> %cst_0 = arith.constant 0.000000e+00 : f32 %2 = linalg.fill ins(%cst_0: f32) outs(%1: tensor<2x6x6x2x3xf32>) -> tensor<2x6x6x2x3xf32> - %3 = linalg.init_tensor [2, 6, 6, 6] : tensor<2x6x6x6xf32> + %3 = tensor.empty () : tensor<2x6x6x6xf32> %4 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%0, %arg1 : tensor<2x7x7x2xf32>, tensor<2x2x2x3xf32>) outs(%2 : tensor<2x6x6x2x3xf32>) -> tensor<2x6x6x2x3xf32> %5 = tensor.collapse_shape %4 [[0], [1], [2], [3, 4]] : tensor<2x6x6x2x3xf32> into tensor<2x6x6x6xf32> %6 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %5 : tensor<6xf32>, tensor<2x6x6x6xf32>) outs(%3 : tensor<2x6x6x6xf32>) { diff --git a/tests/opts/conv2d-to-img2col/nhwc_filter.tgt.mlir b/tests/opts/conv2d-to-img2col/nhwc_filter.tgt.mlir index 7f3a159f..58cac599 100644 --- a/tests/opts/conv2d-to-img2col/nhwc_filter.tgt.mlir +++ b/tests/opts/conv2d-to-img2col/nhwc_filter.tgt.mlir @@ -2,7 +2,7 @@ #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> module { func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> { - %0 = linalg.init_tensor [1, 14, 14, 3, 3, 4] : tensor<1x14x14x3x3x4xf32> + %0 = tensor.empty () : tensor<1x14x14x3x3x4xf32> %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x16x16x4xf32>) diff --git a/tests/opts/fusion-tensor/const-bad.src.mlir b/tests/opts/fusion-tensor/const-bad.src.mlir index 3e125c03..546e771c 100644 --- a/tests/opts/fusion-tensor/const-bad.src.mlir +++ b/tests/opts/fusion-tensor/const-bad.src.mlir @@ -10,7 +10,7 @@ func.func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x? %cst = arith.constant dense<42.0> : tensor<5xf32> %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32> %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> + %2 = tensor.empty (%0, %1) : tensor<5x?x?xf32> %3 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} diff --git a/tests/opts/fusion-tensor/const-bad.tgt.mlir b/tests/opts/fusion-tensor/const-bad.tgt.mlir index 190ff0cf..ba06a8fb 100644 --- a/tests/opts/fusion-tensor/const-bad.tgt.mlir +++ b/tests/opts/fusion-tensor/const-bad.tgt.mlir @@ -6,7 +6,7 @@ module { %cst = arith.constant 5.200000e+01 : f32 // wrong constant %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32> %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> + %2 = tensor.empty (%0, %1) : tensor<5x?x?xf32> %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<5x?x?xf32>) outs(%2 : tensor<5x?x?xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %4 = arith.mulf %cst, %arg1 : f32 diff --git a/tests/opts/fusion-tensor/const.src.mlir b/tests/opts/fusion-tensor/const.src.mlir index c7504d20..89507c59 100644 --- a/tests/opts/fusion-tensor/const.src.mlir +++ b/tests/opts/fusion-tensor/const.src.mlir @@ -10,7 +10,7 @@ func.func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x? %cst = arith.constant dense<42.0> : tensor<5xf32> %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32> %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> + %2 = tensor.empty (%0, %1) : tensor<5x?x?xf32> %3 = linalg.generic { indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel"]} diff --git a/tests/opts/fusion-tensor/const.tgt.mlir b/tests/opts/fusion-tensor/const.tgt.mlir index c69567ee..342adadf 100644 --- a/tests/opts/fusion-tensor/const.tgt.mlir +++ b/tests/opts/fusion-tensor/const.tgt.mlir @@ -6,7 +6,7 @@ module { %cst = arith.constant 4.200000e+01 : f32 %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32> %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32> - %2 = linalg.init_tensor [5, %0, %1] : tensor<5x?x?xf32> + %2 = tensor.empty (%0, %1) : tensor<5x?x?xf32> %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<5x?x?xf32>) outs(%2 : tensor<5x?x?xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %4 = arith.mulf %cst, %arg1 : f32 diff --git a/tests/opts/fusion-tensor/i32-bad.src.mlir b/tests/opts/fusion-tensor/i32-bad.src.mlir index 97bd7bb5..0fad9651 100644 --- a/tests/opts/fusion-tensor/i32-bad.src.mlir +++ b/tests/opts/fusion-tensor/i32-bad.src.mlir @@ -7,7 +7,7 @@ func.func @producer_indexed_consumer_fusion(%arg0: tensor, %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"] } diff --git a/tests/opts/fusion-tensor/i32-bad.tgt.mlir b/tests/opts/fusion-tensor/i32-bad.tgt.mlir index 0f48a26e..dbfb6803 100644 --- a/tests/opts/fusion-tensor/i32-bad.tgt.mlir +++ b/tests/opts/fusion-tensor/i32-bad.tgt.mlir @@ -5,7 +5,7 @@ module { %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors %4 = arith.addi %arg2, %arg3 : i32 diff --git a/tests/opts/fusion-tensor/i32.src.mlir b/tests/opts/fusion-tensor/i32.src.mlir index e1ca29b1..51644ffc 100644 --- a/tests/opts/fusion-tensor/i32.src.mlir +++ b/tests/opts/fusion-tensor/i32.src.mlir @@ -8,7 +8,7 @@ func.func @producer_indexed_consumer_fusion(%arg0: tensor, %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic { indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"] } diff --git a/tests/opts/fusion-tensor/i32.tgt.mlir b/tests/opts/fusion-tensor/i32.tgt.mlir index 3e120a60..cc8d0396 100644 --- a/tests/opts/fusion-tensor/i32.tgt.mlir +++ b/tests/opts/fusion-tensor/i32.tgt.mlir @@ -5,7 +5,7 @@ module { %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors %4 = arith.addi %arg2, %arg3 : i32 diff --git a/tests/opts/fusion-tensor/nontensor-bad.src.mlir b/tests/opts/fusion-tensor/nontensor-bad.src.mlir index 0e016ea8..e156cd29 100644 --- a/tests/opts/fusion-tensor/nontensor-bad.src.mlir +++ b/tests/opts/fusion-tensor/nontensor-bad.src.mlir @@ -9,7 +9,7 @@ func.func @scalar_add_mul_fusion(%arg0: tensor, %arg1 : f32, %arg2 : f3 %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, f32) outs(%2 : tensor) { diff --git a/tests/opts/fusion-tensor/nontensor-bad.tgt.mlir b/tests/opts/fusion-tensor/nontensor-bad.tgt.mlir index 23ce8bcd..a6e402f2 100644 --- a/tests/opts/fusion-tensor/nontensor-bad.tgt.mlir +++ b/tests/opts/fusion-tensor/nontensor-bad.tgt.mlir @@ -7,7 +7,7 @@ module { %cf1 = arith.constant 1.0 : f32 %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor, f32, f32) outs(%2 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %4 = arith.addf %arg3, %arg4 : f32 diff --git a/tests/opts/fusion-tensor/nontensor.src.mlir b/tests/opts/fusion-tensor/nontensor.src.mlir index 02d903d6..619667fd 100644 --- a/tests/opts/fusion-tensor/nontensor.src.mlir +++ b/tests/opts/fusion-tensor/nontensor.src.mlir @@ -10,7 +10,7 @@ func.func @scalar_add_mul_fusion(%arg0: tensor, %arg1 : f32, %arg2 : f3 %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, f32) outs(%2 : tensor) { diff --git a/tests/opts/fusion-tensor/nontensor.tgt.mlir b/tests/opts/fusion-tensor/nontensor.tgt.mlir index eac0d310..5a78ee66 100644 --- a/tests/opts/fusion-tensor/nontensor.tgt.mlir +++ b/tests/opts/fusion-tensor/nontensor.tgt.mlir @@ -6,7 +6,7 @@ module { %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor, f32, f32) outs(%2 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %4 = arith.addf %arg3, %arg4 : f32 diff --git a/tests/opts/fusion-tensor/simple-bad.src.mlir b/tests/opts/fusion-tensor/simple-bad.src.mlir index 50a32da9..ce51c664 100644 --- a/tests/opts/fusion-tensor/simple-bad.src.mlir +++ b/tests/opts/fusion-tensor/simple-bad.src.mlir @@ -8,7 +8,7 @@ func.func @add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { diff --git a/tests/opts/fusion-tensor/simple-bad.tgt.mlir b/tests/opts/fusion-tensor/simple-bad.tgt.mlir index eeb86623..fba6de2b 100644 --- a/tests/opts/fusion-tensor/simple-bad.tgt.mlir +++ b/tests/opts/fusion-tensor/simple-bad.tgt.mlir @@ -4,7 +4,7 @@ func.func @add_mul_fusion(%arg0: tensor, %arg1: tensor, %arg2: %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c0 : tensor // wrong arith.constant - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) outs(%2 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %4 = arith.addf %arg3, %arg4 : f32 diff --git a/tests/opts/fusion-tensor/simple.src.mlir b/tests/opts/fusion-tensor/simple.src.mlir index 2837a736..f37abb27 100644 --- a/tests/opts/fusion-tensor/simple.src.mlir +++ b/tests/opts/fusion-tensor/simple.src.mlir @@ -9,7 +9,7 @@ func.func @add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { diff --git a/tests/opts/fusion-tensor/simple.tgt.mlir b/tests/opts/fusion-tensor/simple.tgt.mlir index 5d40df5f..07196737 100644 --- a/tests/opts/fusion-tensor/simple.tgt.mlir +++ b/tests/opts/fusion-tensor/simple.tgt.mlir @@ -4,7 +4,7 @@ func.func @add_mul_fusion(%arg0: tensor, %arg1: tensor, %arg2: %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) outs(%2 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %4 = arith.addf %arg3, %arg4 : f32 diff --git a/tests/opts/fusion-tensor/sum.src.mlir b/tests/opts/fusion-tensor/sum.src.mlir index 443ee68b..f8eed13a 100644 --- a/tests/opts/fusion-tensor/sum.src.mlir +++ b/tests/opts/fusion-tensor/sum.src.mlir @@ -6,7 +6,7 @@ func.func @consumer_with_reduction(%arg0: tensor<1x10xf32>, %arg1: tensor<1x10xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> { - %init = linalg.init_tensor [1, 10] : tensor<1x10xf32> + %init = tensor.empty () : tensor<1x10xf32> %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} diff --git a/tests/opts/fusion-tensor/sum_comm.src.mlir b/tests/opts/fusion-tensor/sum_comm.src.mlir index 2709eca6..e1d7739b 100644 --- a/tests/opts/fusion-tensor/sum_comm.src.mlir +++ b/tests/opts/fusion-tensor/sum_comm.src.mlir @@ -6,7 +6,7 @@ func.func @consumer_with_reduction(%arg0: tensor<1x10xf32>, %arg1: tensor<1x10xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> { - %init = linalg.init_tensor [1, 10] : tensor<1x10xf32> + %init = tensor.empty () : tensor<1x10xf32> %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} diff --git a/tests/opts/fusion-tensor/tensor_extract.src.mlir b/tests/opts/fusion-tensor/tensor_extract.src.mlir index 9b5c2267..f6f264fa 100644 --- a/tests/opts/fusion-tensor/tensor_extract.src.mlir +++ b/tests/opts/fusion-tensor/tensor_extract.src.mlir @@ -6,7 +6,7 @@ func.func @sigmoid_dynamic_dim(%0: tensor) -> tensor { %shape = shape.shape_of %0 : tensor -> tensor %extend = shape.to_extent_tensor %shape : tensor -> tensor<2xindex> %extracted = tensor.extract %extend[%c0] : tensor<2xindex> - %init0 = linalg.init_tensor [%extracted, 1] : tensor + %init0 = tensor.empty (%extracted) : tensor %1 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"] @@ -16,7 +16,7 @@ func.func @sigmoid_dynamic_dim(%0: tensor) -> tensor { linalg.yield %cp5 : f32 } -> tensor %d0 = tensor.dim %0, %c0 : tensor - %init1 = linalg.init_tensor [%d0, 1] : tensor + %init1 = tensor.empty (%d0) : tensor %2 = linalg.generic {indexing_maps = [ affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, diff --git a/tests/opts/fusion-tensor/tensor_extract.tgt.mlir b/tests/opts/fusion-tensor/tensor_extract.tgt.mlir index 54cda991..7c420aaa 100644 --- a/tests/opts/fusion-tensor/tensor_extract.tgt.mlir +++ b/tests/opts/fusion-tensor/tensor_extract.tgt.mlir @@ -4,7 +4,7 @@ module { %cst = arith.constant 5.000000e-01 : f32 %c0 = arith.constant 0 : index %0 = tensor.dim %arg0, %c0 : tensor - %1 = linalg.init_tensor [%0, 1] : tensor + %1 = tensor.empty (%0) : tensor %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%1 : tensor) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %3 = arith.mulf %arg1, %cst : f32 diff --git a/tests/opts/fusion-tensor/transpose-bad.src.mlir b/tests/opts/fusion-tensor/transpose-bad.src.mlir index 288c9cc5..73ae3957 100644 --- a/tests/opts/fusion-tensor/transpose-bad.src.mlir +++ b/tests/opts/fusion-tensor/transpose-bad.src.mlir @@ -9,7 +9,7 @@ func.func @transpose_add_mul_fusion(%arg0: tensor, %arg1 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { diff --git a/tests/opts/fusion-tensor/transpose-bad.tgt.mlir b/tests/opts/fusion-tensor/transpose-bad.tgt.mlir index 828c3766..0b25dfd4 100644 --- a/tests/opts/fusion-tensor/transpose-bad.tgt.mlir +++ b/tests/opts/fusion-tensor/transpose-bad.tgt.mlir @@ -6,7 +6,7 @@ module { %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) diff --git a/tests/opts/fusion-tensor/transpose.src.mlir b/tests/opts/fusion-tensor/transpose.src.mlir index 764f7388..c10fe77f 100644 --- a/tests/opts/fusion-tensor/transpose.src.mlir +++ b/tests/opts/fusion-tensor/transpose.src.mlir @@ -10,7 +10,7 @@ func.func @transpose_add_mul_fusion(%arg0: tensor, %arg1 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { diff --git a/tests/opts/fusion-tensor/transpose.tgt.mlir b/tests/opts/fusion-tensor/transpose.tgt.mlir index 1f8cfe58..66441042 100644 --- a/tests/opts/fusion-tensor/transpose.tgt.mlir +++ b/tests/opts/fusion-tensor/transpose.tgt.mlir @@ -6,7 +6,7 @@ module { %c1 = arith.constant 1 : index %0 = tensor.dim %arg0, %c0 : tensor %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) outs(%2 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %4 = arith.addf %arg3, %arg4 : f32 diff --git a/tests/opts/fusion-tensor/zerodim.src.mlir b/tests/opts/fusion-tensor/zerodim.src.mlir index 70b54524..a202eba6 100644 --- a/tests/opts/fusion-tensor/zerodim.src.mlir +++ b/tests/opts/fusion-tensor/zerodim.src.mlir @@ -4,7 +4,7 @@ func.func @add_mul_scalar_fusion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty () : tensor %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []} ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) { diff --git a/tests/opts/fusion-tensor/zerodim.tgt.mlir b/tests/opts/fusion-tensor/zerodim.tgt.mlir index ac9a8299..176804f3 100644 --- a/tests/opts/fusion-tensor/zerodim.tgt.mlir +++ b/tests/opts/fusion-tensor/zerodim.tgt.mlir @@ -1,7 +1,7 @@ #map = affine_map<() -> ()> module { func.func @add_mul_scalar_fusion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = linalg.init_tensor [] : tensor + %0 = tensor.empty () : tensor %1 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) outs(%0 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = arith.addf %arg3, %arg4 : f32 diff --git a/tests/opts/linalg-bufferize/depthwise1.src.mlir b/tests/opts/linalg-bufferize/depthwise1.src.mlir index 793f6791..ecf2cd88 100644 --- a/tests/opts/linalg-bufferize/depthwise1.src.mlir +++ b/tests/opts/linalg-bufferize/depthwise1.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @depthwise1(%arg0: tensor<2x5x5x2xf32>, %arg1: tensor<2x2x2x3xf32>) -> tensor<2x4x4x2x3xf32> { - %0 = linalg.init_tensor [2, 4, 4, 2, 3] : tensor<2x4x4x2x3xf32> + %0 = tensor.empty () : tensor<2x4x4x2x3xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<2x4x4x2x3xf32>) -> tensor<2x4x4x2x3xf32> %2 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x5x5x2xf32>, tensor<2x2x2x3xf32>) outs(%1 : tensor<2x4x4x2x3xf32>) -> tensor<2x4x4x2x3xf32> diff --git a/tests/opts/linalg-bufferize/depthwise2.src.mlir b/tests/opts/linalg-bufferize/depthwise2.src.mlir index fdd6f4fe..bffb4004 100644 --- a/tests/opts/linalg-bufferize/depthwise2.src.mlir +++ b/tests/opts/linalg-bufferize/depthwise2.src.mlir @@ -1,7 +1,7 @@ // VERIFY func.func @depthwise2(%arg0: tensor<1x11x9x3xf32>, %arg1: tensor<3x1x3x11xf32>) -> tensor<1x5x5x3x11xf32> { - %0 = linalg.init_tensor [1, 5, 5, 3, 11] : tensor<1x5x5x3x11xf32> + %0 = tensor.empty () : tensor<1x5x5x3x11xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32> %2 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs(%1 : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32> diff --git a/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.src.mlir b/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.src.mlir index d0429c6b..14abfe88 100644 --- a/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.src.mlir +++ b/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.src.mlir @@ -2,7 +2,7 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> { %cst = arith.constant 1.000000e+00 : f32 - %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> + %1 = tensor.empty () : tensor<1x1xf32> %2 = linalg.fill ins(%cst: f32) outs(%1: tensor<1x1xf32>) -> tensor<1x1xf32> %3 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, diff --git a/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.tgt.mlir b/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.tgt.mlir index 6e367bf8..c7aeb549 100644 --- a/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.tgt.mlir +++ b/tests/opts/linalg-fold-unit-extent-dims/drop-unit-extent-dims.tgt.mlir @@ -3,7 +3,7 @@ module { func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> { %cst = arith.constant 1.000000e+00 : f32 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<1x?x1x1xf32> into tensor - %1 = linalg.init_tensor [1] : tensor<1xf32> + %1 = tensor.empty () : tensor<1xf32> %2 = linalg.fill ins(%cst: f32) outs(%1: tensor<1xf32>) -> tensor<1xf32> %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor) outs(%2 : tensor<1xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors diff --git a/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.src.mlir b/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.src.mlir index d898c1a2..c97e9d3b 100644 --- a/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.src.mlir +++ b/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.src.mlir @@ -4,7 +4,7 @@ func.func @f(%arg0: tensor<1x?x1x1xi32>) -> tensor<1x1xi32> { %cst = arith.constant 1 : i32 - %init_tensor = linalg.init_tensor [1, 1] : tensor<1x1xi32> + %init_tensor = tensor.empty () : tensor<1x1xi32> %filled = linalg.fill ins(%cst: i32) outs(%init_tensor: tensor<1x1xi32>) -> tensor<1x1xi32> %res = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, diff --git a/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.tgt.mlir b/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.tgt.mlir index af744599..9442656f 100644 --- a/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.tgt.mlir +++ b/tests/opts/linalg-fold-unit-extent-dims/one-trip-loop.tgt.mlir @@ -3,7 +3,7 @@ module { func.func @f(%arg0: tensor<1x?x1x1xi32>) -> tensor<1x1xi32> { %cst = arith.constant 1 : i32 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<1x?x1x1xi32> into tensor - %1 = linalg.init_tensor [1] : tensor<1xi32> + %1 = tensor.empty () : tensor<1xi32> %2 = linalg.fill ins(%cst: i32) outs(%1: tensor<1xi32>) -> tensor<1xi32> %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor) outs(%2 : tensor<1xi32>) { ^bb0(%arg1: i32, %arg2: i32): // no predecessors diff --git a/tests/opts/tosa-make-broadcastable/broadcast1.tgt.mlir b/tests/opts/tosa-make-broadcastable/broadcast1.tgt.mlir index b5bfe982..b05d94d9 100644 --- a/tests/opts/tosa-make-broadcastable/broadcast1.tgt.mlir +++ b/tests/opts/tosa-make-broadcastable/broadcast1.tgt.mlir @@ -1,7 +1,7 @@ module { func.func @broadcast(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> { - %0 = "tosa.reshape"(%arg1) {new_shape = [1, 1, 15, 1]} : (tensor<15x1xf32>) -> tensor<1x1x15x1xf32> + %0 = "tosa.reshape"(%arg1) {new_shape = array} : (tensor<15x1xf32>) -> tensor<1x1x15x1xf32> %1 = "tosa.add"(%arg0, %0) : (tensor<17x16x15x14xf32>, tensor<1x1x15x1xf32>) -> tensor<17x16x15x14xf32> return %1 : tensor<17x16x15x14xf32> } -} \ No newline at end of file +} diff --git a/tests/opts/tosa-to-linalg/abs.tgt.mlir b/tests/opts/tosa-to-linalg/abs.tgt.mlir index c0c334a1..3f42e965 100644 --- a/tests/opts/tosa-to-linalg/abs.tgt.mlir +++ b/tests/opts/tosa-to-linalg/abs.tgt.mlir @@ -3,7 +3,7 @@ module { func.func @test_abs(%arg0: tensor) -> tensor { %c0 = arith.constant 0: index %sz = tensor.dim %arg0, %c0: tensor - %0 = linalg.init_tensor [%sz] : tensor + %0 = tensor.empty (%sz) : tensor %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor) outs(%0 : tensor) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors diff --git a/tests/opts/tosa-to-linalg/add.tgt.mlir b/tests/opts/tosa-to-linalg/add.tgt.mlir index faffb010..a928dcb9 100644 --- a/tests/opts/tosa-to-linalg/add.tgt.mlir +++ b/tests/opts/tosa-to-linalg/add.tgt.mlir @@ -5,7 +5,7 @@ module { %0 = tensor.dim %arg0, %c0 : tensor %c1 = arith.constant 1 : index %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor, tensor) outs(%2 : tensor) { ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors %4 = arith.addf %arg2, %arg3 : f32 diff --git a/tests/opts/tosa-to-linalg/clamp.tgt.mlir b/tests/opts/tosa-to-linalg/clamp.tgt.mlir index 33def2b9..ad15595e 100644 --- a/tests/opts/tosa-to-linalg/clamp.tgt.mlir +++ b/tests/opts/tosa-to-linalg/clamp.tgt.mlir @@ -5,7 +5,7 @@ module { %0 = tensor.dim %arg0, %c0 : tensor %c1 = arith.constant 1 : index %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%2 : tensor) { ^bb0(%arg1: i32, %arg2: i32): // no predecessors %c-127_i32 = arith.constant -127 : i32 @@ -23,7 +23,7 @@ module { %0 = tensor.dim %arg0, %c0 : tensor %c1 = arith.constant 1 : index %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor + %2 = tensor.empty (%0, %1) : tensor %3 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor) outs(%2 : tensor) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %cst = arith.constant 1.000000e+00 : f32 diff --git a/tests/opts/tosa-to-linalg/concat.tgt.mlir b/tests/opts/tosa-to-linalg/concat.tgt.mlir index 9897cd79..86f82375 100644 --- a/tests/opts/tosa-to-linalg/concat.tgt.mlir +++ b/tests/opts/tosa-to-linalg/concat.tgt.mlir @@ -11,7 +11,7 @@ module { %c4 = arith.constant 4 : index %c5 = arith.constant 5 : index %c9 = arith.constant 9 : index - %0 = linalg.init_tensor [9, 2] : tensor<9x2xf32> + %0 = tensor.empty () : tensor<9x2xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<9x2xf32>) -> tensor<9x2xf32> %c1_4 = arith.constant 1 : index diff --git a/tests/opts/tosa-to-linalg/conv2d1.src.mlir b/tests/opts/tosa-to-linalg/conv2d1.src.mlir index c0906b00..6e97c34c 100644 --- a/tests/opts/tosa-to-linalg/conv2d1.src.mlir +++ b/tests/opts/tosa-to-linalg/conv2d1.src.mlir @@ -7,6 +7,6 @@ func.func @conv(%img: tensor<1x2x2x2xf32>, %filter: tensor<1x2x2x2xf32>) -> tensor<1x1x1x1xf32> { %c0 = arith.constant -0.0 : f32 %bias = tensor.from_elements %c0: tensor<1xf32> - %0 = "tosa.conv2d"(%img, %filter, %bias) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> + %0 = "tosa.conv2d"(%img, %filter, %bias) {dilation = array, pad = array, stride = array} : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x1x1x1xf32> return %0 : tensor<1x1x1x1xf32> } diff --git a/tests/opts/tosa-to-linalg/conv2d1.tgt.mlir b/tests/opts/tosa-to-linalg/conv2d1.tgt.mlir index ec8b0db4..8a49d850 100644 --- a/tests/opts/tosa-to-linalg/conv2d1.tgt.mlir +++ b/tests/opts/tosa-to-linalg/conv2d1.tgt.mlir @@ -6,15 +6,15 @@ module { %cst = arith.constant -0.000000e+00 : f32 %0 = tensor.from_elements %cst : tensor<1xf32> %cst_0 = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64> - %1 = linalg.init_tensor [2, 2, 2, 1] : tensor<2x2x2x1xf32> + %1 = tensor.empty () : tensor<2x2x2x1xf32> %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<1x2x2x2xf32>) outs(%1 : tensor<2x2x2x1xf32>) { ^bb0(%arg2: f32, %arg3: f32): // no predecessors linalg.yield %arg2 : f32 } -> tensor<2x2x2x1xf32> - %3 = linalg.init_tensor [1, 1, 1, 1] : tensor<1x1x1x1xf32> + %3 = tensor.empty () : tensor<1x1x1x1xf32> %cst_1 = arith.constant 0.000000e+00 : f32 %4 = linalg.fill ins(%cst_1: f32) outs(%3: tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> - %5 = linalg.init_tensor [1, 1, 1, 1] : tensor<1x1x1x1xf32> + %5 = tensor.empty () : tensor<1x1x1x1xf32> %6 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %2 : tensor<1x2x2x2xf32>, tensor<2x2x2x1xf32>) outs(%4 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> %7 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0, %6 : tensor<1xf32>, tensor<1x1x1x1xf32>) outs(%5 : tensor<1x1x1x1xf32>) { ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors diff --git a/tests/opts/tosa-to-linalg/depthwise1.src.mlir b/tests/opts/tosa-to-linalg/depthwise1.src.mlir index f95d4cfa..5e57327e 100644 --- a/tests/opts/tosa-to-linalg/depthwise1.src.mlir +++ b/tests/opts/tosa-to-linalg/depthwise1.src.mlir @@ -5,6 +5,6 @@ // a non-identity value (+0.0) to the output tensor. func.func @depthwise1(%arg0: tensor<2x5x5x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<2x4x4x6xf32> { - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<2x5x5x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<2x4x4x6xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array, stride = array, dilation = array} : (tensor<2x5x5x2xf32>, tensor<2x2x2x3xf32>, tensor<6xf32>) -> tensor<2x4x4x6xf32> return %0 : tensor<2x4x4x6xf32> } diff --git a/tests/opts/tosa-to-linalg/depthwise1.tgt.mlir b/tests/opts/tosa-to-linalg/depthwise1.tgt.mlir index 98a8deb5..8cb1ab21 100644 --- a/tests/opts/tosa-to-linalg/depthwise1.tgt.mlir +++ b/tests/opts/tosa-to-linalg/depthwise1.tgt.mlir @@ -2,10 +2,10 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> module { func.func @depthwise1(%arg0: tensor<2x5x5x2xf32>, %arg1: tensor<2x2x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<2x4x4x6xf32> { - %0 = linalg.init_tensor [2, 4, 4, 2, 3] : tensor<2x4x4x2x3xf32> + %0 = tensor.empty () : tensor<2x4x4x2x3xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<2x4x4x2x3xf32>) -> tensor<2x4x4x2x3xf32> - %2 = linalg.init_tensor [2, 4, 4, 6] : tensor<2x4x4x6xf32> + %2 = tensor.empty () : tensor<2x4x4x6xf32> %3 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x5x5x2xf32>, tensor<2x2x2x3xf32>) outs(%1 : tensor<2x4x4x2x3xf32>) -> tensor<2x4x4x2x3xf32> %4 = tensor.collapse_shape %3 [[0], [1], [2], [3, 4]] : tensor<2x4x4x2x3xf32> into tensor<2x4x4x6xf32> %5 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %4 : tensor<6xf32>, tensor<2x4x4x6xf32>) outs(%2 : tensor<2x4x4x6xf32>) { diff --git a/tests/opts/tosa-to-linalg/depthwise3.src.mlir b/tests/opts/tosa-to-linalg/depthwise3.src.mlir index ee47e02f..3bd781d5 100644 --- a/tests/opts/tosa-to-linalg/depthwise3.src.mlir +++ b/tests/opts/tosa-to-linalg/depthwise3.src.mlir @@ -5,6 +5,6 @@ // a non-identity value (+0.0) to the output tensor. func.func @depthwise3(%arg0 : tensor<1x11x9x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<33xf32>) -> (tensor<1x5x5x33xf32>) { - %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1] } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>) + %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = array, stride = array, dilation = array } : (tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>) return %2 : tensor<1x5x5x33xf32> } diff --git a/tests/opts/tosa-to-linalg/depthwise3.tgt.mlir b/tests/opts/tosa-to-linalg/depthwise3.tgt.mlir index c9b0d58a..0c3a5ac7 100644 --- a/tests/opts/tosa-to-linalg/depthwise3.tgt.mlir +++ b/tests/opts/tosa-to-linalg/depthwise3.tgt.mlir @@ -2,10 +2,10 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> module { func.func @depthwise3(%arg0: tensor<1x11x9x3xf32>, %arg1: tensor<3x1x3x11xf32>, %arg2: tensor<33xf32>) -> tensor<1x5x5x33xf32> { - %0 = linalg.init_tensor [1, 5, 5, 3, 11] : tensor<1x5x5x3x11xf32> + %0 = tensor.empty () : tensor<1x5x5x3x11xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32> - %2 = linalg.init_tensor [1, 5, 5, 33] : tensor<1x5x5x33xf32> + %2 = tensor.empty () : tensor<1x5x5x33xf32> %3 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x11x9x3xf32>, tensor<3x1x3x11xf32>) outs(%1 : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32> %4 = tensor.collapse_shape %3 [[0], [1], [2], [3, 4]] : tensor<1x5x5x3x11xf32> into tensor<1x5x5x33xf32> %5 = linalg.generic {indexing_maps = [#map0, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %4 : tensor<33xf32>, tensor<1x5x5x33xf32>) outs(%2 : tensor<1x5x5x33xf32>) { diff --git a/tests/opts/tosa-to-linalg/fully_connected.tgt.mlir b/tests/opts/tosa-to-linalg/fully_connected.tgt.mlir index 31b4c61d..9afe92e3 100644 --- a/tests/opts/tosa-to-linalg/fully_connected.tgt.mlir +++ b/tests/opts/tosa-to-linalg/fully_connected.tgt.mlir @@ -3,16 +3,16 @@ #map2 = affine_map<(d0, d1) -> (d1)> module { func.func @f(%arg0: tensor<10x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> tensor<10x6xf32> { - %0 = linalg.init_tensor [10, 6] : tensor<10x6xf32> + %0 = tensor.empty () : tensor<10x6xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<10x6xf32>) -> tensor<10x6xf32> %cst_0 = arith.constant dense<[1, 0]> : tensor<2xi64> - %2 = linalg.init_tensor [3, 6] : tensor<3x6xf32> + %2 = tensor.empty () : tensor<3x6xf32> %3 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs(%2 : tensor<3x6xf32>) { ^bb0(%arg3: f32, %arg4: f32): // no predecessors linalg.yield %arg3 : f32 } -> tensor<3x6xf32> - %4 = linalg.init_tensor [10, 6] : tensor<10x6xf32> + %4 = tensor.empty () : tensor<10x6xf32> %5 = linalg.matmul ins(%arg0, %3 : tensor<10x3xf32>, tensor<3x6xf32>) outs(%1 : tensor<10x6xf32>) -> tensor<10x6xf32> %6 = linalg.generic {indexing_maps = [#map2, #map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg2, %5 : tensor<6xf32>, tensor<10x6xf32>) outs(%4 : tensor<10x6xf32>) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors diff --git a/tests/opts/tosa-to-linalg/gather.tgt.mlir b/tests/opts/tosa-to-linalg/gather.tgt.mlir index 74e32d42..20f5a213 100644 --- a/tests/opts/tosa-to-linalg/gather.tgt.mlir +++ b/tests/opts/tosa-to-linalg/gather.tgt.mlir @@ -2,7 +2,7 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> module { func.func @gather_float(%arg0: tensor<2x3x2xf32>, %arg1: tensor<2x3xi32>) { - %0 = linalg.init_tensor [2, 3, 2] : tensor<2x3x2xf32> + %0 = tensor.empty () : tensor<2x3x2xf32> %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x3xi32>) outs(%0 : tensor<2x3x2xf32>) { ^bb0(%arg2: i32, %arg3: f32): // no predecessors %2 = linalg.index 0 : index diff --git a/tests/opts/tosa-to-linalg/mul.tgt.mlir b/tests/opts/tosa-to-linalg/mul.tgt.mlir index fc6d3fe9..660528dd 100644 --- a/tests/opts/tosa-to-linalg/mul.tgt.mlir +++ b/tests/opts/tosa-to-linalg/mul.tgt.mlir @@ -1,7 +1,7 @@ #map = affine_map<(d0) -> (d0)> module { func.func @f(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { - %0 = linalg.init_tensor [8] : tensor<8xf32> + %0 = tensor.empty () : tensor<8xf32> %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xf32>, tensor<8xf32>) outs(%0 : tensor<8xf32>) { ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors %2 = arith.mulf %arg2, %arg3 : f32 diff --git a/tests/opts/tosa-to-linalg/muli.tgt.mlir b/tests/opts/tosa-to-linalg/muli.tgt.mlir index 8af3b82c..b2b19279 100644 --- a/tests/opts/tosa-to-linalg/muli.tgt.mlir +++ b/tests/opts/tosa-to-linalg/muli.tgt.mlir @@ -1,7 +1,7 @@ #map = affine_map<(d0) -> (d0)> module { func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> { - %0 = linalg.init_tensor [8] : tensor<8xi32> + %0 = tensor.empty () : tensor<8xi32> %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<8xi32>, tensor<8xi32>) outs(%0 : tensor<8xi32>) { ^bb0(%arg2: i32, %arg3: i32, %arg4: i32): // no predecessors %2 = arith.muli %arg2, %arg3 : i32 diff --git a/tests/opts/tosa-to-linalg/reciprocal.tgt.mlir b/tests/opts/tosa-to-linalg/reciprocal.tgt.mlir index 2c7edc5d..d255435a 100644 --- a/tests/opts/tosa-to-linalg/reciprocal.tgt.mlir +++ b/tests/opts/tosa-to-linalg/reciprocal.tgt.mlir @@ -1,7 +1,7 @@ #map = affine_map<(d0, d1) -> (d0, d1)> module { func.func @f(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { - %0 = linalg.init_tensor [10, 10] : tensor<10x10xf32> + %0 = tensor.empty () : tensor<10x10xf32> %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<10x10xf32>) outs(%0 : tensor<10x10xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors %cst = arith.constant 1.000000e+00 : f32 diff --git a/tests/opts/tosa-to-linalg/reduce_sum.tgt.mlir b/tests/opts/tosa-to-linalg/reduce_sum.tgt.mlir index 51eed77a..8dce61ec 100644 --- a/tests/opts/tosa-to-linalg/reduce_sum.tgt.mlir +++ b/tests/opts/tosa-to-linalg/reduce_sum.tgt.mlir @@ -2,7 +2,7 @@ #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> module { func.func @f(%arg0: tensor<3x4x5xf32>) -> tensor<1x4x5xf32> { - %0 = linalg.init_tensor [4, 5] : tensor<4x5xf32> + %0 = tensor.empty () : tensor<4x5xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0: tensor<4x5xf32>) -> tensor<4x5xf32> %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<3x4x5xf32>) outs(%1 : tensor<4x5xf32>) { diff --git a/tests/opts/tosa-to-linalg/reduce_sum2.tgt.mlir b/tests/opts/tosa-to-linalg/reduce_sum2.tgt.mlir index 7b1058d5..4e952ebe 100644 --- a/tests/opts/tosa-to-linalg/reduce_sum2.tgt.mlir +++ b/tests/opts/tosa-to-linalg/reduce_sum2.tgt.mlir @@ -2,7 +2,7 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> module { func.func @f(%arg0: tensor<3x4x5xf32>) -> tensor<3x1x5xf32> { - %0 = linalg.init_tensor [3, 5] : tensor<3x5xf32> + %0 = tensor.empty () : tensor<3x5xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0:tensor<3x5xf32>) -> tensor<3x5xf32> %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<3x4x5xf32>) outs(%1 : tensor<3x5xf32>) { diff --git a/tests/opts/tosa-to-linalg/reduce_sum3.tgt.mlir b/tests/opts/tosa-to-linalg/reduce_sum3.tgt.mlir index 5e1c21ab..eb3f74de 100644 --- a/tests/opts/tosa-to-linalg/reduce_sum3.tgt.mlir +++ b/tests/opts/tosa-to-linalg/reduce_sum3.tgt.mlir @@ -2,7 +2,7 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> module { func.func @f(%arg0: tensor<3x1000x5xf32>) -> tensor<3x1x5xf32> { - %0 = linalg.init_tensor [3, 5] : tensor<3x5xf32> + %0 = tensor.empty () : tensor<3x5xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst: f32) outs(%0:tensor<3x5xf32>) -> tensor<3x5xf32> %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<3x1000x5xf32>) outs(%1 : tensor<3x5xf32>) { diff --git a/tests/opts/tosa-to-linalg/reverse.tgt.mlir b/tests/opts/tosa-to-linalg/reverse.tgt.mlir index dde0f280..6cdbf652 100644 --- a/tests/opts/tosa-to-linalg/reverse.tgt.mlir +++ b/tests/opts/tosa-to-linalg/reverse.tgt.mlir @@ -9,7 +9,7 @@ module { %2 = tensor.dim %arg0, %c2 : tensor %c1_0 = arith.constant 1 : index %3 = tensor.dim %arg0, %c1_0 : tensor - %4 = linalg.init_tensor [%0, %1, %2] : tensor + %4 = tensor.empty (%0, %1, %2) : tensor %5 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%4 : tensor) { ^bb0(%arg1: i32): // no predecessors %6 = linalg.index 0 : index diff --git a/tests/opts/tosa-to-linalg/tile.src.mlir b/tests/opts/tosa-to-linalg/tile.src.mlir index 60f15141..facd453f 100644 --- a/tests/opts/tosa-to-linalg/tile.src.mlir +++ b/tests/opts/tosa-to-linalg/tile.src.mlir @@ -1,8 +1,8 @@ // VERIFY func.func @f(%x: tensor<3x3xf32>) -> tensor<6x9xf32> { - %a = "tosa.tile"(%x) {multiples = [2, 1]} : (tensor<3x3xf32>) -> (tensor<6x3xf32>) - %b = "tosa.tile"(%a) {multiples = [1, 3]} : (tensor<6x3xf32>) -> (tensor<6x9xf32>) + %a = "tosa.tile"(%x) {multiples = array} : (tensor<3x3xf32>) -> (tensor<6x3xf32>) + %b = "tosa.tile"(%a) {multiples = array} : (tensor<6x3xf32>) -> (tensor<6x9xf32>) return %b: tensor<6x9xf32> } diff --git a/tests/opts/tosa-to-linalg/tile.tgt.mlir b/tests/opts/tosa-to-linalg/tile.tgt.mlir index 13eeaf0e..7bb981ad 100644 --- a/tests/opts/tosa-to-linalg/tile.tgt.mlir +++ b/tests/opts/tosa-to-linalg/tile.tgt.mlir @@ -2,13 +2,13 @@ #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> module { func.func @f(%arg0: tensor<3x3xf32>) -> tensor<6x9xf32> { - %0 = linalg.init_tensor [2, 3, 1, 3] : tensor<2x3x1x3xf32> + %0 = tensor.empty () : tensor<2x3x1x3xf32> %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x3xf32>) outs(%0 : tensor<2x3x1x3xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors linalg.yield %arg1 : f32 } -> tensor<2x3x1x3xf32> %2 = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<2x3x1x3xf32> into tensor<6x3xf32> - %3 = linalg.init_tensor [1, 6, 3, 3] : tensor<1x6x3x3xf32> + %3 = tensor.empty () : tensor<1x6x3x3xf32> %4 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<6x3xf32>) outs(%3 : tensor<1x6x3x3xf32>) { ^bb0(%arg1: f32, %arg2: f32): // no predecessors linalg.yield %arg1 : f32