diff --git a/SPIKE.hash b/SPIKE.hash index ce15e697..a96811da 100644 --- a/SPIKE.hash +++ b/SPIKE.hash @@ -1 +1 @@ -34741e07bc6b56f1762ce579537948d58e28cd5a +02e2d983cc8e2c385ebe920302c427b9167bd76e diff --git a/software/gemmini-rocc-tests b/software/gemmini-rocc-tests index 3aaa2307..21713ec6 160000 --- a/software/gemmini-rocc-tests +++ b/software/gemmini-rocc-tests @@ -1 +1 @@ -Subproject commit 3aaa230733a9eba6edf4d14243d84595e017522f +Subproject commit 21713ec6e9dbbf2477b092e04eb8970776a5da72 diff --git a/src/main/scala/gemmini/AccumulatorMem.scala b/src/main/scala/gemmini/AccumulatorMem.scala index 89a39182..8f3fbaf5 100644 --- a/src/main/scala/gemmini/AccumulatorMem.scala +++ b/src/main/scala/gemmini/AccumulatorMem.scala @@ -44,17 +44,59 @@ class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends override def cloneType: this.type = new AccumulatorWriteReq(n, t).asInstanceOf[this.type] } -class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U) extends Bundle { + +class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U, + acc_sub_banks: Int, use_shared_ext_mem: Boolean +) extends Bundle { val read = Flipped(new AccumulatorReadIO(n, log2Ceil(t.head.head.getWidth), t, scale_t)) - // val write = Flipped(new AccumulatorWriteIO(n, t)) val write = Flipped(Decoupled(new AccumulatorWriteReq(n, t))) - override def cloneType: this.type = new AccumulatorMemIO(n, t, scale_t).asInstanceOf[this.type] + val ext_mem = if (use_shared_ext_mem) Some(Vec(acc_sub_banks, new ExtMemIO)) else None + + val adder = new Bundle { + val valid = Output(Bool()) + val op1 = Output(t.cloneType) + val op2 = Output(t.cloneType) + val sum = Input(t.cloneType) + } + + override def cloneType: this.type = new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem).asInstanceOf[this.type] +} + +class AccPipe[T <: Data : Arithmetic](latency: Int, t: T)(implicit ev: Arithmetic[T]) extends Module { + val io = IO(new Bundle { + val op1 = Input(t.cloneType) + val op2 = Input(t.cloneType) + val sum = Output(t.cloneType) + }) + import ev._ + io.sum := ShiftRegister(io.op1 + io.op2, latency) +} + +class AccPipeShared[T <: Data : Arithmetic](latency: Int, t: Vec[Vec[T]], banks: Int) extends Module { + val io = IO(new Bundle { + val in_sel = Input(Vec(banks, Bool())) + val ina = Input(Vec(banks, t.cloneType)) + val inb = Input(Vec(banks, t.cloneType)) + val out = Output(t.cloneType) + }) + val ina = Mux1H(io.in_sel, io.ina) + val inb = Mux1H(io.in_sel, io.inb) + io.out := VecInit((ina zip inb).map { case (rv, wv) => + VecInit((rv zip wv).map { case (re, we) => + val m = Module(new AccPipe(latency, t.head.head.cloneType)) + m.io.op1 := re + m.io.op2 := we + m.io.sum + }) + }) } class AccumulatorMem[T <: Data, U <: Data]( - n: Int, t: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U, - acc_singleported: Boolean, acc_sub_banks: Int + n: Int, t: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U, + acc_singleported: Boolean, acc_sub_banks: Int, + use_shared_ext_mem: Boolean, + acc_latency: Int, acc_type: T ) (implicit ev: Arithmetic[T]) extends Module { // TODO Do writes in this module work with matrices of size 2? If we try to read from an address right after writing @@ -69,54 +111,91 @@ class AccumulatorMem[T <: Data, U <: Data]( import ev._ // TODO unify this with TwoPortSyncMemIO - val io = IO(new AccumulatorMemIO(n, t, scale_t)) - - - // For any write operation, we spend 2 cycles reading the existing address out, buffering it in a register, and then - // accumulating on top of it (if necessary) - val wdata_buf = ShiftRegister(io.write.bits.data, 2) - val waddr_buf = ShiftRegister(io.write.bits.addr, 2) - val acc_buf = ShiftRegister(io.write.bits.acc, 2) - val mask_buf = ShiftRegister(io.write.bits.mask, 2) - val w_buf_valid = ShiftRegister(io.write.fire(), 2) - val acc_rdata = Wire(t) - acc_rdata := DontCare - val read_rdata = Wire(t) - read_rdata := DontCare + val io = IO(new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem)) + + require (acc_latency >= 2) + + val pipelined_writes = Reg(Vec(acc_latency, Valid(new AccumulatorWriteReq(n, t)))) + val oldest_pipelined_write = pipelined_writes(acc_latency-1) + pipelined_writes(0).valid := io.write.fire() + pipelined_writes(0).bits := io.write.bits + for (i <- 1 until acc_latency) { + pipelined_writes(i) := pipelined_writes(i-1) + } + + val rdata_for_adder = Wire(t) + rdata_for_adder := DontCare + val rdata_for_read_resp = Wire(t) + rdata_for_read_resp := DontCare + + val adder_sum = io.adder.sum + io.adder.valid := pipelined_writes(0).valid && pipelined_writes(0).bits.acc + io.adder.op1 := rdata_for_adder + io.adder.op2 := pipelined_writes(0).bits.data + val block_read_req = WireInit(false.B) - val w_sum = VecInit((RegNext(acc_rdata) zip wdata_buf).map { case (rv, wv) => - VecInit((rv zip wv).map(t => t._1 + t._2)) - }) + val block_write_req = WireInit(false.B) + + val mask_len = t.getWidth / 8 + val mask_elem = UInt((t.getWidth / mask_len).W) if (!acc_singleported) { - val mem = TwoPortSyncMem(n, t, t.getWidth / 8) // TODO We assume byte-alignment here. Use aligned_to instead - mem.io.waddr := waddr_buf - mem.io.wen := w_buf_valid - mem.io.wdata := Mux(acc_buf, w_sum, wdata_buf) - mem.io.mask := mask_buf - acc_rdata := mem.io.rdata - read_rdata := mem.io.rdata + require(!use_shared_ext_mem) + val mem = TwoPortSyncMem(n, t, mask_len) // TODO We assume byte-alignment here. Use aligned_to instead + mem.io.waddr := oldest_pipelined_write.bits.addr + mem.io.wen := oldest_pipelined_write.valid + mem.io.wdata := Mux(oldest_pipelined_write.bits.acc, adder_sum, oldest_pipelined_write.bits.data) + mem.io.mask := oldest_pipelined_write.bits.mask + rdata_for_adder := mem.io.rdata + rdata_for_read_resp := mem.io.rdata mem.io.raddr := Mux(io.write.fire() && io.write.bits.acc, io.write.bits.addr, io.read.req.bits.addr) mem.io.ren := io.read.req.fire() || (io.write.fire() && io.write.bits.acc) } else { - val mask_len = t.getWidth / 8 - val mask_elem = UInt((t.getWidth / mask_len).W) - val reads = Wire(Vec(2, Decoupled(UInt()))) - reads(0).valid := io.write.valid && io.write.bits.acc - reads(0).bits := io.write.bits.addr - reads(0).ready := true.B - reads(1).valid := io.read.req.valid - reads(1).bits := io.read.req.bits.addr - reads(1).ready := true.B - block_read_req := !reads(1).ready + val rmw_req = Wire(Decoupled(UInt())) + rmw_req.valid := io.write.valid && io.write.bits.acc + rmw_req.bits := io.write.bits.addr + rmw_req.ready := true.B + + block_write_req := !rmw_req.ready + + val only_read_req = Wire(Decoupled(UInt())) + only_read_req.valid := io.read.req.valid + only_read_req.bits := io.read.req.bits.addr + only_read_req.ready := true.B + + block_read_req := !only_read_req.ready + for (i <- 0 until acc_sub_banks) { def isThisBank(addr: UInt) = addr(log2Ceil(acc_sub_banks)-1,0) === i.U - def getBankIdx(addr: UInt): UInt = (addr >> log2Ceil(acc_sub_banks)).asUInt() - val mem = SyncReadMem(n / acc_sub_banks, Vec(mask_len, mask_elem)) + def getBankIdx(addr: UInt) = addr >> log2Ceil(acc_sub_banks) + val (read, write) = if (use_shared_ext_mem) { + def read(addr: UInt, ren: Bool): Data = { + io.ext_mem.get(i).read_en := ren + io.ext_mem.get(i).read_addr := addr + io.ext_mem.get(i).read_data + } + io.ext_mem.get(i).write_en := false.B + io.ext_mem.get(i).write_addr := DontCare + io.ext_mem.get(i).write_data := DontCare + io.ext_mem.get(i).write_mask := DontCare + def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = { + io.ext_mem.get(i).write_en := true.B + io.ext_mem.get(i).write_addr := addr + io.ext_mem.get(i).write_data := wdata.asUInt + io.ext_mem.get(i).write_mask := wmask.asUInt + } + (read _, write _) + } else { + val mem = SyncReadMem(n / acc_sub_banks, Vec(mask_len, mask_elem)) + def read(addr: UInt, ren: Bool): Data = mem.read(addr, ren) + def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = mem.write(addr, wdata, wmask) + (read _, write _) + } val ren = WireInit(false.B) - val raddr = WireInit(getBankIdx(reads(0).bits)) + val raddr = WireInit(getBankIdx(rmw_req.bits)) val nEntries = 3 + // Writes coming 2 cycles after read leads to bad bank behavior // Add another buffer here class W_Q_Entry[T <: Data](mask_len: Int, mask_elem: T) extends Bundle { @@ -126,25 +205,32 @@ class AccumulatorMem[T <: Data, U <: Data]( val addr = UInt(log2Ceil(n/acc_sub_banks).W) override def cloneType: this.type = new W_Q_Entry(mask_len, mask_elem).asInstanceOf[this.type] } + val w_q = Reg(Vec(nEntries, new W_Q_Entry(mask_len, mask_elem))) for (e <- w_q) { when (e.valid) { assert(!( - io.write.valid && io.write.bits.acc && + io.write.fire() && io.write.bits.acc && isThisBank(io.write.bits.addr) && getBankIdx(io.write.bits.addr) === e.addr && ((io.write.bits.mask.asUInt & e.mask.asUInt) =/= 0.U) - )) + ), "you cannot accumulate to an AccumulatorMem address until previous writes to that address have completed") + + when (io.write.bits.acc && isThisBank(io.write.bits.addr) && getBankIdx(io.write.bits.addr) === e.addr) { + rmw_req.ready := false.B + } - when (io.read.req.valid && isThisBank(io.read.req.bits.addr) && getBankIdx(io.read.req.bits.addr) === e.addr) { - reads(1).ready := false.B + when (isThisBank(io.read.req.bits.addr) && getBankIdx(io.read.req.bits.addr) === e.addr) { + only_read_req.ready := false.B } } } + val w_q_head = RegInit(1.U(nEntries.W)) val w_q_tail = RegInit(1.U(nEntries.W)) - when (reset.asBool) { - w_q.foreach(_.valid := false.B) - } + + val w_q_full = (w_q_tail.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_) + val w_q_empty = !(w_q_head.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_) + val wen = WireInit(false.B) val wdata = Mux1H(w_q_head.asBools, w_q.map(_.data)) val wmask = Mux1H(w_q_head.asBools, w_q.map(_.mask)) @@ -158,49 +244,61 @@ class AccumulatorMem[T <: Data, U <: Data]( } } - when (w_buf_valid && isThisBank(waddr_buf)) { - assert(!((w_q_tail.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_))) + val w_q_push = oldest_pipelined_write.valid && isThisBank(oldest_pipelined_write.bits.addr) + + when (w_q_push) { + assert(!w_q_full || wen, "we ran out of acc-sub-bank write q entries") + w_q_tail := (w_q_tail << 1).asUInt() | w_q_tail(nEntries-1) for (i <- 0 until nEntries) { when (w_q_tail(i)) { w_q(i).valid := true.B - w_q(i).data := Mux(acc_buf, w_sum, wdata_buf).asTypeOf(Vec(mask_len, mask_elem)) - w_q(i).mask := mask_buf - w_q(i).addr := getBankIdx(waddr_buf) + w_q(i).data := Mux(oldest_pipelined_write.bits.acc, adder_sum, oldest_pipelined_write.bits.data).asTypeOf(Vec(mask_len, mask_elem)) + w_q(i).mask := oldest_pipelined_write.bits.mask + w_q(i).addr := getBankIdx(oldest_pipelined_write.bits.addr) } } - } - val bank_rdata = mem.read(raddr, ren && !wen).asTypeOf(t) - when (RegNext(ren && reads(0).valid && isThisBank(reads(0).bits))) { - acc_rdata := bank_rdata + + val bank_rdata = read(raddr, ren && !wen).asTypeOf(t) + when (RegNext(ren && rmw_req.valid && isThisBank(rmw_req.bits))) { + rdata_for_adder := bank_rdata } .elsewhen (RegNext(ren)) { - read_rdata := bank_rdata + rdata_for_read_resp := bank_rdata } + when (wen) { - mem.write(waddr, wdata, wmask) + write(waddr, wdata, wmask) } + // Three requestors, 1 slot - // Priority is incoming reads for RMW > writes from RMW > incoming reads - when (reads(0).valid && isThisBank(reads(0).bits)) { + // Priority is (in descending order): + // 1. incoming reads for RMW + // 2. writes from RMW + // 3. incoming reads + when (rmw_req.fire() && isThisBank(rmw_req.bits)) { ren := true.B - when (isThisBank(reads(1).bits)) { - reads(1).ready := false.B + when (isThisBank(only_read_req.bits)) { + only_read_req.ready := false.B } - } .elsewhen ((w_q_head.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_)) { + } .elsewhen (!w_q_empty) { wen := true.B - when (isThisBank(reads(1).bits)) { - reads(1).ready := false.B + when (isThisBank(only_read_req.bits)) { + only_read_req.ready := false.B } } .otherwise { - ren := isThisBank(reads(1).bits) - raddr := getBankIdx(reads(1).bits) + ren := isThisBank(only_read_req.bits) && only_read_req.fire() + raddr := getBankIdx(only_read_req.bits) + } + + when (reset.asBool) { + w_q.foreach(_.valid := false.B) } } } val q = Module(new Queue(new AccumulatorReadResp(t, scale_t, log2Ceil(t.head.head.getWidth)), 1, true, true)) - q.io.enq.bits.data := read_rdata + q.io.enq.bits.data := rdata_for_read_resp q.io.enq.bits.scale := RegNext(io.read.req.bits.scale) q.io.enq.bits.relu6_shift := RegNext(io.read.req.bits.relu6_shift) q.io.enq.bits.act := RegNext(io.read.req.bits.act) @@ -222,17 +320,18 @@ class AccumulatorMem[T <: Data, U <: Data]( val q_will_be_empty = (q.io.count +& q.io.enq.fire()) - q.io.deq.fire() === 0.U io.read.req.ready := q_will_be_empty && ( // Make sure we aren't accumulating, which would take over both ports - !(io.write.fire() && io.write.bits.acc) && - // Make sure we aren't reading something that is still being written - !(RegNext(io.write.fire()) && RegNext(io.write.bits.addr) === io.read.req.bits.addr) && - !(w_buf_valid && waddr_buf === io.read.req.bits.addr) && + !(io.write.valid && io.write.bits.acc) && + !pipelined_writes.map(r => r.valid && r.bits.addr === io.read.req.bits.addr).reduce(_||_) && !block_read_req ) - io.write.ready := !io.write.bits.acc || (!(io.write.bits.addr === waddr_buf && w_buf_valid) && - !(io.write.bits.addr === RegNext(io.write.bits.addr) && RegNext(io.write.fire()))) + io.write.ready := !block_write_req && + !pipelined_writes.map(r => r.valid && r.bits.addr === io.write.bits.addr && io.write.bits.acc).reduce(_||_) + + when (reset.asBool()) { + pipelined_writes.foreach(_.valid := false.B) + } // assert(!(io.read.req.valid && io.write.en && io.write.acc), "reading and accumulating simultaneously is not supported") assert(!(io.read.req.fire() && io.write.fire() && io.read.req.bits.addr === io.write.bits.addr), "reading from and writing to same address is not supported") - assert(!(io.read.req.fire() && w_buf_valid && waddr_buf === io.read.req.bits.addr), "reading from an address immediately after writing to it is not supported") } diff --git a/src/main/scala/gemmini/Configs.scala b/src/main/scala/gemmini/Configs.scala index 2e172adf..a4094d88 100644 --- a/src/main/scala/gemmini/Configs.scala +++ b/src/main/scala/gemmini/Configs.scala @@ -5,7 +5,12 @@ import chisel3._ import freechips.rocketchip.config.{Config, Parameters} import freechips.rocketchip.diplomacy.LazyModule import freechips.rocketchip.subsystem._ -import freechips.rocketchip.tile.{BuildRoCC, OpcodeSet} +import freechips.rocketchip.tile.{BuildRoCC, OpcodeSet, XLen} +import freechips.rocketchip.rocket._ +import freechips.rocketchip.tile._ +import freechips.rocketchip.system._ +import freechips.rocketchip.diplomacy._ + import gemmini.Arithmetic.SIntArithmetic import hardfloat._ @@ -162,11 +167,14 @@ object GemminiConfigs { acc_scale_args=Some(defaultConfig.acc_scale_args.get.copy(latency=4)), acc_singleported=true, acc_sub_banks=2, + mesh_output_delay = 2, ex_read_from_acc=false, - ex_write_to_spad=false + ex_write_to_spad=false, + hardcode_d_to_garbage_addr = true ) val largeChipConfig = chipConfig.copy(sp_capacity=CapacityInKilobytes(128), acc_capacity=CapacityInKilobytes(64), + tileRows=1, tileColumns=1, meshRows=32, meshColumns=32 ) @@ -190,3 +198,76 @@ class DefaultGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( ) case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) }) + +// This Gemmini config has both an Int and an FP Gemmini side-by-side, sharing +// the same scratchpad. +class DualGemminiConfig extends Config((site, here, up) => { + case SystemBusKey => up(SystemBusKey).copy(beatBytes = 16) + case BuildRoCC => { + var int_gemmini: Gemmini[_,_,_] = null + var fp_gemmini: Gemmini[_,_,_] = null + val int_fn = (p: Parameters) => { + implicit val q = p + int_gemmini = LazyModule(new Gemmini(GemminiConfigs.chipConfig.copy( + opcodes = OpcodeSet.custom3, + use_shared_ext_mem = true, + clock_gate = true + ))) + int_gemmini + } + val fp_fn = (p: Parameters) => { + implicit val q = p + fp_gemmini = LazyModule(new Gemmini(GemminiFPConfigs.BF16DefaultConfig.copy( + opcodes = OpcodeSet.custom2, + sp_capacity=CapacityInKilobytes(64), acc_capacity=CapacityInKilobytes(32), + tileColumns = 1, tileRows = 1, + meshColumns = 8, meshRows = 8, + acc_singleported = true, acc_banks = 2, acc_sub_banks = 2, + use_shared_ext_mem = true, + ex_read_from_acc=false, + ex_write_to_spad=false, + hardcode_d_to_garbage_addr = true, + headerFileName = "gemmini_params_bf16.h", + acc_latency = 3, + dataflow = Dataflow.WS, + mesh_output_delay = 3, + clock_gate = true + ))) + InModuleBody { + require(int_gemmini.config.sp_banks == fp_gemmini.config.sp_banks) + require(int_gemmini.config.acc_banks == fp_gemmini.config.acc_banks) + require(int_gemmini.config.acc_sub_banks == fp_gemmini.config.acc_sub_banks) + require(int_gemmini.config.sp_singleported && fp_gemmini.config.sp_singleported) + require(int_gemmini.config.acc_singleported && fp_gemmini.config.acc_singleported) + + require(int_gemmini.config.sp_bank_entries == fp_gemmini.config.sp_bank_entries) + require(int_gemmini.spad.module.spad_mems(0).mask_len == fp_gemmini.spad.module.spad_mems(0).mask_len) + require(int_gemmini.spad.module.spad_mems(0).mask_elem.getWidth == fp_gemmini.spad.module.spad_mems(0).mask_elem.getWidth) + + println(int_gemmini.config.acc_bank_entries, fp_gemmini.config.acc_bank_entries) + println(int_gemmini.spad.module.acc_mems(0).mask_len, fp_gemmini.spad.module.acc_mems(0).mask_len) + println(int_gemmini.spad.module.acc_mems(0).mask_elem.getWidth, fp_gemmini.spad.module.acc_mems(0).mask_elem.getWidth) + + require(int_gemmini.config.acc_bank_entries == fp_gemmini.config.acc_bank_entries / 2) + require(int_gemmini.config.acc_sub_banks == fp_gemmini.config.acc_sub_banks) + require(int_gemmini.spad.module.acc_mems(0).mask_len == fp_gemmini.spad.module.acc_mems(0).mask_len * 2) + require(int_gemmini.spad.module.acc_mems(0).mask_elem.getWidth == fp_gemmini.spad.module.acc_mems(0).mask_elem.getWidth) + + val spad_mask_len = int_gemmini.spad.module.spad_mems(0).mask_len + val spad_data_len = int_gemmini.spad.module.spad_mems(0).mask_elem.getWidth + val acc_mask_len = int_gemmini.spad.module.acc_mems(0).mask_len + val acc_data_len = int_gemmini.spad.module.acc_mems(0).mask_elem.getWidth + + val shared_mem = Module(new SharedExtMem( + int_gemmini.config.sp_banks, int_gemmini.config.acc_banks, int_gemmini.config.acc_sub_banks, + int_gemmini.config.sp_bank_entries, spad_mask_len, spad_data_len, + int_gemmini.config.acc_bank_entries / int_gemmini.config.acc_sub_banks, acc_mask_len, acc_data_len + )) + shared_mem.io.in(0) <> int_gemmini.module.ext_mem_io.get + shared_mem.io.in(1) <> fp_gemmini.module.ext_mem_io.get + } + fp_gemmini + } + up(BuildRoCC) ++ Seq(int_fn, fp_fn) + } +}) diff --git a/src/main/scala/gemmini/ConfigsFP.scala b/src/main/scala/gemmini/ConfigsFP.scala index a54c2853..35ecf821 100644 --- a/src/main/scala/gemmini/ConfigsFP.scala +++ b/src/main/scala/gemmini/ConfigsFP.scala @@ -30,6 +30,7 @@ object GemminiFPConfigs { sp_banks = 4, sp_singleported = true, acc_banks = 1, + acc_latency = 2, acc_singleported = false, acc_sub_banks = -1, sp_capacity = CapacityInKilobytes(256), @@ -45,7 +46,7 @@ object GemminiFPConfigs { use_tlb_register_filter = true, max_in_flight_mem_reqs = 16, use_dedicated_tl_port = false, - + use_shared_ext_mem = false, inputType = Float(8, 24), spatialArrayOutputType = Float(8, 24), accType = Float(8, 24), @@ -60,7 +61,7 @@ object GemminiFPConfigs { acc_read_full_width = true, acc_read_small_width = true, - pe_latency = 1, + tile_latency = 1, ex_read_from_spad = true, ex_read_from_acc = true, @@ -80,21 +81,21 @@ object GemminiFPConfigs { //FP32 Single Precision Configuration val FP32DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 24), spatialArrayOutputType = Float(8, 24), accType = Float(8, 24), - pe_latency = 2, + tile_latency = 2, mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), ) - + //FP16 Half Precision Configuration val FP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), spatialArrayOutputType = Float(5, 11), accType = Float(8, 24), - pe_latency = 2, + tile_latency = 2, mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")), mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(5, 11), -1, identity = "1.0", c_str="((x) * (scale))")), ) //Bfloat16 Brain-half Precision Configuration val BF16DefaultConfig = defaultFPConfig.copy(inputType = Float(8, 8), spatialArrayOutputType = Float(8, 8), accType = Float(8, 24), - pe_latency = 2, + tile_latency = 2, mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), ) @@ -102,7 +103,7 @@ object GemminiFPConfigs { //Bfloat16 Brain-half Precision Configuration 8x8 array val BF16Default8Config = defaultFPConfig.copy(inputType = Float(8, 8), spatialArrayOutputType = Float(8, 8), accType = Float(8, 24), meshRows = 8, meshColumns = 8, - pe_latency = 2, + tile_latency = 2, mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), ) diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index 08481a5c..74f23b4c 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -9,6 +9,7 @@ import chisel3.util._ import freechips.rocketchip.config._ import freechips.rocketchip.diplomacy._ import freechips.rocketchip.tile._ +import freechips.rocketchip.util.ClockGate import freechips.rocketchip.tilelink.TLIdentityNode import GemminiISA._ import Util._ @@ -24,7 +25,7 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA (implicit p: Parameters) extends LazyRoCC ( opcodes = config.opcodes, - nPTWPorts = 1) { + nPTWPorts = if (config.use_shared_tlb) 1 else 2) { Files.write(Paths.get(config.headerFilePath), config.generateHeader().getBytes(StandardCharsets.UTF_8)) if (System.getenv("GEMMINI_ONLY_GENERATE_GEMMINI_H") == "1") { @@ -49,6 +50,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] import outer.config._ import outer.spad + val ext_mem_io = if (use_shared_ext_mem) Some(IO(new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks))) else None + ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get) + val tagWidth = 32 // Counters @@ -62,15 +66,21 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] // TLB implicit val edge = outer.node.edges.out.head - val tlb = Module(new FrontendTLB(2, tlb_size, dma_maxbytes, use_tlb_register_filter, use_firesim_simulation_counters)) + val tlb = Module(new FrontendTLB(2, tlb_size, dma_maxbytes, use_tlb_register_filter, use_firesim_simulation_counters, use_shared_tlb)) (tlb.io.clients zip outer.spad.module.io.tlb).foreach(t => t._1 <> t._2) - tlb.io.exp.flush_skip := false.B - tlb.io.exp.flush_retry := false.B + + tlb.io.exp.foreach(_.flush_skip := false.B) + tlb.io.exp.foreach(_.flush_retry := false.B) + + io.ptw <> tlb.io.ptw + counters.io.event_io.collect(tlb.io.counter) - io.ptw.head <> tlb.io.ptw + spad.module.io.flush := tlb.io.exp.map(_.flush()).reduce(_ || _) - spad.module.io.flush := tlb.io.exp.flush() + val clock_en_reg = RegInit(true.B) + val gated_clock = if (clock_gate) ClockGate(clock, clock_en_reg, "gemmini_clock_gate") else clock + outer.spad.module.clock := gated_clock /* //========================================================================= @@ -111,10 +121,12 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] val unrolled_cmd = LoopUnroller(raw_risc_cmd, outer.config.meshRows * outer.config.tileRows) */ - // Incoming commands and reservation station - val reservation_station = Module(new ReservationStation(outer.config, new RoCCCommand)) + val reservation_station = withClock (gated_clock) { Module(new ReservationStation(outer.config, new RoCCCommand)) } counters.io.event_io.collect(reservation_station.io.counter) + when (io.cmd.valid && io.cmd.bits.inst.funct === CLKGATE_EN && !io.busy) { + clock_en_reg := io.cmd.bits.rs1(0) + } val raw_cmd = Queue(io.cmd) val max_lds = reservation_station_partial_entries @@ -122,22 +134,22 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] val max_sts = reservation_station_partial_entries / 2 // TODO replace 4,12,2 with parameters based on ROB size - val (conv_cmd, loop_conv_unroller_busy) = LoopConv(raw_cmd, reservation_station.io.ld_utilization, reservation_station.io.st_utilization, reservation_station.io.ex_utilization, + val (conv_cmd, loop_conv_unroller_busy) = withClock (gated_clock) { LoopConv(raw_cmd, reservation_station.io.ld_utilization, reservation_station.io.st_utilization, reservation_station.io.ex_utilization, meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, inputType.getWidth, accType.getWidth, dma_maxbytes, - new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits), new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t), + new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits, pixel_repeats_bits), new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ConfigMvoutRs2(acc_scale_t_bits, 32), new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t), new ConfigExRs1(acc_scale_t_bits), new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), - has_training_convs, has_max_pool) + has_training_convs, has_max_pool, has_first_layer_optimizations) } - val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(conv_cmd, reservation_station.io.ld_utilization, reservation_station.io.st_utilization, reservation_station.io.ex_utilization, + val (loop_cmd, loop_matmul_unroller_busy) = withClock (gated_clock) { LoopMatmul(conv_cmd, reservation_station.io.ld_utilization, reservation_station.io.st_utilization, reservation_station.io.ex_utilization, meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries, inputType.getWidth, accType.getWidth, dma_maxbytes, new MvinRs2(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new PreloadRs(mvout_rows_bits, mvout_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), new ComputeRs(mvin_rows_bits, mvin_cols_bits, local_addr_t), - new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t)) + new MvoutRs2(mvout_rows_bits, mvout_cols_bits, local_addr_t)) } val unrolled_cmd = Queue(loop_cmd) unrolled_cmd.ready := false.B @@ -165,9 +177,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] //========================================================================= // Controllers //========================================================================= - val load_controller = Module(new LoadController(outer.config, coreMaxAddrBits, local_addr_t)) - val store_controller = Module(new StoreController(outer.config, coreMaxAddrBits, local_addr_t)) - val ex_controller = Module(new ExecuteController(xLen, tagWidth, outer.config)) + val load_controller = withClock (gated_clock) { Module(new LoadController(outer.config, coreMaxAddrBits, local_addr_t)) } + val store_controller = withClock (gated_clock) { Module(new StoreController(outer.config, coreMaxAddrBits, local_addr_t)) } + val ex_controller = withClock (gated_clock) { Module(new ExecuteController(xLen, tagWidth, outer.config)) } counters.io.event_io.collect(load_controller.io.counter) counters.io.event_io.collect(store_controller.io.counter) @@ -238,7 +250,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] ex_controller.io.acc.write <> spad.module.io.acc.write // Im2Col unit - val im2col = Module(new Im2Col(outer.config)) + val im2col = withClock (gated_clock) { Module(new Im2Col(outer.config)) } // Wire up Im2col counters.io.event_io.collect(im2col.io.counter) @@ -311,7 +323,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] // Wire up global RoCC signals io.busy := raw_cmd.valid || loop_conv_unroller_busy || loop_matmul_unroller_busy || reservation_station.io.busy || spad.module.io.busy || unrolled_cmd.valid || loop_cmd.valid || conv_cmd.valid - io.interrupt := tlb.io.exp.interrupt + + io.interrupt := tlb.io.exp.map(_.interrupt).reduce(_ || _) reservation_station.io.solitary_preload := ex_controller.io.solitary_preload @@ -347,6 +360,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] val is_flush = risc_funct === FLUSH_CMD val is_counter_op = risc_funct === COUNTER_OP + val is_clock_gate_en = risc_funct === CLKGATE_EN + /* val is_load = (funct === LOAD_CMD) || (funct === CONFIG_CMD && config_cmd_type === CONFIG_LOAD) val is_store = (funct === STORE_CMD) || (funct === CONFIG_CMD && config_cmd_type === CONFIG_STORE) @@ -356,8 +371,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] when (is_flush) { val skip = unrolled_cmd.bits.rs1(0) - tlb.io.exp.flush_skip := skip - tlb.io.exp.flush_retry := !skip + tlb.io.exp.foreach(_.flush_skip := skip) + tlb.io.exp.foreach(_.flush_retry := !skip) unrolled_cmd.ready := true.B // TODO should we wait for an acknowledgement from the TLB? } @@ -367,6 +382,10 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] counters.io.in <> unrolled_cmd } + .elsewhen (is_clock_gate_en) { + unrolled_cmd.ready := true.B + } + .otherwise { reservation_station.io.alloc.valid := true.B diff --git a/src/main/scala/gemmini/DMA.scala b/src/main/scala/gemmini/DMA.scala index c1cb51ef..5952be5b 100644 --- a/src/main/scala/gemmini/DMA.scala +++ b/src/main/scala/gemmini/DMA.scala @@ -27,6 +27,7 @@ class StreamReadRequest[U <: Data](spad_rows: Int, acc_rows: Int, mvin_scale_t_b val status = new MStatus val len = UInt(16.W) // TODO magic number val repeats = UInt(16.W) // TODO magic number + val pixel_repeats = UInt(8.W) // TODO magic number val block_stride = UInt(16.W) // TODO magic number val cmd_id = UInt(8.W) // TODO magic number @@ -43,6 +44,8 @@ class StreamReadResponse[U <: Data](spadWidth: Int, accWidth: Int, spad_rows: In val has_acc_bitwidth = Bool() val scale = UInt(mvin_scale_t_bits.W) val repeats = UInt(16.W) // TODO magic number + val pixel_repeats = UInt(16.W) // TODO magic number + val len = UInt(16.W) // TODO magic number val last = Bool() val bytes_read = UInt(8.W) // TODO magic number val cmd_id = UInt(8.W) // TODO magic number @@ -100,6 +103,8 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T io.resp.bits.has_acc_bitwidth := beatPacker.io.out.bits.has_acc_bitwidth io.resp.bits.scale := RegEnable(xactTracker.io.peek.entry.scale, beatPacker.io.req.fire()) io.resp.bits.repeats := RegEnable(xactTracker.io.peek.entry.repeats, beatPacker.io.req.fire()) + io.resp.bits.pixel_repeats := RegEnable(xactTracker.io.peek.entry.pixel_repeats, beatPacker.io.req.fire()) + io.resp.bits.len := RegEnable(xactTracker.io.peek.entry.len, beatPacker.io.req.fire()) io.resp.bits.cmd_id := RegEnable(xactTracker.io.peek.entry.cmd_id, beatPacker.io.req.fire()) io.resp.bits.bytes_read := RegEnable(xactTracker.io.peek.entry.bytes_to_read, beatPacker.io.req.fire()) io.resp.bits.last := beatPacker.io.out.bits.last @@ -250,6 +255,8 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf io.reserve.entry.has_acc_bitwidth := req.has_acc_bitwidth io.reserve.entry.scale := req.scale io.reserve.entry.repeats := req.repeats + io.reserve.entry.pixel_repeats := req.pixel_repeats + io.reserve.entry.len := req.len io.reserve.entry.block_stride := req.block_stride io.reserve.entry.lg_len_req := DontCare // TODO just remove this from the IO completely io.reserve.entry.bytes_to_read := read_bytes_read diff --git a/src/main/scala/gemmini/DSEConfigs.scala b/src/main/scala/gemmini/DSEConfigs.scala index 37fc70f4..f00297e3 100644 --- a/src/main/scala/gemmini/DSEConfigs.scala +++ b/src/main/scala/gemmini/DSEConfigs.scala @@ -27,7 +27,7 @@ object DSEBaseConfig { sp_banks = 4, // TODO support one-bank designs acc_banks = 1, acc_singleported = false, - acc_sub_banks = -1, + acc_latency = 2, sp_capacity = CapacityInKilobytes(64), sp_singleported = false, shifter_banks = 1, // TODO add separate parameters for left and up shifter banks @@ -59,7 +59,9 @@ object DSEBaseConfig { acc_read_full_width = true, acc_read_small_width = true, use_dedicated_tl_port = false, - pe_latency = 0, + + use_shared_ext_mem = true, + tile_latency = 0, ex_read_from_spad = true, ex_read_from_acc = true, @@ -79,6 +81,8 @@ object DSEBaseConfig { has_nonlinear_activations = true, num_counter = 8, + + clock_gate = false, ) } diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index 9d1cf094..6891c09b 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -187,7 +187,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val cntl = mesh_cntl_signals_q.io.deq.bits // Instantiate the actual mesh - val mesh = Module(new MeshWithDelays(inputType, spatialArrayOutputType, accType, mesh_tag, dataflow, pe_latency, mesh_output_delay, + val mesh = Module(new MeshWithDelays(inputType, spatialArrayOutputType, accType, mesh_tag, dataflow, tree_reduction, tile_latency, mesh_output_delay, tileRows, tileColumns, meshRows, meshColumns, shifter_banks, shifter_banks)) mesh.io.a.valid := false.B @@ -891,12 +891,12 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In when (cntl_valid && cntl.perform_single_preload) { mesh.io.a.bits := Mux(a_should_be_fed_into_transposer, dataA.asUInt, 0.U).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) - mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, dataB.asUInt, 0.U).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) + mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, dataB.asUInt, 0.U).asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) } when (cntl_valid && cntl.perform_single_mul) { mesh.io.a.bits := Mux(a_should_be_fed_into_transposer, 0.U, dataA.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) - mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, 0.U, dataB.asUInt).asTypeOf(Vec(meshRows, Vec(tileRows, inputType))) + mesh.io.b.bits := Mux(b_should_be_fed_into_transposer, 0.U, dataB.asUInt).asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) mesh.io.req.bits.tag.addr.make_this_garbage() } diff --git a/src/main/scala/gemmini/FrontendTLB.scala b/src/main/scala/gemmini/FrontendTLB.scala index 50c393b5..bc028ee9 100644 --- a/src/main/scala/gemmini/FrontendTLB.scala +++ b/src/main/scala/gemmini/FrontendTLB.scala @@ -66,12 +66,12 @@ class DecoupledTLB(entries: Int, maxSize: Int, use_firesim_simulation_counters: assert(!io.exp.flush_retry || !io.exp.flush_skip, "TLB: flushing with both retry and skip at same time") CounterEventIO.init(io.counter) - io.counter.connectEventSignal(CounterEvent.DMA_TLB_HIT_REQ, RegNext(io.req.fire()) && !tlb.io.resp.miss) + io.counter.connectEventSignal(CounterEvent.DMA_TLB_HIT_REQ, io.req.fire() && !tlb.io.resp.miss) io.counter.connectEventSignal(CounterEvent.DMA_TLB_TOTAL_REQ, io.req.fire()) io.counter.connectEventSignal(CounterEvent.DMA_TLB_MISS_CYCLE, tlb.io.resp.miss) if (use_firesim_simulation_counters) { - PerfCounter(RegNext(io.req.fire()) && !tlb.io.resp.miss, "tlb_hits", "total number of tlb hits") + PerfCounter(io.req.fire() && !tlb.io.resp.miss, "tlb_hits", "total number of tlb hits") PerfCounter(io.req.fire(), "tlb_reqs", "total number of tlb reqs") PerfCounter(tlb.io.resp.miss, "tlb_miss_cycles", "total number of cycles where the tlb is resolving a miss") } @@ -84,51 +84,66 @@ class FrontendTLBIO(implicit p: Parameters) extends CoreBundle { val resp = Flipped(new TLBResp) } -class FrontendTLB(nClients: Int, entries: Int, maxSize: Int, use_tlb_register_filter: Boolean, use_firesim_simulation_counters: Boolean) +class FrontendTLB(nClients: Int, entries: Int, maxSize: Int, use_tlb_register_filter: Boolean, use_firesim_simulation_counters: Boolean, use_shared_tlb: Boolean) (implicit edge: TLEdgeOut, p: Parameters) extends CoreModule { + + val num_tlbs = if (use_shared_tlb) 1 else nClients + val lgMaxSize = log2Ceil(coreDataBytes) + val io = IO(new Bundle { val clients = Flipped(Vec(nClients, new FrontendTLBIO)) - val ptw = new TLBPTWIO - val exp = new TLBExceptionIO + val ptw = Vec(num_tlbs, new TLBPTWIO) + val exp = Vec(num_tlbs, new TLBExceptionIO) val counter = new CounterEventIO() }) - val lgMaxSize = log2Ceil(coreDataBytes) - val tlbArb = Module(new RRArbiter(new DecoupledTLBReq(lgMaxSize), nClients)) - val tlb = Module(new DecoupledTLB(entries, maxSize, use_firesim_simulation_counters)) - tlb.io.req.valid := tlbArb.io.out.valid - tlb.io.req.bits := tlbArb.io.out.bits - tlbArb.io.out.ready := true.B + val tlbs = Seq.fill(num_tlbs)(Module(new DecoupledTLB(entries, maxSize, use_firesim_simulation_counters))) - io.ptw <> tlb.io.ptw - io.exp <> tlb.io.exp + io.ptw <> VecInit(tlbs.map(_.io.ptw)) + io.exp <> VecInit(tlbs.map(_.io.exp)) + + val tlbArbOpt = if (use_shared_tlb) Some(Module(new RRArbiter(new DecoupledTLBReq(lgMaxSize), nClients))) else None + + if (use_shared_tlb) { + val tlbArb = tlbArbOpt.get + val tlb = tlbs.head + tlb.io.req.valid := tlbArb.io.out.valid + tlb.io.req.bits := tlbArb.io.out.bits + tlbArb.io.out.ready := true.B + } - io.clients.zip(tlbArb.io.in).foreach { case (client, req) => + io.clients.zipWithIndex.foreach { case (client, i) => val last_translated_valid = RegInit(false.B) val last_translated_vpn = RegInit(0.U(vaddrBits.W)) val last_translated_ppn = RegInit(0.U(paddrBits.W)) - val l0_tlb_hit = last_translated_valid && ((client.req.bits.tlb_req.vaddr >> pgIdxBits) === (last_translated_vpn >> pgIdxBits)) + val l0_tlb_hit = last_translated_valid && ((client.req.bits.tlb_req.vaddr >> pgIdxBits).asUInt() === (last_translated_vpn >> pgIdxBits).asUInt()) val l0_tlb_paddr = Cat(last_translated_ppn >> pgIdxBits, client.req.bits.tlb_req.vaddr(pgIdxBits-1,0)) - when (req.fire() && !tlb.io.resp.miss) { + val tlb = if (use_shared_tlb) tlbs.head else tlbs(i) + val tlbReq = if (use_shared_tlb) tlbArbOpt.get.io.in(i).bits else tlb.io.req.bits + val tlbReqValid = if (use_shared_tlb) tlbArbOpt.get.io.in(i).valid else tlb.io.req.valid + val tlbReqFire = if (use_shared_tlb) tlbArbOpt.get.io.in(i).fire() else tlb.io.req.fire() + + tlbReqValid := RegNext(client.req.valid && !l0_tlb_hit) + tlbReq := RegNext(client.req.bits) + + when (tlbReqFire && !tlb.io.resp.miss) { last_translated_valid := true.B - last_translated_vpn := req.bits.tlb_req.vaddr + last_translated_vpn := tlbReq.tlb_req.vaddr last_translated_ppn := tlb.io.resp.paddr } - when (io.exp.flush()) { + + when (tlb.io.exp.flush()) { last_translated_valid := false.B } - req.valid := RegNext(client.req.valid && !l0_tlb_hit) - req.bits := RegNext(client.req.bits) - - when (!req.fire()) { + when (tlbReqFire) { + client.resp := tlb.io.resp + }.otherwise { client.resp := DontCare client.resp.paddr := RegNext(l0_tlb_paddr) client.resp.miss := !RegNext(l0_tlb_hit) - } .otherwise { - client.resp := tlb.io.resp } // If we're not using the TLB filter register, then we set this value to always be false @@ -137,16 +152,8 @@ class FrontendTLB(nClients: Int, entries: Int, maxSize: Int, use_tlb_register_fi } } - io.counter.collect(tlb.io.counter) + // TODO Return the sum of the TLB counters, rather than just the counters of the first TLB. This only matters if we're + // not using the shared TLB + tlbs.foreach(_.io.counter.external_reset := false.B) + io.counter.collect(tlbs.head.io.counter) } - -/*class TLBArb (nClients: Int, lgMaxSize: Int)(implicit p: Parameters) extends CoreModule { - val io = IO(new Bundle { - val in_req = Vec(nClients, Flipped(Decoupled(new TLBReq(lgMaxSize)))) - val in_resp = Vec(nClients, Flipped(Valid(new TLBResp))) - val out_req = Decoupled(new TLBReq(lgMaxSize)) - val out_resp = Valid(new TLBResp) - }) - - val priority = Reg(UInt(log2Up(nClients).W)) -}*/ diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index 45c481ce..567ef060 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -15,12 +15,12 @@ case class ScaleArguments[T <: Data, U <: Data](scale_func: (T, U) => T, latency identity: String="0", c_str: String="ROUNDING_RIGHT_SHIFT(x, scale)") case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( + opcodes: OpcodeSet = OpcodeSet.custom3, + inputType: T, spatialArrayOutputType: T, accType: T, - opcodes: OpcodeSet = OpcodeSet.custom3, - dataflow: Dataflow.Value = Dataflow.BOTH, tileRows: Int = 1, @@ -44,6 +44,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( acc_singleported: Boolean = false, acc_sub_banks: Int = -1, acc_capacity: GemminiMemCapacity = CapacityInKilobytes(64), + acc_latency: Int = 2, dma_maxbytes: Int = 64, // TODO get this from cacheblockbytes dma_buswidth: Int = 128, // TODO get this from SystemBusKey @@ -57,8 +58,6 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( mvin_scale_shared: Boolean = false, acc_scale_args: Option[ScaleArguments[T, V]] = None, - pe_latency: Int = 0, - acc_read_full_width: Boolean = true, acc_read_small_width: Boolean = true, use_dedicated_tl_port: Boolean = true, @@ -73,17 +72,26 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( ex_write_to_acc: Boolean = true, hardcode_d_to_garbage_addr: Boolean = false, + use_shared_tlb: Boolean = true, + tile_latency: Int = 0, mesh_output_delay: Int = 1, + use_tree_reduction_if_possible: Boolean = true, + num_counter: Int = 8, has_training_convs: Boolean = true, has_max_pool: Boolean = true, has_nonlinear_activations: Boolean = true, + has_first_layer_optimizations: Boolean = true, + use_firesim_simulation_counters: Boolean = false, + use_shared_ext_mem: Boolean = false, + clock_gate: Boolean = false, + headerFileName: String = "gemmini_params.h" ) { val sp_width = meshColumns * tileColumns * inputType.getWidth @@ -153,10 +161,17 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( val mvout_rows_bits = log2Up(meshRows * tileRows + 1) val load_states = 3 - val block_stride_bits = 16 + val block_stride_bits = 16 min (log2Up(acc_banks * acc_bank_entries) max log2Up(sp_banks * sp_bank_entries)) + + val a_stride_bits = 16 min (log2Up(acc_banks * acc_bank_entries) max log2Up(sp_banks * sp_bank_entries)) + val c_stride_bits = 16 min (log2Up(acc_banks * acc_bank_entries) max log2Up(sp_banks * sp_bank_entries)) + + val pixel_repeats_bits = 8 min log2Up(meshColumns * tileColumns + 1) val hasIm2Col = false + val tree_reduction = use_tree_reduction_if_possible && dataflow == Dataflow.WS && tileRows > 1 + //========================================================================== // sanity check mesh size //========================================================================== @@ -260,7 +275,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( (dt.expWidth, dt.sigWidth) match { case (8, 24) => (scala.Float.MinValue.toString, scala.Float.MaxValue.toString) case (11, 53) => (scala.Double.MinValue.toString, scala.Double.MaxValue.toString) - case _ => (((Range(-1,-(dt.sigWidth),-1).map(-Math.pow(2, _)).foldLeft(-1.0)(_ + _)) * Math.pow(2, Math.pow(2, dt.expWidth - 1) - 1)).toString, ((Range(-1,-(dt.sigWidth),-1).map(Math.pow(2, _)).foldLeft(1.0)(_ + _)) * Math.pow(2, Math.pow(2, dt.expWidth - 1) - 1)).toString) + case (e, s) => (((Range(-1,-(s),-1).map(-Math.pow(2, _)).foldLeft(-1.0)(_ + _)) * Math.pow(2, Math.pow(2, e - 1) - 1)).toString, ((Range(-1,-(s),-1).map(Math.pow(2, _)).foldLeft(1.0)(_ + _)) * Math.pow(2, Math.pow(2, e - 1) - 1)).toString) } case dt => ("0", BigInt(2).pow(dt.getWidth).-(1).toString) // case _ => throw new IllegalArgumentException(s"Data type $dataType is unknown") @@ -274,7 +289,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( (dt.expWidth, dt.sigWidth) match { case (8, 24) => "float" case (11, 53) => "double" - case _ => s"uint" + (Math.pow(2, Math.ceil(Math.log(dt.expWidth + dt.sigWidth)/Math.log(2.0)))).toInt.toString + s"_t" + case (e, s) => s"uint" + (Math.pow(2, Math.ceil(Math.log(e + s)/Math.log(2.0)))).toInt.toString + s"_t" } case dt => s"uint${dt.getWidth}_t" } @@ -463,6 +478,10 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( header ++= s"#define ACC_READ_FULL_WIDTH\n" header ++= s"\n" + if (has_first_layer_optimizations) { + header ++= "#define HAS_FIRST_LAYER_OPTIMIZATIONS\n\n" + } + header ++= s"#endif // $guard\n" header.toString() } diff --git a/src/main/scala/gemmini/GemminiISA.scala b/src/main/scala/gemmini/GemminiISA.scala index 554bcdeb..0b28316d 100644 --- a/src/main/scala/gemmini/GemminiISA.scala +++ b/src/main/scala/gemmini/GemminiISA.scala @@ -24,7 +24,7 @@ object GemminiISA { val LOAD3_CMD = 14.U // TODO add orows and ocols to this as well - val LOOP_CONV_WS = 15.U // no_bias, wrot180, trans_output_1203, trans_weight_1203, trans_input_3120 | no_pool, downsample, input_dilated, act + val LOOP_CONV_WS = 15.U // no_bias, wrot180, trans_output_1203, trans_weight_1203, trans_input_3120, max_pixels_per_row | no_pool, downsample, input_dilated, act val LOOP_CONV_WS_CONFIG_1 = 16.U // batch_size, in_dim, in_channels, out_channels | out_dim, pool_out_dim, stride, padding val LOOP_CONV_WS_CONFIG_2 = 17.U // kernel_dim, pool_size, pool_stride, pool_padding | batches, porows, pocols, pochs val LOOP_CONV_WS_CONFIG_3 = 18.U // krows, kcols, kchs, lpad | rpad, upad, dpad, plpad @@ -32,6 +32,8 @@ object GemminiISA { val LOOP_CONV_WS_CONFIG_5 = 20.U // *weights | *output val LOOP_CONV_WS_CONFIG_6 = 21.U // *bias, *input + val CLKGATE_EN = 22.U + // rs1[2:0] values val CONFIG_EX = 0.U val CONFIG_LOAD = 1.U @@ -93,22 +95,25 @@ object GemminiISA { val CONFIG_MVIN_RS1_UNUSED_WIDTH = 2 val CONFIG_MVIN_RS1_SHRINK_WIDTH = 1 val CONFIG_MVIN_RS1_STATE_ID_WIDTH = 2 - val CONFIG_MVIN_RS1_SPACER_WIDTH = (16 - 2 - 1 - 2) + val CONFIG_MVIN_RS1_SPACER_WIDTH = 8 - 2 - 1 - 2 + val CONFIG_MVIN_RS1_PIXEL_REPEAT_WIDTH = 8 val CONFIG_MVIN_RS1_STRIDE_WIDTH = 16 val CONFIG_MVIN_RS1_SCALE_WIDTH = 32 - class ConfigMvinRs1(scale_bits: Int, stride_bits: Int) extends Bundle { - val _spacer2 = UInt((CONFIG_MVIN_RS1_SCALE_WIDTH - scale_bits).W) + class ConfigMvinRs1(scale_bits: Int, stride_bits: Int, pixel_repeat_bits: Int) extends Bundle { + val _spacer3 = UInt((CONFIG_MVIN_RS1_SCALE_WIDTH - scale_bits).W) val scale = UInt(scale_bits.W) - val _spacer1 = UInt((CONFIG_MVIN_RS1_STRIDE_WIDTH - stride_bits).W) + val _spacer2 = UInt((CONFIG_MVIN_RS1_STRIDE_WIDTH - stride_bits).W) val stride = UInt(stride_bits.W) + val _spacer1 = UInt((CONFIG_MVIN_RS1_PIXEL_REPEAT_WIDTH - pixel_repeat_bits).W) + val pixel_repeats = UInt(pixel_repeat_bits.W) val _spacer0 = UInt(CONFIG_MVIN_RS1_SPACER_WIDTH.W) val state_id = UInt(CONFIG_MVIN_RS1_STATE_ID_WIDTH.W) val shrink = UInt(CONFIG_MVIN_RS1_SHRINK_WIDTH.W) val _unused = UInt(CONFIG_MVIN_RS1_UNUSED_WIDTH.W) override def cloneType: ConfigMvinRs1.this.type = - (new ConfigMvinRs1(scale_bits, stride_bits)).asInstanceOf[this.type] + (new ConfigMvinRs1(scale_bits, stride_bits, pixel_repeat_bits)).asInstanceOf[this.type] } val CONFIG_MVOUT_RS1_UNUSED_WIDTH = 2 diff --git a/src/main/scala/gemmini/LoadController.scala b/src/main/scala/gemmini/LoadController.scala index 89f7be7c..49d7b409 100644 --- a/src/main/scala/gemmini/LoadController.scala +++ b/src/main/scala/gemmini/LoadController.scala @@ -34,6 +34,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig val scales = Reg(Vec(load_states, UInt(mvin_scale_t_bits.W))) val shrinks = Reg(Vec(load_states, Bool())) // Shrink inputs to accumulator val block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) // Spad stride during block move-ins + val pixel_repeats = Reg(Vec(load_states, UInt(pixel_repeats_bits.W))) val block_rows = meshRows * tileRows val block_cols = meshColumns * tileColumns val row_counter = RegInit(0.U(log2Ceil(block_rows).W)) @@ -47,11 +48,13 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig val rows = mvin_rs2.num_rows val config_stride = cmd.bits.cmd.rs2 - val config_mvin_rs1 = cmd.bits.cmd.rs1.asTypeOf(new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits)) - val config_scale = config_mvin_rs1.scale // maybe limit width to `mvin_scale_t_bits`? + val config_mvin_rs1 = cmd.bits.cmd.rs1.asTypeOf(new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits, pixel_repeats_bits)) + + val config_scale = config_mvin_rs1.scale val config_shrink = config_mvin_rs1.shrink val config_block_stride = config_mvin_rs1.stride + val config_pixel_repeats = config_mvin_rs1.pixel_repeats val mstatus = cmd.bits.cmd.status @@ -64,6 +67,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig val scale = scales(state_id) val shrink = shrinks(state_id) val block_stride = block_strides(state_id) + val pixel_repeat = pixel_repeats(state_id) val all_zeros = vaddr === 0.U @@ -104,6 +108,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig io.dma.req.bits.has_acc_bitwidth := localaddr_plus_row_counter.is_acc_addr && !shrink io.dma.req.bits.all_zeros := all_zeros io.dma.req.bits.status := mstatus + io.dma.req.bits.pixel_repeats := pixel_repeat // Command tracker IO cmd_tracker.io.alloc.valid := control_state === waiting_for_command && cmd.valid && DoLoad @@ -140,6 +145,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig scale := config_scale shrink := config_shrink block_stride := config_block_stride + pixel_repeat := Mux(config_pixel_repeats === 0.U, 1.U, config_pixel_repeats) // TODO this default value was just added to maintain backwards compatibility. we should deprecate and remove it later cmd.ready := true.B } @@ -165,6 +171,10 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig } } + // Optimizations based on config parameters + if (!has_first_layer_optimizations) + pixel_repeats.foreach(_ := 1.U) + // Performance counter CounterEventIO.init(io.counter) io.counter.connectEventSignal(CounterEvent.LOAD_ACTIVE_CYCLE, control_state === sending_rows) @@ -177,4 +187,5 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig // Assertions assert(!(cmd_tracker.io.alloc.fire() && cmd_tracker.io.alloc.bits.bytes_to_read === 0.U), "A single mvin instruction must load more than 0 bytes") + assert(has_first_layer_optimizations.B || !(cmd.valid && DoConfig && config_pixel_repeats > 1.U), "If first-layer optimizations are not enabled, then pixel-repeats cannot be greater than 1") } diff --git a/src/main/scala/gemmini/LocalAddr.scala b/src/main/scala/gemmini/LocalAddr.scala index b003fd7b..ac5a1f4a 100644 --- a/src/main/scala/gemmini/LocalAddr.scala +++ b/src/main/scala/gemmini/LocalAddr.scala @@ -16,6 +16,8 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en private val accBankBits = log2Up(acc_banks) val accBankRowBits = log2Up(acc_bank_entries) + val spRows = sp_banks * sp_bank_entries + val is_acc_addr = Bool() val accumulate = Bool() val read_full_acc_row = Bool() @@ -71,6 +73,19 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en (result, overflow) } + // This function can only be used with non-accumulator addresses. Returns both new address and underflow + def floorSub(other: UInt, floor: UInt): (LocalAddr, Bool) = { + require(isPow2(sp_bank_entries)) // TODO remove this requirement + require(isPow2(acc_bank_entries)) // TODO remove this requirement + + val underflow = data < (floor +& other) + + val result = WireInit(this) + result.data := Mux(underflow, floor, data - other) + + (result, underflow) + } + def make_this_garbage(dummy: Int = 0): Unit = { is_acc_addr := true.B accumulate := true.B @@ -81,3 +96,13 @@ class LocalAddr(sp_banks: Int, sp_bank_entries: Int, acc_banks: Int, acc_bank_en override def cloneType: LocalAddr.this.type = new LocalAddr(sp_banks, sp_bank_entries, acc_banks, acc_bank_entries).asInstanceOf[this.type] } + +object LocalAddr { + def cast_to_local_addr[T <: Data](local_addr_t: LocalAddr, t: T): LocalAddr = { + // This convenience function is basically the same as calling "asTypeOf(local_addr_t)". However, this convenience + // function will also cast unnecessary garbage bits to 0, which may help reduce multiplier/adder bitwidths + val result = WireInit(t.asTypeOf(local_addr_t)) + if (result.garbage_bit.getWidth > 0) result.garbage := 0.U + result + } +} diff --git a/src/main/scala/gemmini/LoopConv.scala b/src/main/scala/gemmini/LoopConv.scala index 47cd5a39..d2775a9c 100644 --- a/src/main/scala/gemmini/LoopConv.scala +++ b/src/main/scala/gemmini/LoopConv.scala @@ -6,6 +6,7 @@ import chisel3.experimental._ import freechips.rocketchip.tile.RoCCCommand import freechips.rocketchip.config.Parameters import GemminiISA._ +import LocalAddr.cast_to_local_addr import Util._ class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_bitwidth: Int, val tiny_iterator_bitwidth: Int) extends Bundle { @@ -137,6 +138,7 @@ class LoopConvLdBias(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwi config_cmd_rs1 := DontCare config_cmd_rs1.scale := MVIN_SCALE_IDENTITY config_cmd_rs1.stride := req.derived_params.bias_spad_stride + config_cmd_rs1.pixel_repeats := 1.U config_cmd_rs1.state_id := 2.U config_cmd_rs1.shrink := 0.U config_cmd_rs1._unused := 1.U @@ -172,7 +174,7 @@ class LoopConvLdBias(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwi mvin_cmd_rs2 := DontCare mvin_cmd_rs2.num_rows := o.I.asUInt() mvin_cmd_rs2.num_cols := o.J.asUInt() - mvin_cmd_rs2.local_addr := o.spad_addr.asTypeOf(mvin_cmd_rs2.local_addr) + mvin_cmd_rs2.local_addr := cast_to_local_addr(mvin_cmd_rs2.local_addr, o.spad_addr) io.cmd.bits.rs2 := mvin_cmd_rs2.asUInt() } @@ -216,6 +218,7 @@ class LoopConvLdInputReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: val addr_start = UInt(log2Up(max_acc_addr).W) val dram_addr = UInt(coreMaxAddrBits.W) val downsample = Bool() + val max_pixels_per_row = UInt(small_iterator_bitwidth.W) val input_dilated = Bool() val trans_input_3120 = Bool() val loop_id = UInt(log2Up(concurrent_loops).W) @@ -309,10 +312,12 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw config_cmd_rs1 := DontCare config_cmd_rs1.scale := MVIN_SCALE_IDENTITY config_cmd_rs1.stride := input_spad_stride + config_cmd_rs1.pixel_repeats := req.max_pixels_per_row config_cmd_rs1.state_id := 0.U config_cmd_rs1.shrink := 0.U config_cmd_rs1._unused := 1.U config_cmd.rs1 := config_cmd_rs1.asUInt() + config_cmd.rs2 := dram_stride << req.downsample val mvin_cmd = Wire(new RoCCCommand) @@ -343,7 +348,7 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw mvin_cmd_rs2 := DontCare mvin_cmd_rs2.num_rows := (o.I >> req.downsample).asUInt() mvin_cmd_rs2.num_cols := o.K.asUInt() - mvin_cmd_rs2.local_addr := o.spad_addr.asTypeOf(mvin_cmd_rs2.local_addr) + mvin_cmd_rs2.local_addr := cast_to_local_addr(mvin_cmd_rs2.local_addr, o.spad_addr) io.cmd.bits.rs2 := mvin_cmd_rs2.asUInt() } @@ -388,7 +393,7 @@ class LoopConvLdWeightReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: val outer_bounds = new LoopConvOuterBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) - val addr_end = UInt(log2Up(max_addr).W) + val addr_end = UInt(log2Up(max_addr+1).W) val dram_addr = UInt(coreMaxAddrBits.W) val trans_weight_1203 = Bool() val trans_weight_0132 = Bool() @@ -475,14 +480,17 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit val config_cmd = Wire(new RoCCCommand) config_cmd := DontCare config_cmd.inst.funct := CONFIG_CMD + val config_cmd_rs1 = Wire(config_mvin_rs1_t.cloneType) config_cmd_rs1 := DontCare config_cmd_rs1.scale := MVIN_SCALE_IDENTITY config_cmd_rs1.stride := req.derived_params.weight_spad_stride + config_cmd_rs1.pixel_repeats := 1.U config_cmd_rs1.state_id := 1.U config_cmd_rs1.shrink := 0.U config_cmd_rs1._unused := 1.U config_cmd.rs1 := config_cmd_rs1.asUInt + config_cmd.rs2 := dram_stride val mvin_cmd = Wire(new RoCCCommand) @@ -513,7 +521,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit mvin_cmd_rs2 := DontCare mvin_cmd_rs2.num_rows := o.K mvin_cmd_rs2.num_cols := o.J - mvin_cmd_rs2.local_addr := o.spad_addr.asTypeOf(mvin_cmd_rs2.local_addr) + mvin_cmd_rs2.local_addr := cast_to_local_addr(mvin_cmd_rs2.local_addr, o.spad_addr) io.cmd.bits.rs2 := mvin_cmd_rs2.asUInt() } @@ -556,10 +564,11 @@ class LoopConvExecuteReq(val large_iterator_bitwidth: Int, val small_iterator_bi val inner_bounds = new LoopConvInnerBounds(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) val a_addr_start = UInt(log2Up(max_addr).W) - val b_addr_end = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr+1).W) val c_addr_start = UInt(log2Up(max_acc_addr).W) val wrot180 = Bool() val downsample = Bool() + val max_pixels_per_row = UInt(small_iterator_bitwidth.W) val input_dilated = Bool() val trans_weight_0132 = Bool() val trans_input_3120 = Bool() @@ -622,6 +631,8 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera val skip_iteration = state >= pre && req.input_dilated && (((krow * kernel_dilation +& orow -& upad)(0) & req.input_dilated).asBool() || ((kcol * kernel_dilation +& ocol -& lpad)(0) & req.input_dilated).asBool()) + val pixels = Mux(kcols - kcol > req.max_pixels_per_row, req.max_pixels_per_row, kcols - kcol) + val irow = undilated(orow * stride +& krow * kernel_dilation) val icol = undilated(ocol * stride +& kcol * kernel_dilation) @@ -629,7 +640,7 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera Mux(batches - b > block_size.U, block_size.U, batches - b), undilated(Mux(ocols - ocol > (block_size.U << req.input_dilated).asUInt(), (block_size.U << req.input_dilated).asUInt(), ocols - ocol))) val J = Mux(ochs - och > block_size.U, block_size.U, ochs - och) - val K = Mux(kchs - kch > block_size.U, block_size.U, kchs - kch) + val K = pixels * Mux(kchs - kch > block_size.U, block_size.U, kchs - kch) // Addresses val a_addr = Mux(req.trans_input_3120, @@ -719,13 +730,13 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera pre_cmd_rs1 := DontCare pre_cmd_rs1.num_rows := o.K.asUInt() pre_cmd_rs1.num_cols := o.J.asUInt() - pre_cmd_rs1.local_addr := o.pre_addr.asTypeOf(pre_cmd_rs1.local_addr) + pre_cmd_rs1.local_addr := cast_to_local_addr(pre_cmd_rs1.local_addr, o.pre_addr) val pre_cmd_rs2 = Wire(preload_rs2_t.cloneType) pre_cmd_rs2 := DontCare pre_cmd_rs2.num_rows := o.I.asUInt() pre_cmd_rs2.num_cols := o.J.asUInt() - pre_cmd_rs2.local_addr := o.c_addr.asTypeOf(pre_cmd_rs2.local_addr) + pre_cmd_rs2.local_addr := cast_to_local_addr(pre_cmd_rs2.local_addr, o.c_addr) io.cmd.bits.rs1 := pre_cmd_rs1.asUInt() io.cmd.bits.rs2 := pre_cmd_rs2.asUInt() @@ -735,13 +746,13 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera comp_cmd_rs1 := DontCare comp_cmd_rs1.num_rows := o.I.asUInt() comp_cmd_rs1.num_cols := o.K.asUInt() - comp_cmd_rs1.local_addr := o.a_addr.asTypeOf(comp_cmd_rs1.local_addr) + comp_cmd_rs1.local_addr := cast_to_local_addr(comp_cmd_rs1.local_addr, o.a_addr) val comp_cmd_rs2 = Wire(compute_rs2_t.cloneType) comp_cmd_rs2 := DontCare comp_cmd_rs2.num_rows := o.I.asUInt() comp_cmd_rs2.num_cols := o.J.asUInt() - comp_cmd_rs2.local_addr := GARBAGE_ADDR.asTypeOf(comp_cmd_rs2.local_addr) + comp_cmd_rs2.local_addr := cast_to_local_addr(comp_cmd_rs2.local_addr, GARBAGE_ADDR) io.cmd.bits.rs1 := comp_cmd_rs1.asUInt() io.cmd.bits.rs2 := comp_cmd_rs2.asUInt() @@ -767,7 +778,7 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera val next_b = floorAdd(b, b_it, batches, next_orow === 0.U && next_ocol === 0.U) val next_kch = floorAdd(kch, block_size.U, kchs, next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) - val next_kcol = floorAdd(kcol, 1.U, kcols, + val next_kcol = floorAdd(kcol, req.max_pixels_per_row, kcols, next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) val next_krow = floorAdd(krow, 1.U, krows, next_kcol === 0.U && next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) @@ -967,7 +978,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: val pool_mvout_cmd_rs2 = Wire(mvout_rs2_t.cloneType) pool_mvout_cmd_rs2 := DontCare pool_mvout_cmd_rs2.num_cols := o.channels - pool_mvout_cmd_rs2.local_addr := o.pool_spad_addr.asTypeOf(pool_mvout_cmd_rs2.local_addr) + pool_mvout_cmd_rs2.local_addr := cast_to_local_addr(pool_mvout_cmd_rs2.local_addr, o.pool_spad_addr) io.cmd.bits.rs1 := o.pool_dram_addr io.cmd.bits.rs2 := pool_mvout_cmd_rs2.asUInt() @@ -976,7 +987,7 @@ class LoopConvSt(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwidth: mvout_cmd_rs2 := DontCare mvout_cmd_rs2.num_rows := o.I.asUInt() mvout_cmd_rs2.num_cols := o.J.asUInt() - mvout_cmd_rs2.local_addr := o.spad_addr.asTypeOf(mvout_cmd_rs2.local_addr) + mvout_cmd_rs2.local_addr := cast_to_local_addr(mvout_cmd_rs2.local_addr, o.spad_addr) io.cmd.bits.rs1 := o.dram_addr io.cmd.bits.rs2 := mvout_cmd_rs2.asUInt() @@ -1048,6 +1059,8 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s val trans_weight_0132 = Bool() val trans_input_3120 = Bool() + val max_pixels_per_row = UInt(small_iterator_bitwidth.W) + val configured = Bool() val running = Bool() @@ -1067,7 +1080,7 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s def all_completed(dummy: Int=0): Bool = ld_bias_completed && ld_input_completed && ld_weights_completed && ex_completed && st_completed val a_addr_start = UInt(log2Up(max_addr).W) - val b_addr_end = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr+1).W) def derived_params(dummy: Int=0): LoopConvDerivedParams = { import outer_bounds.{stride, kernel_dilation} @@ -1136,7 +1149,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: I config_mvin_rs1_t: ConfigMvinRs1, mvin_rs2_t: MvinRs2, config_mvout_rs2_t: ConfigMvoutRs2, mvout_rs2_t: MvoutRs2, config_ex_rs1_t: ConfigExRs1, preload_rs1_t: PreloadRs, preload_rs2_t: PreloadRs, compute_rs1_t: ComputeRs, compute_rs2_t: ComputeRs, - has_training_convs: Boolean, has_max_pool: Boolean) + has_training_convs: Boolean, has_max_pool: Boolean, has_first_layer_optimizations: Boolean) (implicit p: Parameters) extends Module { val large_iterator_bitwidth = 16 val small_iterator_bitwidth = 16 // 8 @@ -1288,6 +1301,12 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: I is (LOOP_CONV_WS) { loop_being_configured.no_bias := cmd.bits.rs1(0) + // TODO we added a default value for max_pixels_per_row just to maintain backwards compatibility. we should deprecate and remove it later + val config_max_pixels_per_row = cmd.bits.rs1(15, 8) + loop_being_configured.max_pixels_per_row := Mux( + !has_first_layer_optimizations.B || config_max_pixels_per_row === 0.U, + 1.U, config_max_pixels_per_row) + loop_being_configured.wrot180 := has_training_convs.B && cmd.bits.rs1(1) loop_being_configured.input_dilated := has_training_convs.B && cmd.bits.rs2(2) loop_being_configured.trans_output_1203 := has_training_convs.B && cmd.bits.rs1(2) @@ -1343,6 +1362,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: I ld_input.io.req.bits.addr_start := loop_requesting_ld_input.a_addr_start ld_input.io.req.bits.dram_addr := loop_requesting_ld_input.input_dram_addr ld_input.io.req.bits.downsample := loop_requesting_ld_input.downsample + ld_input.io.req.bits.max_pixels_per_row := loop_requesting_ld_input.max_pixels_per_row ld_input.io.req.bits.input_dilated := loop_requesting_ld_input.input_dilated ld_input.io.req.bits.trans_input_3120 := loop_requesting_ld_input.trans_input_3120 ld_input.io.req.bits.loop_id := loop_requesting_ld_input_id @@ -1382,6 +1402,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: I ex.io.req.bits.c_addr_start := ex_c_addr_start ex.io.req.bits.wrot180 := loop_requesting_ex.wrot180 ex.io.req.bits.downsample := loop_requesting_ex.downsample + ex.io.req.bits.max_pixels_per_row := loop_requesting_ex.max_pixels_per_row ex.io.req.bits.input_dilated := loop_requesting_ex.input_dilated ex.io.req.bits.trans_weight_0132 := loop_requesting_ex.trans_weight_0132 ex.io.req.bits.trans_input_3120 := loop_requesting_ex.trans_input_3120 @@ -1453,7 +1474,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: I loops.zipWithIndex.foreach { case (l, i) => l.reset() l.a_addr_start := (i * (max_addr / concurrent_loops)).U - l.b_addr_end := ((i+1) * (max_addr / concurrent_loops) - block_size).U + l.b_addr_end := ((i+1) * (max_addr / concurrent_loops)).U } } } @@ -1464,13 +1485,14 @@ object LoopConv { max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int, config_mvin_rs1_t: ConfigMvinRs1, mvin_rs2_t: MvinRs2, config_mvout_rs2_t: ConfigMvoutRs2, mvout_rs2_t: MvoutRs2, config_ex_rs1_t: ConfigExRs1, preload_rs1_t: PreloadRs, preload_rs2_t: PreloadRs, - compute_rs1_t: ComputeRs, compute_rs2_t: ComputeRs, has_training_convs: Boolean, has_max_pool: Boolean) + compute_rs1_t: ComputeRs, compute_rs2_t: ComputeRs, has_training_convs: Boolean, has_max_pool: Boolean, + has_first_layer_optimizations: Boolean) (implicit p: Parameters): Tuple2[DecoupledIO[RoCCCommand], Bool] = { val mod = Module(new LoopConv(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts, max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes, config_mvin_rs1_t, mvin_rs2_t, config_mvout_rs2_t, mvout_rs2_t, config_ex_rs1_t, preload_rs1_t, preload_rs2_t, - compute_rs1_t, compute_rs2_t, has_training_convs, has_max_pool)) + compute_rs1_t, compute_rs2_t, has_training_convs, has_max_pool, has_first_layer_optimizations)) mod.io.in <> in mod.io.ld_utilization := ld_utilization diff --git a/src/main/scala/gemmini/LoopMatmul.scala b/src/main/scala/gemmini/LoopMatmul.scala index ea1c3ed6..791b43d5 100644 --- a/src/main/scala/gemmini/LoopMatmul.scala +++ b/src/main/scala/gemmini/LoopMatmul.scala @@ -6,6 +6,7 @@ import chisel3.experimental._ import freechips.rocketchip.tile.RoCCCommand import freechips.rocketchip.config.Parameters import GemminiISA._ +import LocalAddr.cast_to_local_addr import Util._ // LdA @@ -75,7 +76,7 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvin_cmd_rs2 := DontCare mvin_cmd_rs2.num_rows := rows.asUInt() mvin_cmd_rs2.num_cols := cols.asUInt() - mvin_cmd_rs2.local_addr := sp_addr.asTypeOf(mvin_cmd_rs2.local_addr) + mvin_cmd_rs2.local_addr := cast_to_local_addr(mvin_cmd_rs2.local_addr, sp_addr) mvin_cmd.rs2 := mvin_cmd_rs2.asUInt() io.req.ready := state === idle @@ -122,7 +123,7 @@ class LoopMatmulLdBReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat val dram_addr = UInt(coreMaxAddrBits.W) val dram_stride = UInt(coreMaxAddrBits.W) val transpose = Bool() - val addr_end = UInt(log2Up(max_addr).W) + val addr_end = UInt(log2Up(max_addr+1).W) val loop_id = UInt(log2Up(concurrent_loops).W) } @@ -182,7 +183,7 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvin_cmd_rs2 := DontCare mvin_cmd_rs2.num_rows := rows.asUInt() mvin_cmd_rs2.num_cols := cols.asUInt() - mvin_cmd_rs2.local_addr := sp_addr.asTypeOf(mvin_cmd_rs2.local_addr) + mvin_cmd_rs2.local_addr := cast_to_local_addr(mvin_cmd_rs2.local_addr, sp_addr) mvin_cmd.rs2 := mvin_cmd_rs2.asUInt() io.req.ready := state === idle @@ -278,7 +279,7 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvin_cmd_rs2 := DontCare mvin_cmd_rs2.num_rows := rows.asUInt() mvin_cmd_rs2.num_cols := cols.asUInt() - mvin_cmd_rs2.local_addr := sp_addr.asTypeOf(mvin_cmd_rs2.local_addr) + mvin_cmd_rs2.local_addr := cast_to_local_addr(mvin_cmd_rs2.local_addr, sp_addr) mvin_cmd.rs2 := mvin_cmd_rs2.asUInt() io.req.ready := state === idle @@ -325,7 +326,7 @@ class LoopMatmulExecuteReq(val block_size: Int, val coreMaxAddrBits: Int, val it val b_tranpose = Bool() val accumulate = Bool() val a_addr_start = UInt(log2Up(max_addr).W) - val b_addr_end = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr+1).W) val c_addr_start = UInt(log2Up(max_acc_addr).W) val loop_id = UInt(log2Up(concurrent_loops).W) } @@ -405,13 +406,13 @@ class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth pre_cmd_rs1 := DontCare pre_cmd_rs1.num_rows := b_rows.asUInt() pre_cmd_rs1.num_cols := b_cols.asUInt() - pre_cmd_rs1.local_addr := pre_addr.asTypeOf(pre_cmd_rs1.local_addr) + pre_cmd_rs1.local_addr := cast_to_local_addr(pre_cmd_rs1.local_addr, pre_addr) val pre_cmd_rs2 = Wire(preload_rs2_t.cloneType) pre_cmd_rs2 := DontCare pre_cmd_rs2.num_rows := c_rows.asUInt() pre_cmd_rs2.num_cols := c_cols.asUInt() - pre_cmd_rs2.local_addr := out_addr.asTypeOf(pre_cmd_rs2.local_addr) + pre_cmd_rs2.local_addr := cast_to_local_addr(pre_cmd_rs2.local_addr, out_addr) pre_cmd.rs1 := pre_cmd_rs1.asUInt() pre_cmd.rs2 := pre_cmd_rs2.asUInt() @@ -424,13 +425,13 @@ class LoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth comp_cmd_rs1 := DontCare comp_cmd_rs1.num_rows := a_rows.asUInt() comp_cmd_rs1.num_cols := a_cols.asUInt() - comp_cmd_rs1.local_addr := a_addr.asTypeOf(comp_cmd_rs1.local_addr) + comp_cmd_rs1.local_addr := cast_to_local_addr(comp_cmd_rs1.local_addr, a_addr) val comp_cmd_rs2 = Wire(compute_rs2_t.cloneType) comp_cmd_rs2 := DontCare comp_cmd_rs2.num_rows := block_size.U comp_cmd_rs2.num_cols := block_size.U - comp_cmd_rs2.local_addr := GARBAGE_ADDR.asTypeOf(comp_cmd_rs2.local_addr) + comp_cmd_rs2.local_addr := cast_to_local_addr(comp_cmd_rs2.local_addr, GARBAGE_ADDR) comp_cmd.rs1 := comp_cmd_rs1.asUInt() comp_cmd.rs2 := comp_cmd_rs2.asUInt() @@ -545,7 +546,7 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In mvout_cmd_rs2 := DontCare mvout_cmd_rs2.num_rows := rows.asUInt() mvout_cmd_rs2.num_cols := cols.asUInt() - mvout_cmd_rs2.local_addr := sp_addr.asTypeOf(mvout_cmd_rs2.local_addr) + mvout_cmd_rs2.local_addr := cast_to_local_addr(mvout_cmd_rs2.local_addr, sp_addr) mvout_cmd.rs2 := mvout_cmd_rs2.asUInt() io.req.ready := state === idle @@ -636,7 +637,7 @@ class LoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val def all_completed(dummy: Int=0): Bool = lda_completed && ldb_completed && ldd_completed && ex_completed && st_completed val a_addr_start = UInt(log2Up(max_addr).W) - val b_addr_end = UInt(log2Up(max_addr).W) + val b_addr_end = UInt(log2Up(max_addr+1).W) def reset(): Unit = { configured := false.B @@ -958,7 +959,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: loops.zipWithIndex.foreach { case (l, i) => l.reset() l.a_addr_start := (i * (max_addr / concurrent_loops)).U - l.b_addr_end := ((i+1) * (max_addr / concurrent_loops) - block_size).U + l.b_addr_end := ((i+1) * (max_addr / concurrent_loops)).U } } } diff --git a/src/main/scala/gemmini/Mesh.scala b/src/main/scala/gemmini/Mesh.scala index 5bb924c5..cd056658 100644 --- a/src/main/scala/gemmini/Mesh.scala +++ b/src/main/scala/gemmini/Mesh.scala @@ -15,7 +15,7 @@ import chisel3.experimental._ * @param meshColumns */ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, - df: Dataflow.Value, pe_latency: Int, + df: Dataflow.Value, tree_reduction: Boolean, tile_latency: Int, max_simultaneous_matmuls: Int, output_delay: Int, val tileRows: Int, val tileColumns: Int, val meshRows: Int, val meshColumns: Int) extends Module { @@ -34,43 +34,54 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, val out_id = Output(Vec(meshColumns, Vec(tileColumns, UInt(log2Up(max_simultaneous_matmuls).W)))) val out_last = Output(Vec(meshColumns, Vec(tileColumns, Bool()))) }) + // mesh(r)(c) => Tile at row r, column c - val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, pe_latency, max_simultaneous_matmuls, tileRows, tileColumns))) + val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, tree_reduction, max_simultaneous_matmuls, tileRows, tileColumns))) val meshT = mesh.transpose + + def pipe[T <: Data](valid: Bool, t: T, latency: Int): T = { + // The default "Pipe" function apparently resets the valid signals to false.B. We would like to avoid using global + // signals in the Mesh, so over here, we make it clear that the reset signal will never be asserted + chisel3.withReset(false.B) { Pipe(valid, t, latency).bits } + } + // Chain tile_a_out -> tile_a_in (pipeline a across each row) // TODO clock-gate A signals with in_garbage for (r <- 0 until meshRows) { mesh(r).foldLeft(io.in_a(r)) { case (in_a, tile) => - tile.io.in_a := RegNext(in_a) + tile.io.in_a := ShiftRegister(in_a, tile_latency+1) tile.io.out_a } } + // Chain tile_out_b -> tile_b_in (pipeline b across each column) for (c <- 0 until meshColumns) { meshT(c).foldLeft((io.in_b(c), io.in_valid(c))) { case ((in_b, valid), tile) => - tile.io.in_b := RegEnable(in_b, valid.head) + tile.io.in_b := pipe(valid.head, in_b, tile_latency+1) (tile.io.out_b, tile.io.out_valid) } } + // Chain tile_out -> tile_propag (pipeline output across each column) for (c <- 0 until meshColumns) { meshT(c).foldLeft((io.in_d(c), io.in_valid(c))) { case ((in_propag, valid), tile) => - tile.io.in_d := RegEnable(in_propag, valid.head) + tile.io.in_d := pipe(valid.head, in_propag, tile_latency+1) (tile.io.out_c, tile.io.out_valid) } } + // Chain control signals (pipeline across each column) assert(!(mesh.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_))) for (c <- 0 until meshColumns) { meshT(c).foldLeft((io.in_control(c), io.in_valid(c))) { case ((in_ctrl, valid), tile) => (tile.io.in_control, in_ctrl, valid).zipped.foreach { case (tile_ctrl, ctrl, v) => - tile_ctrl.shift := RegEnable(ctrl.shift, v) - tile_ctrl.dataflow := RegEnable(ctrl.dataflow, v) - tile_ctrl.propagate := RegEnable(ctrl.propagate, v) + tile_ctrl.shift := pipe(v, ctrl.shift, tile_latency+1) + tile_ctrl.dataflow := pipe(v, ctrl.dataflow, tile_latency+1) + tile_ctrl.propagate := pipe(v, ctrl.propagate, tile_latency+1) } (tile.io.out_control, tile.io.out_valid) } @@ -80,7 +91,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, for (c <- 0 until meshColumns) { meshT(c).foldLeft(io.in_valid(c)) { case (in_v, tile) => - tile.io.in_valid := RegNext(in_v) + tile.io.in_valid := ShiftRegister(in_v, tile_latency+1) tile.io.out_valid } } @@ -89,7 +100,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, for (c <- 0 until meshColumns) { meshT(c).foldLeft(io.in_id(c)) { case (in_id, tile) => - tile.io.in_id := RegNext(in_id) + tile.io.in_id := ShiftRegister(in_id, tile_latency+1) tile.io.out_id } } @@ -98,7 +109,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, for (c <- 0 until meshColumns) { meshT(c).foldLeft(io.in_last(c)) { case (in_last, tile) => - tile.io.in_last := RegNext(in_last) + tile.io.in_last := ShiftRegister(in_last, tile_latency+1) tile.io.out_last } } diff --git a/src/main/scala/gemmini/MeshWithDelays.scala b/src/main/scala/gemmini/MeshWithDelays.scala index acab135d..db40debf 100644 --- a/src/main/scala/gemmini/MeshWithDelays.scala +++ b/src/main/scala/gemmini/MeshWithDelays.scala @@ -33,7 +33,7 @@ class MeshWithDelaysResp[T <: Data: Arithmetic, TagT <: TagQueueTag with Data](o class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] (inputType: T, val outputType: T, accType: T, - tagType: U, df: Dataflow.Value, pe_latency: Int, output_delay: Int, + tagType: U, df: Dataflow.Value, tree_reduction: Boolean, tile_latency: Int, output_delay: Int, tileRows: Int, tileColumns: Int, meshRows: Int, meshColumns: Int, leftBanks: Int, upBanks: Int, outBanks: Int = 1, n_simultaneous_matmuls: Int = -1) extends Module { @@ -47,12 +47,13 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] assert(meshRows*tileRows == meshColumns*tileColumns) val block_size = meshRows*tileRows + val latency_per_pe = (tile_latency + 1).toFloat / (tileRows min tileColumns) val max_simultaneous_matmuls = if (n_simultaneous_matmuls == -1) { - 5 * (pe_latency + 1) + (5 * latency_per_pe).ceil.toInt } else { n_simultaneous_matmuls } - assert(max_simultaneous_matmuls >= 5 * (pe_latency + 1)) + assert(max_simultaneous_matmuls >= 5 * latency_per_pe) val tagqlen = max_simultaneous_matmuls+1 @@ -70,7 +71,6 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] def shifted[T <: Data](x: Vec[Vec[T]], banks: Int, reverse: Boolean = false) = { assert(x.size % banks == 0, "cannot bank without clean divisors") - assert(pe_latency == 0 || (tileRows == 1 && tileColumns == 1), "If tiles are larger than 1x1, then PEs must have 0 latency") val banked_len = x.size / banks val banked_x = x.grouped(banked_len).toSeq @@ -79,13 +79,13 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] (banked_x zip indexes).flatMap { case (bx, i) => val bxVec = VecInit(bx) - val sram_shift = i * banked_len * (pe_latency+1) + val sram_shift = i * banked_len * (tile_latency+1) val SRAMShifted = Shifter(bxVec, sram_shift, true.B, true) val indexes = if (reverse) SRAMShifted.indices.reverse else SRAMShifted.indices val RegShifted = (SRAMShifted zip indexes).map { case (srs, j) => - ShiftRegister(srs, j*(pe_latency+1)) + ShiftRegister(srs, j*(tile_latency+1)) } RegShifted @@ -166,25 +166,25 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] val transposer_out = VecInit(transposer.io.outCol.bits.grouped(tileRows).map(t => VecInit(t)).toSeq) // Wire up mesh's IO to this module's IO - val mesh = Module(new Mesh(inputType, outputType, accType, df, pe_latency, max_simultaneous_matmuls, output_delay, tileRows, tileColumns, meshRows, meshColumns)) + val mesh = Module(new Mesh(inputType, outputType, accType, df, tree_reduction, tile_latency, max_simultaneous_matmuls, output_delay, tileRows, tileColumns, meshRows, meshColumns)) // TODO wire only to *_buf here, instead of io.*.bits - val a_shifter_in = WireInit(Mux(a_is_from_transposer, transposer_out, a_buf)) - val b_shifter_in = WireInit(Mux(b_is_from_transposer, transposer_out, b_buf)) + val a_shifter_in = WireInit(Mux(a_is_from_transposer, transposer_out.asTypeOf(A_TYPE), a_buf)) + val b_shifter_in = WireInit(Mux(b_is_from_transposer, transposer_out.asTypeOf(B_TYPE), b_buf)) val d_shifter_in = WireInit(Mux(d_is_from_transposer, - VecInit(transposer_out.flatten.reverse.grouped(tileRows).map(VecInit(_)).toSeq), d_buf)) + VecInit(transposer_out.flatten.reverse.grouped(tileRows).map(VecInit(_)).toSeq).asTypeOf(D_TYPE), d_buf)) mesh.io.in_a := shifted(a_shifter_in, leftBanks) mesh.io.in_b := shifted(b_shifter_in, upBanks) mesh.io.in_d := shifted(d_shifter_in, upBanks) mesh.io.in_control.zipWithIndex.foreach { case (ss, i) => - ss.foreach(_.dataflow := ShiftRegister(req.bits.pe_control.dataflow, i * (pe_latency + 1))) - ss.foreach(_.propagate := ShiftRegister(in_prop, i * (pe_latency + 1))) + ss.foreach(_.dataflow := ShiftRegister(req.bits.pe_control.dataflow, i * (tile_latency + 1))) + ss.foreach(_.propagate := ShiftRegister(in_prop, i * (tile_latency + 1))) } val result_shift = RegNext(req.bits.pe_control.shift) // TODO will this arrive at the right time if memory isn't pipelined? mesh.io.in_control.zipWithIndex.foreach { case (ctrl, i) => - ctrl.foreach(_.shift := ShiftRegister(result_shift, i * (pe_latency + 1))) + ctrl.foreach(_.shift := ShiftRegister(result_shift, i * (tile_latency + 1))) } val not_paused_vec = VecInit(Seq.fill(meshColumns)(VecInit(Seq.fill(tileColumns)(!pause)))) @@ -198,8 +198,7 @@ class MeshWithDelays[T <: Data: Arithmetic, U <: TagQueueTag with Data] // We want to output C when we're output-stationary, but B when we're weight-stationary // TODO these would actually overlap when we switch from output-stationary to weight-stationary - val out_pe_control = shifted(mesh.io.out_control, outBanks, reverse = true)(0)(0) - io.resp.bits.data := shifted(Mux(out_pe_control.dataflow === Dataflow.OS.id.U, mesh.io.out_c, mesh.io.out_b), outBanks, true) + io.resp.bits.data := shifted(Mux(mesh.io.out_control(0)(0).dataflow === Dataflow.OS.id.U, mesh.io.out_c, mesh.io.out_b), outBanks, true) io.resp.valid := shifted(mesh.io.out_valid, outBanks, reverse = true)(0)(0) diff --git a/src/main/scala/gemmini/PE.scala b/src/main/scala/gemmini/PE.scala index 79944b72..e10318a3 100644 --- a/src/main/scala/gemmini/PE.scala +++ b/src/main/scala/gemmini/PE.scala @@ -17,7 +17,7 @@ class PEControl[T <: Data : Arithmetic](accType: T) extends Bundle { * A PE implementing a MAC operation. Configured as fully combinational when integrated into a Mesh. * @param width Data width of operands */ -class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, latency: Int, max_simultaneous_matmuls: Int) +class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, max_simultaneous_matmuls: Int) (implicit ev: Arithmetic[T]) extends Module { // Debugging variables import ev._ @@ -46,17 +46,17 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val cType = if (df == Dataflow.WS) inputType else accType - val a = ShiftRegister(io.in_a, latency) - val b = ShiftRegister(io.in_b, latency) - val d = ShiftRegister(io.in_d, latency) + val a = io.in_a + val b = io.in_b + val d = io.in_d val c1 = Reg(cType) val c2 = Reg(cType) - val dataflow = ShiftRegister(io.in_control.dataflow, latency) - val prop = ShiftRegister(io.in_control.propagate, latency) - val shift = ShiftRegister(io.in_control.shift, latency) - val id = ShiftRegister(io.in_id, latency) - val last = ShiftRegister(io.in_last, latency) - val valid = ShiftRegister(io.in_valid, latency) // TODO should we clockgate the rest of the ShiftRegisters based on the values in this ShiftRegisters + val dataflow = io.in_control.dataflow + val prop = io.in_control.propagate + val shift = io.in_control.shift + val id = io.in_id + val last = io.in_last + val valid = io.in_valid io.out_a := a io.out_control.dataflow := dataflow diff --git a/src/main/scala/gemmini/PixelRepeater.scala b/src/main/scala/gemmini/PixelRepeater.scala new file mode 100644 index 00000000..0413304e --- /dev/null +++ b/src/main/scala/gemmini/PixelRepeater.scala @@ -0,0 +1,95 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +import Util._ + +class PixelRepeaterReq[T <: Data, Tag <: Data](t: T, laddr_t: LocalAddr, block_cols: Int, tag_t: Tag) extends Bundle { + val in: Vec[T] = Vec(block_cols, t.cloneType) + val mask: Vec[Bool] = Vec(block_cols, Bool()) + val laddr: LocalAddr = laddr_t.cloneType + val len: UInt = UInt(log2Up(block_cols+1).W) // TODO magic number + val pixel_repeats: UInt = UInt(8.W) // TODO magic number + val last: Bool = Bool() + val tag: Tag = tag_t.cloneType + + assert(block_cols <= 255, "len must be longer") + + override def cloneType: PixelRepeaterReq.this.type = new PixelRepeaterReq(t, laddr_t, block_cols, tag_t).asInstanceOf[this.type] +} + +class PixelRepeaterResp[T <: Data, Tag <: Data](t: T, laddr_t: LocalAddr, block_cols: Int, tag_t: Tag) extends Bundle { + val out: Vec[T] = Vec(block_cols, t.cloneType) + val mask: Vec[Bool] = Vec(block_cols, Bool()) + val laddr: LocalAddr = laddr_t.cloneType + val last: Bool = Bool() + val tag: Tag = tag_t.cloneType + + override def cloneType: PixelRepeaterResp.this.type = new PixelRepeaterResp(t, laddr_t, block_cols, tag_t).asInstanceOf[this.type] +} + +class PixelRepeater[T <: Data, Tag <: Data](t: T, laddr_t: LocalAddr, block_cols: Int, aligned_to: Int, tag_t: Tag, passthrough: Boolean) extends Module { + val io = IO(new Bundle { + val req = Flipped(Decoupled(new PixelRepeaterReq(t, laddr_t, block_cols, tag_t))) + val resp = Decoupled(new PixelRepeaterResp(t, laddr_t, block_cols, tag_t)) + }) + + if (passthrough) { + io.resp.valid := io.req.valid + io.resp.bits.out := io.req.bits.in + io.resp.bits.mask := io.req.bits.mask + io.resp.bits.laddr := io.req.bits.laddr + io.resp.bits.last := io.req.bits.last + io.resp.bits.tag := io.req.bits.tag + + io.req.ready := io.resp.ready + } else { + val req = Reg(UDValid(io.req.bits.cloneType)) + + io.req.ready := !req.valid || (io.resp.ready && req.bits.pixel_repeats === 0.U) + + val out_shift = Wire(UInt(log2Up(block_cols / 2 + 1).W)) + out_shift := req.bits.pixel_repeats * req.bits.len + + io.resp.bits.out := (req.bits.in.asUInt() << (out_shift * t.getWidth.U)).asTypeOf(io.resp.bits.out) + io.resp.bits.mask := (req.bits.mask.asUInt() << (out_shift * ((t.getWidth / 8) / aligned_to).U)).asTypeOf(io.resp.bits.mask) + + io.resp.bits.last := req.bits.last && (req.bits.pixel_repeats === 0.U) + io.resp.bits.tag := req.bits.tag + + val is_acc_addr = req.bits.laddr.is_acc_addr + assert(!(req.valid && is_acc_addr && req.bits.pixel_repeats > 0.U)) + + val sp_addr = Mux(req.bits.laddr.full_sp_addr() < (laddr_t.spRows / 2).U, + req.bits.laddr.floorSub(req.bits.pixel_repeats, 0.U)._1, + req.bits.laddr.floorSub(req.bits.pixel_repeats, (laddr_t.spRows / 2).U)._1, + ) + + val underflow = !is_acc_addr && Mux(req.bits.laddr.full_sp_addr() < (laddr_t.spRows / 2).U, + req.bits.laddr.floorSub(req.bits.pixel_repeats, 0.U)._2, + req.bits.laddr.floorSub(req.bits.pixel_repeats, (laddr_t.spRows / 2).U)._2, + ) + + io.resp.bits.laddr := Mux(is_acc_addr, req.bits.laddr, sp_addr) + + io.resp.valid := req.valid && !underflow + + when(io.resp.fire() || underflow) { + req.bits.pixel_repeats := req.bits.pixel_repeats - 1.U + + when(req.bits.pixel_repeats === 0.U) { + req.pop() + } + } + + when(io.req.fire()) { + req.push(io.req.bits) + req.bits.pixel_repeats := io.req.bits.pixel_repeats - 1.U + } + + when(reset.toBool()) { + req.pop() + } + } +} diff --git a/src/main/scala/gemmini/ReservationStation.scala b/src/main/scala/gemmini/ReservationStation.scala index 929685f6..7135969f 100644 --- a/src/main/scala/gemmini/ReservationStation.scala +++ b/src/main/scala/gemmini/ReservationStation.scala @@ -115,14 +115,14 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val solitary_preload = utilization === 1.U && entries.map(e => e.valid && e.bits.cmd.inst.funct === PRELOAD_CMD).reduce(_ || _) io.busy := !empty && !(solitary_preload && io.solitary_preload) - // Config values set by programmer - val a_stride = Reg(UInt(16.W)) // TODO magic numbers - val c_stride = Reg(UInt(16.W)) // TODO magic numbers + val a_stride = Reg(UInt(a_stride_bits.W)) + val c_stride = Reg(UInt(c_stride_bits.W)) val a_transpose = Reg(Bool()) val ld_block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) val st_block_stride = block_rows.U val pooling_is_enabled = Reg(Bool()) + val ld_pixel_repeats = Reg(Vec(load_states, UInt(pixel_repeats_bits.W))) // This is the ld_pixel_repeat MINUS ONE val new_entry = Wire(new Entry) new_entry := DontCare @@ -245,6 +245,7 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val id = MuxCase(0.U, Seq((new_entry.cmd.inst.funct === LOAD2_CMD) -> 1.U, (new_entry.cmd.inst.funct === LOAD3_CMD) -> 2.U)) val block_stride = ld_block_strides(id) + val pixel_repeats = ld_pixel_repeats(id) val mvin_cols = cmd.rs2(32 + mvin_cols_bits - 1, 32) val mvin_rows = cmd.rs2(48 + mvin_rows_bits - 1, 48) @@ -252,6 +253,18 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G val mvin_mats = mvin_cols / block_cols.U + (mvin_cols % block_cols.U =/= 0.U) val total_mvin_rows = ((mvin_mats - 1.U) * block_stride) + mvin_rows + // TODO We have to know how the LoopConv's internals work here. Our abstractions are leaking + if (has_first_layer_optimizations) { + val start = cmd.rs2(31, 0).asTypeOf(local_addr_t) + // TODO instead of using a floor-sub that's hardcoded to the Scratchpad bank boundaries, we should find some way of letting the programmer specify the start address + dst.bits.start := Mux(start.is_acc_addr, start, + Mux(start.full_sp_addr() > (local_addr_t.spRows / 2).U, + start.floorSub(pixel_repeats, (local_addr_t.spRows / 2).U)._1, + start.floorSub(pixel_repeats, 0.U)._1, + ) + ) + } + dst.bits.end := dst.bits.start + total_mvin_rows dst.bits.wraps_around := dst.bits.start.add_with_overflow(total_mvin_rows)._2 } @@ -365,7 +378,9 @@ class ReservationStation[T <: Data : Arithmetic, U <: Data, V <: Data](config: G }.elsewhen(new_entry.is_config && new_entry.q === ldq) { val id = new_entry.cmd.rs1(4,3) // TODO magic numbers val block_stride = new_entry.cmd.rs1(31, 16) // TODO magic numbers + val repeat_pixels = maxOf(new_entry.cmd.rs1(8 + pixel_repeats_bits - 1, 8), 1.U) // TODO we use a default value of pixel repeats here, for backwards compatibility. However, we should deprecate and remove this default value eventually ld_block_strides(id) := block_stride + ld_pixel_repeats(id) := repeat_pixels - 1.U }.elsewhen(new_entry.is_config && new_entry.q === stq) { val pool_stride = new_entry.cmd.rs1(5, 4) // TODO magic numbers pooling_is_enabled := pool_stride =/= 0.U diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index e3289b7f..764b5d5a 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -6,12 +6,11 @@ import freechips.rocketchip.config.Parameters import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp} import freechips.rocketchip.rocket._ import freechips.rocketchip.tile._ -import freechips.rocketchip.tilelink.{TLIdentityNode, TLXbar} +import freechips.rocketchip.tilelink.{TLIdentityNode, TLXbar, TLBuffer} import Util._ -class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits: Int) - (implicit p: Parameters) extends CoreBundle { +class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits: Int)(implicit p: Parameters) extends CoreBundle { val vaddr = UInt(coreMaxAddrBits.W) val laddr = local_addr_t.cloneType @@ -21,6 +20,7 @@ class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits: val has_acc_bitwidth = Bool() val all_zeros = Bool() val block_stride = UInt(16.W) // TODO magic numbers + val pixel_repeats = UInt(8.W) // TODO magic numbers val cmd_id = UInt(8.W) // TODO don't use a magic number here val status = new MStatus @@ -57,15 +57,13 @@ class ScratchpadMemReadResponse extends Bundle { val cmd_id = UInt(8.W) // TODO don't use a magic number here } -class ScratchpadReadMemIO[U <: Data](local_addr_t: LocalAddr, scale_t_bits: Int) - (implicit p: Parameters) extends CoreBundle { +class ScratchpadReadMemIO[U <: Data](local_addr_t: LocalAddr, scale_t_bits: Int)(implicit p: Parameters) extends CoreBundle { val req = Decoupled(new ScratchpadMemReadRequest(local_addr_t, scale_t_bits)) val resp = Flipped(Valid(new ScratchpadMemReadResponse)) override def cloneType: this.type = new ScratchpadReadMemIO(local_addr_t, scale_t_bits).asInstanceOf[this.type] } -// class ScratchpadWriteMemIO(val nBanks: Int, val nRows: Int, val acc_rows: Int) class ScratchpadWriteMemIO(local_addr_t: LocalAddr, scale_t_bits: Int) (implicit p: Parameters) extends CoreBundle { val req = Decoupled(new ScratchpadMemWriteRequest(local_addr_t, scale_t_bits)) @@ -96,7 +94,7 @@ class ScratchpadWriteIO(val n: Int, val w: Int, val mask_len: Int) extends Bundl val data = Output(UInt(w.W)) } -class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean) extends Module { +class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean, use_shared_ext_mem: Boolean) extends Module { // This is essentially a pipelined SRAM with the ability to stall pipeline stages require(w % aligned_to == 0 || w < aligned_to) @@ -106,27 +104,50 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean) ex val io = IO(new Bundle { val read = Flipped(new ScratchpadReadIO(n, w)) val write = Flipped(new ScratchpadWriteIO(n, w, mask_len)) + val ext_mem = if (use_shared_ext_mem) Some(new ExtMemIO) else None }) - val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) + val (read, write) = if (use_shared_ext_mem) { + def read(addr: UInt, ren: Bool): Data = { + io.ext_mem.get.read_en := ren + io.ext_mem.get.read_addr := addr + io.ext_mem.get.read_data + } + io.ext_mem.get.write_en := false.B + io.ext_mem.get.write_addr := DontCare + io.ext_mem.get.write_data := DontCare + io.ext_mem.get.write_mask := DontCare + def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = { + io.ext_mem.get.write_en := true.B + io.ext_mem.get.write_addr := addr + io.ext_mem.get.write_data := wdata.asUInt + io.ext_mem.get.write_mask := wmask.asUInt + } + (read _, write _) + } else { + val mem = SyncReadMem(n, Vec(mask_len, mask_elem)) + def read(addr: UInt, ren: Bool): Data = mem.read(addr, ren) + def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = mem.write(addr, wdata, wmask) + (read _, write _) + } // When the scratchpad is single-ported, the writes take precedence val singleport_busy_with_write = single_ported.B && io.write.en when (io.write.en) { if (aligned_to >= w) - mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem))) + write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), VecInit((~(0.U(mask_len.W))).asBools)) else - mem.write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), io.write.mask) + write(io.write.addr, io.write.data.asTypeOf(Vec(mask_len, mask_elem)), io.write.mask) } val raddr = io.read.req.bits.addr val ren = io.read.req.fire() val rdata = if (single_ported) { assert(!(ren && io.write.en)) - mem.read(raddr, ren && !io.write.en).asUInt() + read(raddr, ren && !io.write.en).asUInt() } else { - mem.read(raddr, ren).asUInt() + read(raddr, ren).asUInt() } val fromDMA = io.read.req.bits.fromDMA @@ -143,6 +164,7 @@ class ScratchpadBank(n: Int, w: Int, aligned_to: Int, single_ported: Boolean) ex io.read.resp <> q.io.deq } + class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, U, V]) (implicit p: Parameters, ev: Arithmetic[T]) extends LazyModule { @@ -171,9 +193,9 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // id_node :=* reader.node // id_node :=* writer.node - xbar_node := reader.node // TODO - xbar_node := writer.node - id_node := xbar_node + xbar_node := TLBuffer() := reader.node // TODO + xbar_node := TLBuffer() := writer.node + id_node := TLBuffer() := xbar_node lazy val module = new LazyModuleImp(this) with HasCoreParameters { @@ -204,6 +226,12 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, )))) } + val ext_mem = if (use_shared_ext_mem) { + Some(new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks)) + } else { + None + } + // TLB ports val tlb = Vec(2, new FrontendTLBIO) @@ -229,7 +257,6 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, write_issue_q.io.enq.valid := false.B write_issue_q.io.enq.bits := write_scale_q.io.deq.bits - // Garbage can immediately fire between dispatch_q and scale_q when (write_dispatch_q.bits.laddr.is_garbage()) { write_scale_q.io.enq <> write_dispatch_q @@ -239,7 +266,6 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, write_issue_q.io.enq <> write_scale_q.io.deq } - val writeData = Wire(Valid(UInt((spad_w max acc_w).W))) writeData.valid := write_issue_q.io.deq.bits.laddr.is_garbage() writeData.bits := DontCare @@ -285,7 +311,20 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, zero_writer.io.req.bits.block_stride := io.dma.read.req.bits.block_stride zero_writer.io.req.bits.tag := io.dma.read.req.bits - zero_writer.io.resp.ready := false.B + // zero_writer.io.resp.ready := false.B + + val zero_writer_pixel_repeater = Module(new PixelRepeater(inputType, local_addr_t, block_cols, aligned_to, new ScratchpadMemReadRequest(local_addr_t, mvin_scale_t_bits), passthrough = !has_first_layer_optimizations)) + zero_writer_pixel_repeater.io.req.valid := zero_writer.io.resp.valid + zero_writer_pixel_repeater.io.req.bits.in := 0.U.asTypeOf(Vec(block_cols, inputType)) + zero_writer_pixel_repeater.io.req.bits.mask := zero_writer.io.resp.bits.mask + zero_writer_pixel_repeater.io.req.bits.laddr := zero_writer.io.resp.bits.laddr + zero_writer_pixel_repeater.io.req.bits.len := zero_writer.io.resp.bits.tag.cols + zero_writer_pixel_repeater.io.req.bits.pixel_repeats := zero_writer.io.resp.bits.tag.pixel_repeats + zero_writer_pixel_repeater.io.req.bits.last := zero_writer.io.resp.bits.last + zero_writer_pixel_repeater.io.req.bits.tag := zero_writer.io.resp.bits.tag + + zero_writer.io.resp.ready := zero_writer_pixel_repeater.io.req.ready + zero_writer_pixel_repeater.io.resp.ready := false.B reader.module.io.req.valid := read_issue_q.io.deq.valid read_issue_q.io.deq.ready := reader.module.io.req.ready @@ -294,6 +333,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, read_issue_q.io.deq.bits.laddr.full_acc_addr(), read_issue_q.io.deq.bits.laddr.full_sp_addr()) reader.module.io.req.bits.len := read_issue_q.io.deq.bits.cols reader.module.io.req.bits.repeats := read_issue_q.io.deq.bits.repeats + reader.module.io.req.bits.pixel_repeats := read_issue_q.io.deq.bits.pixel_repeats reader.module.io.req.bits.scale := read_issue_q.io.deq.bits.scale reader.module.io.req.bits.is_acc := read_issue_q.io.deq.bits.laddr.is_acc_addr reader.module.io.req.bits.accumulate := read_issue_q.io.deq.bits.laddr.accumulate @@ -321,10 +361,22 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, mvin_scale_in.bits.in := reader.module.io.resp.bits.data.asTypeOf(chiselTypeOf(mvin_scale_in.bits.in)) mvin_scale_in.bits.scale := reader.module.io.resp.bits.scale.asTypeOf(mvin_scale_t) mvin_scale_in.bits.repeats := reader.module.io.resp.bits.repeats + mvin_scale_in.bits.pixel_repeats := reader.module.io.resp.bits.pixel_repeats mvin_scale_in.bits.last := reader.module.io.resp.bits.last mvin_scale_in.bits.tag := reader.module.io.resp.bits - mvin_scale_out.ready := false.B + val mvin_scale_pixel_repeater = Module(new PixelRepeater(inputType, local_addr_t, block_cols, aligned_to, mvin_scale_out.bits.tag.cloneType, passthrough = !has_first_layer_optimizations)) + mvin_scale_pixel_repeater.io.req.valid := mvin_scale_out.valid + mvin_scale_pixel_repeater.io.req.bits.in := mvin_scale_out.bits.out + mvin_scale_pixel_repeater.io.req.bits.mask := mvin_scale_out.bits.tag.mask take mvin_scale_pixel_repeater.io.req.bits.mask.size + mvin_scale_pixel_repeater.io.req.bits.laddr := mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row + mvin_scale_pixel_repeater.io.req.bits.len := mvin_scale_out.bits.tag.len + mvin_scale_pixel_repeater.io.req.bits.pixel_repeats := mvin_scale_out.bits.tag.pixel_repeats + mvin_scale_pixel_repeater.io.req.bits.last := mvin_scale_out.bits.last + mvin_scale_pixel_repeater.io.req.bits.tag := mvin_scale_out.bits.tag + + mvin_scale_out.ready := mvin_scale_pixel_repeater.io.req.ready + mvin_scale_pixel_repeater.io.resp.ready := false.B if (!mvin_scale_shared) { mvin_scale_acc_in.valid := reader.module.io.resp.valid && @@ -332,6 +384,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, mvin_scale_acc_in.bits.in := reader.module.io.resp.bits.data.asTypeOf(chiselTypeOf(mvin_scale_acc_in.bits.in)) mvin_scale_acc_in.bits.scale := reader.module.io.resp.bits.scale.asTypeOf(mvin_scale_acc_t) mvin_scale_acc_in.bits.repeats := reader.module.io.resp.bits.repeats + mvin_scale_acc_in.bits.pixel_repeats := 1.U mvin_scale_acc_in.bits.last := reader.module.io.resp.bits.last mvin_scale_acc_in.bits.tag := reader.module.io.resp.bits @@ -341,23 +394,33 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, reader.module.io.resp.ready := Mux(reader.module.io.resp.bits.is_acc && reader.module.io.resp.bits.has_acc_bitwidth, mvin_scale_acc_in.ready, mvin_scale_in.ready) - val mvin_scale_finished = mvin_scale_out.fire() && mvin_scale_out.bits.last + // val mvin_scale_finished = mvin_scale_out.fire() && mvin_scale_out.bits.last + val mvin_scale_finished = mvin_scale_pixel_repeater.io.resp.fire() && mvin_scale_pixel_repeater.io.resp.bits.last val mvin_scale_acc_finished = mvin_scale_acc_out.fire() && mvin_scale_acc_out.bits.last - val zero_writer_finished = zero_writer.io.resp.fire() && zero_writer.io.resp.bits.last + // val zero_writer_finished = zero_writer.io.resp.fire() && zero_writer.io.resp.bits.last + val zero_writer_finished = zero_writer_pixel_repeater.io.resp.fire() && zero_writer_pixel_repeater.io.resp.bits.last + /* val zero_writer_bytes_read = Mux(zero_writer.io.resp.bits.laddr.is_acc_addr, zero_writer.io.resp.bits.tag.cols * (accType.getWidth / 8).U, zero_writer.io.resp.bits.tag.cols * (inputType.getWidth / 8).U) + */ + val zero_writer_bytes_read = Mux(zero_writer_pixel_repeater.io.resp.bits.laddr.is_acc_addr, + zero_writer_pixel_repeater.io.resp.bits.tag.cols * (accType.getWidth / 8).U, + zero_writer_pixel_repeater.io.resp.bits.tag.cols * (inputType.getWidth / 8).U) // For DMA read responses, mvin_scale gets first priority, then mvin_scale_acc, and then zero_writer io.dma.read.resp.valid := mvin_scale_finished || mvin_scale_acc_finished || zero_writer_finished - io.dma.read.resp.bits.cmd_id := MuxCase(zero_writer.io.resp.bits.tag.cmd_id, Seq( - mvin_scale_finished -> mvin_scale_out.bits.tag.cmd_id, + // io.dma.read.resp.bits.cmd_id := MuxCase(zero_writer.io.resp.bits.tag.cmd_id, Seq( + io.dma.read.resp.bits.cmd_id := MuxCase(zero_writer_pixel_repeater.io.resp.bits.tag.cmd_id, Seq( + // mvin_scale_finished -> mvin_scale_out.bits.tag.cmd_id, + mvin_scale_finished -> mvin_scale_pixel_repeater.io.resp.bits.tag.cmd_id, mvin_scale_acc_finished -> mvin_scale_acc_out.bits.tag.cmd_id)) io.dma.read.resp.bits.bytesRead := MuxCase(zero_writer_bytes_read, Seq( - mvin_scale_finished -> mvin_scale_out.bits.tag.bytes_read, + // mvin_scale_finished -> mvin_scale_out.bits.tag.bytes_read, + mvin_scale_finished -> mvin_scale_pixel_repeater.io.resp.bits.tag.bytes_read, mvin_scale_acc_finished -> mvin_scale_acc_out.bits.tag.bytes_read)) io.tlb(0) <> writer.module.io.tlb @@ -368,12 +431,19 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, io.busy := writer.module.io.busy || reader.module.io.busy || write_issue_q.io.deq.valid || write_scale_q.io.deq.valid || write_dispatch_q.valid - { - val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank(sp_bank_entries, spad_w, aligned_to, config.sp_singleported)) } + val spad_mems = { + val banks = Seq.fill(sp_banks) { Module(new ScratchpadBank( + sp_bank_entries, spad_w, + aligned_to, config.sp_singleported, + use_shared_ext_mem + )) } val bank_ios = VecInit(banks.map(_.io)) - // Reading from the SRAM banks bank_ios.zipWithIndex.foreach { case (bio, i) => + if (use_shared_ext_mem) { + io.ext_mem.get.spad(i) <> bio.ext_mem.get + } + val ex_read_req = io.srams.read(i).req val exread = ex_read_req.valid @@ -414,7 +484,6 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, val dma_read_pipe = Pipeline(dma_read_resp, spad_read_delay) val ex_read_pipe = Pipeline(ex_read_resp, spad_read_delay) - bio.read.resp.ready := Mux(bio.read.resp.bits.fromDMA, dma_read_resp.ready, ex_read_resp.ready) dma_read_pipe.ready := writer.module.io.req.ready && @@ -432,16 +501,21 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bank_ios.zipWithIndex.foreach { case (bio, i) => val exwrite = io.srams.write(i).en - val laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row + // val laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row + val laddr = mvin_scale_pixel_repeater.io.resp.bits.laddr - val dmaread = mvin_scale_out.valid && !mvin_scale_out.bits.tag.is_acc && + // val dmaread = mvin_scale_out.valid && !mvin_scale_out.bits.tag.is_acc && + val dmaread = mvin_scale_pixel_repeater.io.resp.valid && !mvin_scale_pixel_repeater.io.resp.bits.tag.is_acc && laddr.sp_bank() === i.U // We need to make sure that we don't try to return a dma read resp from both zero_writer and either mvin_scale // or mvin_acc_scale at the same time. The scalers always get priority in those cases - val zerowrite = zero_writer.io.resp.valid && !zero_writer.io.resp.bits.laddr.is_acc_addr && - zero_writer.io.resp.bits.laddr.sp_bank() === i.U && - !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + /* val zerowrite = zero_writer.io.resp.valid && !zero_writer.io.resp.bits.laddr.is_acc_addr && + zero_writer.io.resp.bits.laddr.sp_bank() === i.U && */ + val zerowrite = zero_writer_pixel_repeater.io.resp.valid && !zero_writer_pixel_repeater.io.resp.bits.laddr.is_acc_addr && + zero_writer_pixel_repeater.io.resp.bits.laddr.sp_bank() === i.U && + // !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + !((mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) bio.write.en := exwrite || dmaread || zerowrite @@ -451,27 +525,34 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.write.mask := io.srams.write(i).mask }.elsewhen (dmaread) { bio.write.addr := laddr.sp_row() - bio.write.data := mvin_scale_out.bits.out.asUInt() - bio.write.mask := mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) + // bio.write.data := mvin_scale_out.bits.out.asUInt() + // bio.write.mask := mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) + bio.write.data := mvin_scale_pixel_repeater.io.resp.bits.out.asUInt() + bio.write.mask := mvin_scale_pixel_repeater.io.resp.bits.mask take ((spad_w / (aligned_to * 8)) max 1) - mvin_scale_out.ready := true.B // TODO we combinationally couple valid and ready signals + // mvin_scale_out.ready := true.B // TODO we combinationally couple valid and ready signals + mvin_scale_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals }.elsewhen (zerowrite) { - bio.write.addr := zero_writer.io.resp.bits.laddr.sp_row() + // bio.write.addr := zero_writer.io.resp.bits.laddr.sp_row() + bio.write.addr := zero_writer_pixel_repeater.io.resp.bits.laddr.sp_row() bio.write.data := 0.U bio.write.mask := { val n = inputType.getWidth / 8 - val mask = zero_writer.io.resp.bits.mask + // val mask = zero_writer.io.resp.bits.mask + val mask = zero_writer_pixel_repeater.io.resp.bits.mask val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) expanded } - zero_writer.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals + // zero_writer.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals + zero_writer_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals }.otherwise { bio.write.addr := DontCare bio.write.data := DontCare bio.write.mask := DontCare } } + banks } val acc_row_t = Vec(meshColumns, Vec(tileColumns, accType)) @@ -513,11 +594,14 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, } } - { + val acc_adders = Module(new AccPipeShared(acc_latency-1, acc_row_t, acc_banks)) + val acc_mems = { val banks = Seq.fill(acc_banks) { Module(new AccumulatorMem( acc_bank_entries, acc_row_t, acc_scale_func, acc_scale_t.asInstanceOf[V], - acc_singleported, acc_sub_banks + acc_singleported, acc_sub_banks, + use_shared_ext_mem, + acc_latency, accType, )) } val bank_ios = VecInit(banks.map(_.io)) @@ -526,6 +610,15 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // Reading from the Accumulator banks bank_ios.zipWithIndex.foreach { case (bio, i) => + if (use_shared_ext_mem) { + io.ext_mem.get.acc(i) <> bio.ext_mem.get + } + + acc_adders.io.in_sel(i) := bio.adder.valid + acc_adders.io.ina(i) := bio.adder.op1 + acc_adders.io.inb(i) := bio.adder.op2 + bio.adder.sum := acc_adders.io.out + val ex_read_req = io.acc.read_req(i) val exread = ex_read_req.valid @@ -590,10 +683,12 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, io.acc.write(i).ready := true.B assert(!(exwrite && !bio.write.ready), "Execute controller write to AccumulatorMem was skipped") - val from_mvin_scale = mvin_scale_out.valid && mvin_scale_out.bits.tag.is_acc + // val from_mvin_scale = mvin_scale_out.valid && mvin_scale_out.bits.tag.is_acc + val from_mvin_scale = mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.tag.is_acc val from_mvin_scale_acc = mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.tag.is_acc - val mvin_scale_laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row + // val mvin_scale_laddr = mvin_scale_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_out.bits.row + val mvin_scale_laddr = mvin_scale_pixel_repeater.io.resp.bits.laddr val mvin_scale_acc_laddr = mvin_scale_acc_out.bits.tag.addr.asTypeOf(local_addr_t) + mvin_scale_acc_out.bits.row val dmaread_bank = Mux(from_mvin_scale, mvin_scale_laddr.acc_bank(), @@ -602,7 +697,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // We need to make sure that we don't try to return a dma read resp from both mvin_scale and mvin_scale_acc // at the same time. mvin_scale always gets priority in this cases - val spad_last = mvin_scale_out.valid && mvin_scale_out.bits.last && !mvin_scale_out.bits.tag.is_acc + // val spad_last = mvin_scale_out.valid && mvin_scale_out.bits.last && !mvin_scale_out.bits.tag.is_acc + val spad_last = mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.last && !mvin_scale_pixel_repeater.io.resp.bits.tag.is_acc val dmaread = (from_mvin_scale || from_mvin_scale_acc) && dmaread_bank === i.U /* && @@ -610,9 +706,13 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, // We need to make sure that we don't try to return a dma read resp from both zero_writer and either mvin_scale // or mvin_acc_scale at the same time. The scalers always get priority in those cases - val zerowrite = zero_writer.io.resp.valid && zero_writer.io.resp.bits.laddr.is_acc_addr && - zero_writer.io.resp.bits.laddr.acc_bank() === i.U && - !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + /* val zerowrite = zero_writer.io.resp.valid && zero_writer.io.resp.bits.laddr.is_acc_addr && + zero_writer.io.resp.bits.laddr.acc_bank() === i.U && */ + val zerowrite = zero_writer_pixel_repeater.io.resp.valid && zero_writer_pixel_repeater.io.resp.bits.laddr.is_acc_addr && + zero_writer_pixel_repeater.io.resp.bits.laddr.acc_bank() === i.U && + // !((mvin_scale_out.valid && mvin_scale_out.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + !((mvin_scale_pixel_repeater.io.resp.valid && mvin_scale_pixel_repeater.io.resp.bits.last) || (mvin_scale_acc_out.valid && mvin_scale_acc_out.bits.last)) + val consecutive_write_block = RegInit(false.B) if (acc_singleported) { val consecutive_write_sub_bank = RegInit(0.U((1 max log2Ceil(acc_sub_banks)).W)) @@ -628,12 +728,15 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, } bio.write.valid := false.B - bio.write.bits.acc := MuxCase(zero_writer.io.resp.bits.laddr.accumulate, + // bio.write.bits.acc := MuxCase(zero_writer.io.resp.bits.laddr.accumulate, + bio.write.bits.acc := MuxCase(zero_writer_pixel_repeater.io.resp.bits.laddr.accumulate, Seq(exwrite -> io.acc.write(i).bits.acc, - from_mvin_scale -> mvin_scale_out.bits.tag.accumulate, + // from_mvin_scale -> mvin_scale_out.bits.tag.accumulate, + from_mvin_scale -> mvin_scale_pixel_repeater.io.resp.bits.tag.accumulate, from_mvin_scale_acc -> mvin_scale_acc_out.bits.tag.accumulate)) - bio.write.bits.addr := MuxCase(zero_writer.io.resp.bits.laddr.acc_row(), + // bio.write.bits.addr := MuxCase(zero_writer.io.resp.bits.laddr.acc_row(), + bio.write.bits.addr := MuxCase(zero_writer_pixel_repeater.io.resp.bits.laddr.acc_row(), Seq(exwrite -> io.acc.write(i).bits.addr, (from_mvin_scale || from_mvin_scale_acc) -> dmaread_row)) @@ -644,20 +747,23 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, }.elsewhen (dmaread && !spad_last && !consecutive_write_block) { bio.write.valid := true.B bio.write.bits.data := Mux(from_mvin_scale, - VecInit(mvin_scale_out.bits.out.map(e => e.withWidthOf(accType))).asTypeOf(acc_row_t), + // VecInit(mvin_scale_out.bits.out.map(e => e.withWidthOf(accType))).asTypeOf(acc_row_t), + VecInit(mvin_scale_pixel_repeater.io.resp.bits.out.map(e => e.withWidthOf(accType))).asTypeOf(acc_row_t), mvin_scale_acc_out.bits.out.asTypeOf(acc_row_t)) bio.write.bits.mask := Mux(from_mvin_scale, { val n = accType.getWidth / inputType.getWidth - val mask = mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) + // val mask = mvin_scale_out.bits.tag.mask take ((spad_w / (aligned_to * 8)) max 1) + val mask = mvin_scale_pixel_repeater.io.resp.bits.mask take ((spad_w / (aligned_to * 8)) max 1) val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) expanded }, mvin_scale_acc_out.bits.tag.mask) when(from_mvin_scale) { - mvin_scale_out.ready := bio.write.ready + // mvin_scale_out.ready := bio.write.ready + mvin_scale_pixel_repeater.io.resp.ready := bio.write.ready }.otherwise { mvin_scale_acc_out.ready := bio.write.ready } @@ -666,17 +772,20 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, bio.write.bits.data := 0.U.asTypeOf(acc_row_t) bio.write.bits.mask := { val n = accType.getWidth / 8 - val mask = zero_writer.io.resp.bits.mask + // val mask = zero_writer.io.resp.bits.mask + val mask = zero_writer_pixel_repeater.io.resp.bits.mask val expanded = VecInit(mask.flatMap(e => Seq.fill(n)(e))) expanded } - zero_writer.io.resp.ready := bio.write.ready + // zero_writer.io.resp.ready := bio.write.ready + zero_writer_pixel_repeater.io.resp.ready := bio.write.ready }.otherwise { bio.write.bits.data := DontCare bio.write.bits.mask := DontCare } } + banks } // Counter connection diff --git a/src/main/scala/gemmini/SharedExtMem.scala b/src/main/scala/gemmini/SharedExtMem.scala new file mode 100644 index 00000000..9d0e1802 --- /dev/null +++ b/src/main/scala/gemmini/SharedExtMem.scala @@ -0,0 +1,80 @@ +package gemmini + +import chisel3._ +import chisel3.util._ + +import Util._ + + +class ExtMemIO extends Bundle { + val read_en = Output(Bool()) + val read_addr = Output(UInt()) + val read_data = Input(UInt()) + + val write_en = Output(Bool()) + val write_addr = Output(UInt()) + val write_data = Output(UInt()) + val write_mask = Output(UInt()) +} + +class ExtSpadMemIO(sp_banks: Int, acc_banks: Int, acc_sub_banks: Int) extends Bundle { + val spad = Vec(sp_banks, new ExtMemIO) + val acc = Vec(acc_banks, Vec(acc_sub_banks, new ExtMemIO)) + override def cloneType: this.type = new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks).asInstanceOf[this.type] +} + + +class SharedSyncReadMem(nSharers: Int, depth: Int, mask_len: Int, data_len: Int) extends Module { + val io = IO(new Bundle { + val in = Vec(nSharers, Flipped(new ExtMemIO())) + }) + val mem = SyncReadMem(depth, Vec(mask_len, UInt(data_len.W))) + val wens = io.in.map(_.write_en) + val wen = wens.reduce(_||_) + val waddr = Mux1H(wens, io.in.map(_.write_addr)) + val wmask = Mux1H(wens, io.in.map(_.write_mask)) + val wdata = Mux1H(wens, io.in.map(_.write_data)) + assert(PopCount(wens) <= 1.U) + val rens = io.in.map(_.read_en) + assert(PopCount(rens) <= 1.U) + val ren = rens.reduce(_||_) + val raddr = Mux1H(rens, io.in.map(_.read_addr)) + val rdata = mem.read(raddr, ren && !wen) + io.in.foreach(_.read_data := rdata.asUInt) + when (wen) { + mem.write(waddr, wdata.asTypeOf(Vec(mask_len, UInt(data_len.W))), wmask.asTypeOf(Vec(mask_len, Bool()))) + } + +} + +class SharedExtMem( + sp_banks: Int, acc_banks: Int, acc_sub_banks: Int, + sp_depth: Int, sp_mask_len: Int, sp_data_len: Int, + acc_depth: Int, acc_mask_len: Int, acc_data_len: Int +) extends Module { + val nSharers = 2 + val io = IO(new Bundle { + val in = Vec(nSharers, Flipped(new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks))) + }) + for (i <- 0 until sp_banks) { + val spad_mem = Module(new SharedSyncReadMem(nSharers, sp_depth, sp_mask_len, sp_data_len)) + for (w <- 0 until nSharers) { + spad_mem.io.in(w) <> io.in(w).spad(i) + } + } + for (i <- 0 until acc_banks) { + for (s <- 0 until acc_sub_banks) { + val acc_mem = Module(new SharedSyncReadMem(nSharers, acc_depth, acc_mask_len, acc_data_len)) + + acc_mem.io.in(0) <> io.in(0).acc(i)(s) + // The FP gemmini expects a taller, skinnier accumulator mem + acc_mem.io.in(1) <> io.in(1).acc(i)(s) + acc_mem.io.in(1).read_addr := io.in(1).acc(i)(s).read_addr >> 1 + io.in(1).acc(i)(s).read_data := acc_mem.io.in(1).read_data.asTypeOf(Vec(2, UInt((acc_data_len * acc_mask_len / 2).W)))(RegNext(io.in(1).acc(i)(s).read_addr(0))) + + acc_mem.io.in(1).write_addr := io.in(1).acc(i)(s).write_addr >> 1 + acc_mem.io.in(1).write_data := Cat(io.in(1).acc(i)(s).write_data, io.in(1).acc(i)(s).write_data) + acc_mem.io.in(1).write_mask := Mux(io.in(1).acc(i)(s).write_addr(0), io.in(1).acc(i)(s).write_mask << (acc_mask_len / 2), io.in(1).acc(i)(s).write_mask) + } + } +} diff --git a/src/main/scala/gemmini/Tile.scala b/src/main/scala/gemmini/Tile.scala index 59807893..9c2a418c 100644 --- a/src/main/scala/gemmini/Tile.scala +++ b/src/main/scala/gemmini/Tile.scala @@ -4,6 +4,7 @@ package gemmini import chisel3._ import chisel3.util._ +import Util._ /** * A Tile is a purely combinational 2D array of passThrough PEs. @@ -12,7 +13,7 @@ import chisel3.util._ * @param rows Number of PEs on each row * @param columns Number of PEs on each column */ -class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: Dataflow.Value, pe_latency: Int, max_simultaneous_matmuls: Int, val rows: Int, val columns: Int) extends Module { +class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, tree_reduction: Boolean, max_simultaneous_matmuls: Int, val rows: Int, val columns: Int)(implicit ev: Arithmetic[T]) extends Module { val io = IO(new Bundle { val in_a = Input(Vec(rows, inputType)) val in_b = Input(Vec(columns, outputType)) // This is the output of the tile next to it @@ -32,11 +33,13 @@ class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: val in_valid = Input(Vec(columns, Bool())) val out_valid = Output(Vec(columns, Bool())) - + val bad_dataflow = Output(Bool()) }) - val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, pe_latency, max_simultaneous_matmuls))) + import ev._ + + val tile = Seq.fill(rows, columns)(Module(new PE(inputType, outputType, accType, df, max_simultaneous_matmuls))) val tileT = tile.transpose // TODO: abstract hori/vert broadcast, all these connections look the same @@ -53,7 +56,7 @@ class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: for (c <- 0 until columns) { tileT(c).foldLeft(io.in_b(c)) { case (in_b, pe) => - pe.io.in_b := in_b + pe.io.in_b := (if (tree_reduction) in_b.zero else in_b) pe.io.out_b } } @@ -106,11 +109,19 @@ class Tile[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, df: // Drive the Tile's bottom IO for (c <- 0 until columns) { io.out_c(c) := tile(rows-1)(c).io.out_c - io.out_b(c) := tile(rows-1)(c).io.out_b io.out_control(c) := tile(rows-1)(c).io.out_control io.out_id(c) := tile(rows-1)(c).io.out_id io.out_last(c) := tile(rows-1)(c).io.out_last io.out_valid(c) := tile(rows-1)(c).io.out_valid + + io.out_b(c) := { + if (tree_reduction) { + val prods = tileT(c).map(_.io.out_b) + accumulateTree(prods :+ io.in_b(c)) + } else { + tile(rows - 1)(c).io.out_b + } + } } io.bad_dataflow := tile.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_) diff --git a/src/main/scala/gemmini/Util.scala b/src/main/scala/gemmini/Util.scala index 511cfee2..907c4ad2 100644 --- a/src/main/scala/gemmini/Util.scala +++ b/src/main/scala/gemmini/Util.scala @@ -109,6 +109,22 @@ object Util { Mux(u1 < u2, u1, u2) } + def accumulateTree[T <: Data](xs: Seq[T])(implicit ev: Arithmetic[T]): T = { + import ev._ + + assert(xs.nonEmpty, "can't accumulate 0 elements") + + if (xs.length == 1) { + xs.head + } else { + val upperRowLen = 1 << log2Ceil(xs.length) + val upperRow = xs.padTo(upperRowLen, xs.head.zero) + val pairs = upperRow.grouped(2) + val lowerRow = pairs.map { case Seq(a, b) => a + b } + accumulateTree(lowerRow.toSeq) + } + } + // An undirectioned Valid bundle class UDValid[T <: Data](t: T) extends Bundle { val valid = Bool() diff --git a/src/main/scala/gemmini/VectorScalarMultiplier.scala b/src/main/scala/gemmini/VectorScalarMultiplier.scala index d1cefcb3..7cb8c14f 100644 --- a/src/main/scala/gemmini/VectorScalarMultiplier.scala +++ b/src/main/scala/gemmini/VectorScalarMultiplier.scala @@ -9,6 +9,7 @@ class VectorScalarMultiplierReq[T <: Data, U <: Data, Tag <: Data](block_cols: I val in: Vec[T] = Vec(block_cols, t.cloneType) val scale: U = u.cloneType val repeats: UInt = UInt(16.W) // TODO magic number + val pixel_repeats: UInt = UInt(8.W) // TODO magic number val last: Bool = Bool() val tag: Tag = tag_t.cloneType @@ -81,7 +82,6 @@ class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data]( in.valid := false.B } - if (num_scale_units == -1) { val pipe = Module(new Pipeline( new VectorScalarMultiplierResp(block_cols, t, tag_t), @@ -120,7 +120,7 @@ class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data]( head_oh := (head_oh << 1) | head_oh(nEntries-1) } in_fire := (in.valid && - (!Mux1H(tail_oh.asBools, regs.map(_.valid)) || (tail_oh === head_oh && io.resp.fire())) + (!Mux1H(tail_oh.asBools, regs.map(_.valid))) ) when (in_fire) { for (i <- 0 until nEntries) { @@ -144,8 +144,6 @@ class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data]( tail_oh := (tail_oh << 1) | tail_oh(nEntries-1) } - - val inputs = Seq.fill(width*nEntries) { Wire(Decoupled(new DataWithIndex(t, u))) } for (i <- 0 until nEntries) { for (w <- 0 until width) { @@ -172,7 +170,6 @@ class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data]( arbOut.valid := false.B } - val pipe = Module(new ScalePipe(t, mvin_scale_args.get)) pipe.io.in := arbOut val pipe_out = pipe.io.out @@ -187,14 +184,11 @@ class VectorScalarMultiplier[T <: Data, U <: Data, Tag <: Data]( } } } + when (reset.asBool) { regs.foreach(_.valid := false.B) } - - } - - } object VectorScalarMultiplier { diff --git a/src/main/scala/gemmini/XactTracker.scala b/src/main/scala/gemmini/XactTracker.scala index e8581a26..84821d4e 100644 --- a/src/main/scala/gemmini/XactTracker.scala +++ b/src/main/scala/gemmini/XactTracker.scala @@ -15,6 +15,8 @@ class XactTrackerEntry[U <: Data](maxShift: Int, spadWidth: Int, accWidth: Int, val has_acc_bitwidth = Bool() val scale = UInt(mvin_scale_t_bits.W) val repeats = UInt(16.W) // TODO magic number + val pixel_repeats = UInt(8.W) // TODO magic number + val len = UInt(16.W) // TODO magic number val block_stride = UInt(16.W) // TODO magic number val spad_row_offset = UInt(log2Up(spadWidth max accWidth).W) val lg_len_req = UInt(log2Up(log2Up(maxReqBytes+1)+1).W)