From b70af98ee8d433cf22f9b6ff1a90a9c0afb10dc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20L=C3=B6her?= Date: Sun, 5 Feb 2023 03:32:54 +0100 Subject: [PATCH] Add step 15: load. --- 15_load/bench.py | 61 +++++++++++ 15_load/blink.py | 38 +++++++ 15_load/cpu.py | 254 ++++++++++++++++++++++++++++++++++++++++++++++ 15_load/memory.py | 58 +++++++++++ 15_load/soc.py | 82 +++++++++++++++ boards/top.py | 2 + 6 files changed, 495 insertions(+) create mode 100644 15_load/bench.py create mode 100644 15_load/blink.py create mode 100644 15_load/cpu.py create mode 100644 15_load/memory.py create mode 100644 15_load/soc.py diff --git a/15_load/bench.py b/15_load/bench.py new file mode 100644 index 0000000..42c3221 --- /dev/null +++ b/15_load/bench.py @@ -0,0 +1,61 @@ +from amaranth import * +from amaranth.sim import * + +from soc import SOC + +soc = SOC() + +sim = Simulator(soc) + +prev_clk = 0 + +def proc(): + cpu = soc.cpu + mem = soc.memory + while True: + global prev_clk + clk = yield soc.slow_clk + if prev_clk == 0 and prev_clk != clk: + state = (yield soc.cpu.fsm.state) + if state == 2: + print("-- NEW CYCLE -----------------------") + print(" F: LEDS = {:05b}".format((yield soc.leds))) + print(" F: pc={}".format((yield cpu.pc))) + print(" F: instr={:#032b}".format((yield cpu.instr))) + if (yield cpu.isALUreg): + print(" ALUreg rd={} rs1={} rs2={} funct3={}".format( + (yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.rs2Id), + (yield cpu.funct3))) + if (yield cpu.isALUimm): + print(" ALUimm rd={} rs1={} imm={} funct3={}".format( + (yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.Iimm), + (yield cpu.funct3))) + if (yield cpu.isBranch): + print(" BRANCH rs1={} rs2={}".format( + (yield cpu.rs1Id), (yield cpu.rs2Id))) + if (yield cpu.isLoad): + print(" LOAD") + if (yield cpu.isStore): + print(" STORE") + if (yield cpu.isSystem): + print(" SYSTEM") + break + if state == 4: + print(" R: LEDS = {:05b}".format((yield soc.leds))) + print(" R: rs1={}".format((yield cpu.rs1))) + print(" R: rs2={}".format((yield cpu.rs2))) + if state == 1: + print(" E: LEDS = {:05b}".format((yield soc.leds))) + print(" E: Writeback x{} = {:032b}".format((yield cpu.rdId), + (yield cpu.writeBackData))) + if state == 8: + print(" NEW") + yield + prev_clk = clk + +sim.add_clock(1e-6) +sim.add_sync_process(proc) + +with sim.write_vcd('bench.vcd', 'bench.gtkw', traces=soc.ports): + # Let's run for a quite long time + sim.run_until(2, ) diff --git a/15_load/blink.py b/15_load/blink.py new file mode 100644 index 0000000..518fef0 --- /dev/null +++ b/15_load/blink.py @@ -0,0 +1,38 @@ +from amaranth import * +from amaranth_boards.arty_a7 import * + +from soc import SOC + +# A platform contains board specific information about FPGA pin assignments, +# toolchain and specific information for uploading the bitfile. +platform = ArtyA7_35Platform(toolchain="Symbiflow") + +# We need a top level module +m = Module() + +# This is the instance of our SOC +soc = SOC() + +# The SOC is turned into a submodule (fragment) of our top level module. +m.submodules.soc = soc + +# The platform allows access to the various resources defined by the board +# definition from amaranth-boards. +led0 = platform.request('led', 0) +led1 = platform.request('led', 1) +led2 = platform.request('led', 2) +led3 = platform.request('led', 3) +rgb = platform.request('rgb_led') + +# We connect the SOC leds signal to the various LEDs on the board. +m.d.comb += [ + led0.o.eq(soc.leds[0]), + led1.o.eq(soc.leds[1]), + led1.o.eq(soc.leds[2]), + led1.o.eq(soc.leds[3]), + rgb.r.o.eq(soc.leds[4]), +] + +# To generate the bitstream, we build() the platform using our top level +# module m. +platform.build(m, do_program=False) diff --git a/15_load/cpu.py b/15_load/cpu.py new file mode 100644 index 0000000..1551864 --- /dev/null +++ b/15_load/cpu.py @@ -0,0 +1,254 @@ +from amaranth import * + +class CPU(Elaboratable): + + def __init__(self): + self.mem_addr = Signal(32) + self.mem_rstrb = Signal() + self.mem_rdata = Signal(32) + self.x10 = Signal(32) + self.fsm = None + + def elaborate(self, platform): + m = Module() + + # Program counter + pc = Signal(32) + self.pc = pc + + # Memory + mem_rdata = self.mem_rdata + + # Current instruction + instr = Signal(32, reset=0b0110011) + self.instr = instr + + # Register bank + regs = Array([Signal(32, name="x"+str(x)) for x in range(32)]) + rs1 = Signal(32) + rs2 = Signal(32) + + # ALU registers + aluOut = Signal(32) + takeBranch = Signal(32) + + # Opcode decoder + # It is nice to have these as actual signals for simulation + isALUreg = Signal() + isALUimm = Signal() + isBranch = Signal() + isJALR = Signal() + isJAL = Signal() + isAUIPC = Signal() + isLUI = Signal() + isLoad = Signal() + isStore = Signal() + isSystem = Signal() + m.d.comb += [ + isALUreg.eq(instr[0:7] == 0b0110011), + isALUimm.eq(instr[0:7] == 0b0010011), + isBranch.eq(instr[0:7] == 0b1100011), + isJALR.eq(instr[0:7] == 0b1100111), + isJAL.eq(instr[0:7] == 0b1101111), + isAUIPC.eq(instr[0:7] == 0b0010111), + isLUI.eq(instr[0:7] == 0b0110111), + isLoad.eq(instr[0:7] == 0b0000011), + isStore.eq(instr[0:7] == 0b0100011), + isSystem.eq(instr[0:7] == 0b1110011) + ] + self.isALUreg = isALUreg + self.isALUimm = isALUimm + self.isBranch = isBranch + self.isLoad = isLoad + self.isStore = isStore + self.isSystem = isSystem + + # Extend a signal with a sign bit repeated n times + def SignExtend(signal, sign, n): + return Cat(signal, Repl(sign, n)) + + # Immediate format decoder + Uimm = Cat(Repl(0, 12), instr[12:32]) + Iimm = Cat(instr[20:31], Repl(instr[31], 21)) + Simm = Cat(instr[7:12], instr[25:31], Repl(instr[31], 21)) + Bimm = Cat(0, instr[8:12], instr[25:31], instr[7], + Repl(instr[31], 20)) + Jimm = Cat(0, instr[21:31], instr[20], instr[12:20], + Repl(instr[31], 12)) + self.Iimm = Iimm + + # Register addresses decoder + rs1Id = instr[15:20] + rs2Id = instr[20:25] + rdId = instr[7:12] + + self.rdId = rdId + self.rs1Id = rs1Id + self.rs2Id = rs2Id + + # Function code decdore + funct3 = instr[12:15] + funct7 = instr[25:32] + self.funct3 = funct3 + + # ALU + aluIn1 = Signal.like(rs1) + aluIn2 = Signal.like(rs2) + shamt = Signal(5) + aluMinus = Signal(33) + aluPlus = Signal.like(aluIn1) + + m.d.comb += [ + aluIn1.eq(rs1), + aluIn2.eq(Mux((isALUreg | isBranch), rs2, Iimm)), + shamt.eq(Mux(isALUreg, rs2[0:5], instr[20:25])) + ] + + m.d.comb += [ + aluMinus.eq(Cat(~aluIn1, C(0,1)) + Cat(aluIn2, C(0,1)) + 1), + aluPlus.eq(aluIn1 + aluIn2) + ] + + EQ = aluMinus[0:32] == 0 + LTU = aluMinus[32] + LT = Mux((aluIn1[31] ^ aluIn2[31]), aluIn1[31], aluMinus[32]) + + def flip32(x): + a = [x[i] for i in range(0, 32)] + return Cat(*reversed(a)) + + # TODO: check these again! + shifter_in = Mux(funct3 == 0b001, flip32(aluIn1), aluIn1) + shifter = Cat(shifter_in, (instr[30] & aluIn1[31])) >> aluIn2[0:5] + leftshift = flip32(shifter) + + with m.Switch(funct3) as alu: + with m.Case(0b000): + m.d.comb += aluOut.eq(Mux(funct7[5] & instr[5], + aluMinus[0:32], aluPlus)) + with m.Case(0b001): + m.d.comb += aluOut.eq(leftshift) + with m.Case(0b010): + m.d.comb += aluOut.eq(LT) + with m.Case(0b011): + m.d.comb += aluOut.eq(LTU) + with m.Case(0b100): + m.d.comb += aluOut.eq(aluIn1 ^ aluIn2) + with m.Case(0b101): + m.d.comb += aluOut.eq(shifter) + with m.Case(0b110): + m.d.comb += aluOut.eq(aluIn1 | aluIn2) + with m.Case(0b111): + m.d.comb += aluOut.eq(aluIn1 & aluIn2) + + with m.Switch(funct3) as alu_branch: + with m.Case(0b000): + m.d.comb += takeBranch.eq(EQ) + with m.Case(0b001): + m.d.comb += takeBranch.eq(~EQ) + with m.Case(0b100): + m.d.comb += takeBranch.eq(LT) + with m.Case(0b101): + m.d.comb += takeBranch.eq(~LT) + with m.Case(0b110): + m.d.comb += takeBranch.eq(LTU) + with m.Case(0b111): + m.d.comb += takeBranch.eq(~LTU) + with m.Case("---"): + m.d.comb += takeBranch.eq(0) + + # Next program counter is either next intstruction or depends on + # jump target + pcPlusImm = pc + Mux(instr[3], Jimm[0:32], + Mux(instr[4], Uimm[0:32], + Bimm[0:32])) + pcPlus4 = pc + 4 + + nextPc = Mux(((isBranch & takeBranch) | isJAL), pcPlusImm, + Mux(isJALR, Cat(C(0, 1), aluPlus[1:32]), + pcPlus4)) + + # Main state machine + with m.FSM(reset="FETCH_INSTR") as fsm: + self.fsm = fsm + with m.State("FETCH_INSTR"): + m.next = "WAIT_INSTR" + with m.State("WAIT_INSTR"): + m.d.sync += instr.eq(self.mem_rdata) + m.next = ("FETCH_REGS") + with m.State("FETCH_REGS"): + m.d.sync += [ + rs1.eq(regs[rs1Id]), + rs2.eq(regs[rs2Id]) + ] + m.next = "EXECUTE" + with m.State("EXECUTE"): + with m.If(~isSystem): + m.d.sync += pc.eq(nextPc) + with m.If(isLoad): + m.next = "LOAD" + with m.Else(): + m.next = "FETCH_INSTR" + with m.State("LOAD"): + m.next = "WAIT_DATA" + with m.State("WAIT_DATA"): + m.next = "FETCH_INSTR" + + ## Load and store + + loadStoreAddr = Signal(32) + m.d.comb += loadStoreAddr.eq(rs1 + Iimm) + + # Load + memByteAccess = Signal() + memHalfwordAccess = Signal() + loadHalfword = Signal(16) + loadByte = Signal(8) + loadSign = Signal() + loadData = Signal(32) + + m.d.comb += [ + memByteAccess.eq(funct3[0:2] == C(0,2)), + memHalfwordAccess.eq(funct3[0:2] == C(1,2)), + loadHalfword.eq(Mux(loadStoreAddr[1], mem_rdata[16:32], + mem_rdata[0:16])), + loadByte.eq(Mux(loadStoreAddr[0], loadHalfword[8:16], + loadHalfword[0:8])), + loadSign.eq(~funct3[2] & Mux(memByteAccess, loadByte[7], + loadHalfword[15])), + loadData.eq( + Mux(memByteAccess, SignExtend(loadByte, loadSign, 24), + Mux(memHalfwordAccess, SignExtend(loadHalfword, + loadSign, 16), + mem_rdata))) + ] + + # Wire memory address to pc or loadStoreAddr + m.d.comb += [ + self.mem_addr.eq( + Mux(fsm.ongoing("WAIT_INSTR") | fsm.ongoing("FETCH_INSTR"), + pc, loadStoreAddr)), + self.mem_rstrb.eq(fsm.ongoing("FETCH_INSTR") | fsm.ongoing("LOAD")) + ] + + + # Register write back + writeBackData = Mux((isJAL | isJALR), pcPlus4, + Mux(isLUI, Uimm, + Mux(isAUIPC, pcPlusImm, + Mux(isLoad, loadData, + aluOut)))) + + writeBackEn = ((fsm.ongoing("EXECUTE") & ~isBranch & ~isStore & ~isLoad) + | fsm.ongoing("WAIT_DATA")) + + self.writeBackData = writeBackData + + + with m.If(writeBackEn & (rdId != 0)): + m.d.sync += regs[rdId].eq(writeBackData) + # Also assign to debug output to see what is happening + with m.If(rdId == 10): + m.d.sync += self.x10.eq(writeBackData) + + return m diff --git a/15_load/memory.py b/15_load/memory.py new file mode 100644 index 0000000..1682c4e --- /dev/null +++ b/15_load/memory.py @@ -0,0 +1,58 @@ +from amaranth import * +from riscv_assembler import RiscvAssembler + +class Memory(Elaboratable): + + def __init__(self): + a = RiscvAssembler() + + a.read("""begin: + LI s0, 0 + LI s1, 16 + + l0: + LB a0, s0, 400 + CALL wait + ADDI s0, s0, 1 + BNE s0, s1, l0 + EBREAK + + wait: + LI t0, 1 + SLLI t0, t0, 20 + + l1: + ADDI t0, t0, -1 + BNEZ t0, l1 + RET + """) + + a.assemble() + self.instructions = a.mem + print("memory = {}".format(self.instructions)) + + # Instruction memory initialised with above instructions + self.mem = Array([Signal(32, reset=x, name="mem{}".format(i)) + for i,x in enumerate(self.instructions)]) + + self.mem_addr = Signal(32) + self.mem_rdata = Signal(32) + self.mem_rstrb = Signal() + + while(len(self.mem) < 100): + self.mem.append(0) + + self.mem.append(0x04030201) + self.mem.append(0x08070605) + self.mem.append(0x0c0b0a09) + self.mem.append(0xff0f0e0d) + + print(self.mem) + + def elaborate(self, platform): + m = Module() + + with m.If(self.mem_rstrb): + m.d.sync += self.mem_rdata.eq(self.mem[self.mem_addr[2:32]]) + + return m diff --git a/15_load/soc.py b/15_load/soc.py new file mode 100644 index 0000000..beba9ea --- /dev/null +++ b/15_load/soc.py @@ -0,0 +1,82 @@ +import sys +from amaranth import * + +from clockworks import Clockworks +from memory import Memory +from cpu import CPU + +class SOC(Elaboratable): + + def __init__(self): + + self.leds = Signal(5) + + # Signals in this list can easily be plotted as vcd traces + self.ports = [] + + def elaborate(self, platform): + + m = Module() + cw = Clockworks() + memory = DomainRenamer("slow")(Memory()) + cpu = DomainRenamer("slow")(CPU()) + m.submodules.cw = cw + m.submodules.cpu = cpu + m.submodules.memory = memory + + self.cpu = cpu + self.memory = memory + + x10 = Signal(32) + + # Connect memory to CPU + m.d.comb += [ + memory.mem_addr.eq(cpu.mem_addr), + memory.mem_rstrb.eq(cpu.mem_rstrb), + cpu.mem_rdata.eq(memory.mem_rdata) + ] + + # CPU debug output + m.d.comb += [ + x10.eq(cpu.x10), + self.leds.eq(x10[0:5]) + ] + + # Export signals for simulation + def export(signal, name): + if type(signal) is not Signal: + newsig = Signal(signal.shape(), name = name) + m.d.comb += newsig.eq(signal) + else: + newsig = signal + self.ports.append(newsig) + setattr(self, name, newsig) + + if platform is None: + export(ClockSignal("slow"), "slow_clk") + #export(pc, "pc") + #export(instr, "instr") + #export(isALUreg, "isALUreg") + #export(isALUimm, "isALUimm") + #export(isBranch, "isBranch") + #export(isJAL, "isJAL") + #export(isJALR, "isJALR") + #export(isLoad, "isLoad") + #export(isStore, "isStore") + #export(isSystem, "isSystem") + #export(rdId, "rdId") + #export(rs1Id, "rs1Id") + #export(rs2Id, "rs2Id") + #export(Iimm, "Iimm") + #export(Bimm, "Bimm") + #export(Jimm, "Jimm") + #export(funct3, "funct3") + #export(rdId, "rdId") + #export(rs1, "rs1") + #export(rs2, "rs2") + #export(writeBackData, "writeBackData") + #export(writeBackEn, "writeBackEn") + #export(aluOut, "aluOut") + #export((1 << cpu.fsm.state), "state") + + return m diff --git a/boards/top.py b/boards/top.py index 636febc..ffd6bf3 100644 --- a/boards/top.py +++ b/boards/top.py @@ -40,6 +40,8 @@ class Top(Elaboratable): path = "13_subroutines" elif step == 14: path = "14_subroutines_v2" + elif step == 15: + path = "15_load" else: print("Invalid step_number {}.".format(step)) exit(1)