From 47fe7bee364409aadab4c07159eeb9831cfdaef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20L=C3=B6her?= Date: Mon, 16 Jan 2023 10:30:13 +0100 Subject: [PATCH] Add step 12 in top. --- 12_size_optimisation/bench.py | 61 +++++++++++ 12_size_optimisation/blink.py | 38 +++++++ 12_size_optimisation/cpu.py | 180 +++++++++++++++++++++++++++++++++ 12_size_optimisation/memory.py | 36 +++++++ 12_size_optimisation/soc.py | 82 +++++++++++++++ boards/top.py | 2 + 6 files changed, 399 insertions(+) create mode 100644 12_size_optimisation/bench.py create mode 100644 12_size_optimisation/blink.py create mode 100644 12_size_optimisation/cpu.py create mode 100644 12_size_optimisation/memory.py create mode 100644 12_size_optimisation/soc.py diff --git a/12_size_optimisation/bench.py b/12_size_optimisation/bench.py new file mode 100644 index 0000000..42c3221 --- /dev/null +++ b/12_size_optimisation/bench.py @@ -0,0 +1,61 @@ +from amaranth import * +from amaranth.sim import * + +from soc import SOC + +soc = SOC() + +sim = Simulator(soc) + +prev_clk = 0 + +def proc(): + cpu = soc.cpu + mem = soc.memory + while True: + global prev_clk + clk = yield soc.slow_clk + if prev_clk == 0 and prev_clk != clk: + state = (yield soc.cpu.fsm.state) + if state == 2: + print("-- NEW CYCLE -----------------------") + print(" F: LEDS = {:05b}".format((yield soc.leds))) + print(" F: pc={}".format((yield cpu.pc))) + print(" F: instr={:#032b}".format((yield cpu.instr))) + if (yield cpu.isALUreg): + print(" ALUreg rd={} rs1={} rs2={} funct3={}".format( + (yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.rs2Id), + (yield cpu.funct3))) + if (yield cpu.isALUimm): + print(" ALUimm rd={} rs1={} imm={} funct3={}".format( + (yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.Iimm), + (yield cpu.funct3))) + if (yield cpu.isBranch): + print(" BRANCH rs1={} rs2={}".format( + (yield cpu.rs1Id), (yield cpu.rs2Id))) + if (yield cpu.isLoad): + print(" LOAD") + if (yield cpu.isStore): + print(" STORE") + if (yield cpu.isSystem): + print(" SYSTEM") + break + if state == 4: + print(" R: LEDS = {:05b}".format((yield soc.leds))) + print(" R: rs1={}".format((yield cpu.rs1))) + print(" R: rs2={}".format((yield cpu.rs2))) + if state == 1: + print(" E: LEDS = {:05b}".format((yield soc.leds))) + print(" E: Writeback x{} = {:032b}".format((yield cpu.rdId), + (yield cpu.writeBackData))) + if state == 8: + print(" NEW") + yield + prev_clk = clk + +sim.add_clock(1e-6) +sim.add_sync_process(proc) + +with sim.write_vcd('bench.vcd', 'bench.gtkw', traces=soc.ports): + # Let's run for a quite long time + sim.run_until(2, ) diff --git a/12_size_optimisation/blink.py b/12_size_optimisation/blink.py new file mode 100644 index 0000000..518fef0 --- /dev/null +++ b/12_size_optimisation/blink.py @@ -0,0 +1,38 @@ +from amaranth import * +from amaranth_boards.arty_a7 import * + +from soc import SOC + +# A platform contains board specific information about FPGA pin assignments, +# toolchain and specific information for uploading the bitfile. +platform = ArtyA7_35Platform(toolchain="Symbiflow") + +# We need a top level module +m = Module() + +# This is the instance of our SOC +soc = SOC() + +# The SOC is turned into a submodule (fragment) of our top level module. +m.submodules.soc = soc + +# The platform allows access to the various resources defined by the board +# definition from amaranth-boards. +led0 = platform.request('led', 0) +led1 = platform.request('led', 1) +led2 = platform.request('led', 2) +led3 = platform.request('led', 3) +rgb = platform.request('rgb_led') + +# We connect the SOC leds signal to the various LEDs on the board. +m.d.comb += [ + led0.o.eq(soc.leds[0]), + led1.o.eq(soc.leds[1]), + led1.o.eq(soc.leds[2]), + led1.o.eq(soc.leds[3]), + rgb.r.o.eq(soc.leds[4]), +] + +# To generate the bitstream, we build() the platform using our top level +# module m. +platform.build(m, do_program=False) diff --git a/12_size_optimisation/cpu.py b/12_size_optimisation/cpu.py new file mode 100644 index 0000000..7355e7a --- /dev/null +++ b/12_size_optimisation/cpu.py @@ -0,0 +1,180 @@ +from amaranth import * + +class CPU(Elaboratable): + + def __init__(self): + self.mem_addr = Signal(32) + self.mem_rstrb = Signal() + self.mem_rdata = Signal(32) + self.x1 = Signal(32) + self.fsm = None + + def elaborate(self, platform): + m = Module() + + # Program counter + pc = Signal(32) + self.pc = pc + + # Current instruction + instr = Signal(32, reset=0b0110011) + self.instr = instr + + # Register bank + regs = Array([Signal(32, name="x"+str(x)) for x in range(32)]) + rs1 = Signal(32) + rs2 = Signal(32) + + # ALU registers + aluOut = Signal(32) + takeBranch = Signal(32) + + # Opcode decoder + isALUreg = (instr[0:7] == 0b0110011) + isALUimm = (instr[0:7] == 0b0010011) + isBranch = (instr[0:7] == 0b1100011) + isJALR = (instr[0:7] == 0b1100111) + isJAL = (instr[0:7] == 0b1101111) + isAUIPC = (instr[0:7] == 0b0010111) + isLUI = (instr[0:7] == 0b0110111) + isLoad = (instr[0:7] == 0b0000011) + isStore = (instr[0:7] == 0b0100011) + isSystem = (instr[0:7] == 0b1110011) + self.isALUreg = isALUreg + self.isALUimm = isALUimm + self.isBranch = isBranch + self.isLoad = isLoad + self.isStore = isStore + self.isSystem = isSystem + + def Extend(x, n): + return [x for i in range(n + 1)] + + # Immediate format decoder + Uimm = Cat(Const(0, 12), instr[12:32]) + Iimm = Cat(instr[20:31], *Extend(instr[31], 21)) + Simm = Cat(instr[7:12], instr[25:31], *Extend(instr[31], 21)) + Bimm = Cat(0, instr[8:12], instr[25:31], instr[7], + *Extend(instr[31], 20)) + Jimm = Cat(0, instr[21:31], instr[20], instr[12:20], + *Extend(instr[31], 12)) + self.Iimm = Iimm + + # Register addresses decoder + rs1Id = instr[15:20] + rs2Id = instr[20:25] + rdId = instr[7:12] + + self.rdId = rdId + self.rs1Id = rs1Id + self.rs2Id = rs2Id + + # Function code decdore + funct3 = instr[12:15] + funct7 = instr[25:32] + self.funct3 = funct3 + + # ALU + aluIn1 = rs1 + aluIn2 = Mux((isALUreg | isBranch), rs2, Iimm) + shamt = Mux(isALUreg, rs2[0:5], instr[20:25]) + + # Wire memory address to pc + m.d.comb += self.mem_addr.eq(pc) + + aluMinus = Cat(~aluIn1, C(0,1)) + Cat(aluIn2, C(0,1)) + 1 + aluPlus = aluIn1 + aluIn2 + + EQ = aluMinus[0:32] == 0 + LTU = aluMinus[32] + LT = Mux((aluIn1[31] ^ aluIn2[31]), aluIn1[31], aluMinus[32]) + + def flip32(x): + a = [x[i] for i in range(0, 32)] + return Cat(*reversed(a)) + + # TODO: check these again! + shifter_in = Mux(funct3 == 0b001, flip32(aluIn1), aluIn1) + shifter = Cat(shifter_in, (instr[30] & aluIn1[31])) >> aluIn2[0:5] + leftshift = flip32(shifter) + + with m.Switch(funct3) as alu: + with m.Case(0b000): + m.d.comb += aluOut.eq(Mux(funct7[5] & instr[5], + aluMinus[0:32], aluPlus)) + with m.Case(0b001): + m.d.comb += aluOut.eq(leftshift) + with m.Case(0b010): + m.d.comb += aluOut.eq(LT) + with m.Case(0b011): + m.d.comb += aluOut.eq(LTU) + with m.Case(0b100): + m.d.comb += aluOut.eq(aluIn1 ^ aluIn2) + with m.Case(0b101): + m.d.comb += aluOut.eq(shifter) + with m.Case(0b110): + m.d.comb += aluOut.eq(aluIn1 | aluIn2) + with m.Case(0b111): + m.d.comb += aluOut.eq(aluIn1 & aluIn2) + + with m.Switch(funct3) as alu_branch: + with m.Case(0b000): + m.d.comb += takeBranch.eq(EQ) + with m.Case(0b001): + m.d.comb += takeBranch.eq(~EQ) + with m.Case(0b100): + m.d.comb += takeBranch.eq(LT) + with m.Case(0b101): + m.d.comb += takeBranch.eq(~LT) + with m.Case(0b110): + m.d.comb += takeBranch.eq(LTU) + with m.Case(0b111): + m.d.comb += takeBranch.eq(~LTU) + with m.Case("---"): + m.d.comb += takeBranch.eq(0) + + # Next program counter is either next intstruction or depends on + # jump target + pcPlusImm = pc + Mux(instr[3], Jimm[0:32], + Mux(instr[4], Uimm[0:32], + Bimm[0:32])) + pcPlus4 = pc + 4 + + nextPc = Mux(((isBranch & takeBranch) | isJAL), pcPlusImm, + Mux(isJALR, Cat(C(0, 1), aluPlus[1:32]), + pcPlus4)) + + # Main state machine + with m.FSM(reset="FETCH_INSTR") as fsm: + self.fsm = fsm + m.d.comb += self.mem_rstrb.eq(fsm.ongoing("FETCH_INSTR")) + with m.State("FETCH_INSTR"): + m.next = "WAIT_INSTR" + with m.State("WAIT_INSTR"): + m.d.sync += instr.eq(self.mem_rdata) + m.next = ("FETCH_REGS") + with m.State("FETCH_REGS"): + m.d.sync += [ + rs1.eq(regs[rs1Id]), + rs2.eq(regs[rs2Id]) + ] + m.next = "EXECUTE" + with m.State("EXECUTE"): + m.d.sync += pc.eq(nextPc) + m.next = "FETCH_INSTR" + + # Register write back + writeBackData = Mux((isJAL | isJALR), pcPlus4, + Mux(isLUI, Uimm, + Mux(isAUIPC, pcPlusImm, aluOut))) + + writeBackEn = fsm.ongoing("EXECUTE") & ~isBranch & ~isStore + + self.writeBackData = writeBackData + + with m.If(writeBackEn & (rdId != 0)): + m.d.sync += regs[rdId].eq(writeBackData) + # Also assign to debug output to see what is happening + m.d.sync += self.x1.eq(writeBackData) + + return m diff --git a/12_size_optimisation/memory.py b/12_size_optimisation/memory.py new file mode 100644 index 0000000..b31a225 --- /dev/null +++ b/12_size_optimisation/memory.py @@ -0,0 +1,36 @@ +from amaranth import * +from riscv_assembler import RiscvAssembler + +class Memory(Elaboratable): + + def __init__(self): + a = RiscvAssembler() + + a.read("""begin: + ADD x1, x0, x0 + ADDI x2, x0, 31 + l0: + ADDI x1, x1, 1 + BNE x1, x2, l0 + EBREAK + """) + + a.assemble() + self.instructions = a.mem + print("memory = {}".format(self.instructions)) + + # Instruction memory initialised with above instructions + self.mem = Array([Signal(32, reset=x, name="mem") + for x in self.instructions]) + + self.mem_addr = Signal(32) + self.mem_rdata = Signal(32) + self.mem_rstrb = Signal() + + def elaborate(self, platform): + m = Module() + + with m.If(self.mem_rstrb): + m.d.sync += self.mem_rdata.eq(self.mem[self.mem_addr[2:32]]) + + return m diff --git a/12_size_optimisation/soc.py b/12_size_optimisation/soc.py new file mode 100644 index 0000000..2ddf7af --- /dev/null +++ b/12_size_optimisation/soc.py @@ -0,0 +1,82 @@ +import sys +from amaranth import * + +from clockworks import Clockworks +from memory import Memory +from cpu import CPU + +class SOC(Elaboratable): + + def __init__(self): + + self.leds = Signal(5) + + # Signals in this list can easily be plotted as vcd traces + self.ports = [] + + def elaborate(self, platform): + + m = Module() + cw = Clockworks(slow=19, sim_slow=10) + memory = DomainRenamer("slow")(Memory()) + cpu = DomainRenamer("slow")(CPU()) + m.submodules.cw = cw + m.submodules.cpu = cpu + m.submodules.memory = memory + + self.cpu = cpu + self.memory = memory + + x1 = Signal(32) + + # Connect memory to CPU + m.d.comb += [ + memory.mem_addr.eq(cpu.mem_addr), + memory.mem_rstrb.eq(cpu.mem_rstrb), + cpu.mem_rdata.eq(memory.mem_rdata) + ] + + # CPU debug output + m.d.comb += [ + x1.eq(cpu.x1), + self.leds.eq(x1[0:5]) + ] + + # Export signals for simulation + def export(signal, name): + if type(signal) is not Signal: + newsig = Signal(signal.shape(), name = name) + m.d.comb += newsig.eq(signal) + else: + newsig = signal + self.ports.append(newsig) + setattr(self, name, newsig) + + if platform is None: + export(ClockSignal("slow"), "slow_clk") + #export(pc, "pc") + #export(instr, "instr") + #export(isALUreg, "isALUreg") + #export(isALUimm, "isALUimm") + #export(isBranch, "isBranch") + #export(isJAL, "isJAL") + #export(isJALR, "isJALR") + #export(isLoad, "isLoad") + #export(isStore, "isStore") + #export(isSystem, "isSystem") + #export(rdId, "rdId") + #export(rs1Id, "rs1Id") + #export(rs2Id, "rs2Id") + #export(Iimm, "Iimm") + #export(Bimm, "Bimm") + #export(Jimm, "Jimm") + #export(funct3, "funct3") + #export(rdId, "rdId") + #export(rs1, "rs1") + #export(rs2, "rs2") + #export(writeBackData, "writeBackData") + #export(writeBackEn, "writeBackEn") + #export(aluOut, "aluOut") + #export((1 << cpu.fsm.state), "state") + + return m diff --git a/boards/top.py b/boards/top.py index 986f0ca..40c6cc1 100644 --- a/boards/top.py +++ b/boards/top.py @@ -33,6 +33,8 @@ class Top(Elaboratable): path = "10_lui_auipc" elif step == 11: path = "11_modules" + elif step == 12: + path = "12_size_optimisation" else: print("Invalid step_number {}.".format(step)) exit(1)