Add step 15: load.

This commit is contained in:
Bastian Löher
2023-02-05 03:32:54 +01:00
parent b2b67d191c
commit b70af98ee8
6 changed files with 495 additions and 0 deletions

61
15_load/bench.py Normal file
View File

@@ -0,0 +1,61 @@
from amaranth import *
from amaranth.sim import *
from soc import SOC
soc = SOC()
sim = Simulator(soc)
prev_clk = 0
def proc():
cpu = soc.cpu
mem = soc.memory
while True:
global prev_clk
clk = yield soc.slow_clk
if prev_clk == 0 and prev_clk != clk:
state = (yield soc.cpu.fsm.state)
if state == 2:
print("-- NEW CYCLE -----------------------")
print(" F: LEDS = {:05b}".format((yield soc.leds)))
print(" F: pc={}".format((yield cpu.pc)))
print(" F: instr={:#032b}".format((yield cpu.instr)))
if (yield cpu.isALUreg):
print(" ALUreg rd={} rs1={} rs2={} funct3={}".format(
(yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.rs2Id),
(yield cpu.funct3)))
if (yield cpu.isALUimm):
print(" ALUimm rd={} rs1={} imm={} funct3={}".format(
(yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.Iimm),
(yield cpu.funct3)))
if (yield cpu.isBranch):
print(" BRANCH rs1={} rs2={}".format(
(yield cpu.rs1Id), (yield cpu.rs2Id)))
if (yield cpu.isLoad):
print(" LOAD")
if (yield cpu.isStore):
print(" STORE")
if (yield cpu.isSystem):
print(" SYSTEM")
break
if state == 4:
print(" R: LEDS = {:05b}".format((yield soc.leds)))
print(" R: rs1={}".format((yield cpu.rs1)))
print(" R: rs2={}".format((yield cpu.rs2)))
if state == 1:
print(" E: LEDS = {:05b}".format((yield soc.leds)))
print(" E: Writeback x{} = {:032b}".format((yield cpu.rdId),
(yield cpu.writeBackData)))
if state == 8:
print(" NEW")
yield
prev_clk = clk
sim.add_clock(1e-6)
sim.add_sync_process(proc)
with sim.write_vcd('bench.vcd', 'bench.gtkw', traces=soc.ports):
# Let's run for a quite long time
sim.run_until(2, )

38
15_load/blink.py Normal file
View File

@@ -0,0 +1,38 @@
from amaranth import *
from amaranth_boards.arty_a7 import *
from soc import SOC
# A platform contains board specific information about FPGA pin assignments,
# toolchain and specific information for uploading the bitfile.
platform = ArtyA7_35Platform(toolchain="Symbiflow")
# We need a top level module
m = Module()
# This is the instance of our SOC
soc = SOC()
# The SOC is turned into a submodule (fragment) of our top level module.
m.submodules.soc = soc
# The platform allows access to the various resources defined by the board
# definition from amaranth-boards.
led0 = platform.request('led', 0)
led1 = platform.request('led', 1)
led2 = platform.request('led', 2)
led3 = platform.request('led', 3)
rgb = platform.request('rgb_led')
# We connect the SOC leds signal to the various LEDs on the board.
m.d.comb += [
led0.o.eq(soc.leds[0]),
led1.o.eq(soc.leds[1]),
led1.o.eq(soc.leds[2]),
led1.o.eq(soc.leds[3]),
rgb.r.o.eq(soc.leds[4]),
]
# To generate the bitstream, we build() the platform using our top level
# module m.
platform.build(m, do_program=False)

254
15_load/cpu.py Normal file
View File

@@ -0,0 +1,254 @@
from amaranth import *
class CPU(Elaboratable):
def __init__(self):
self.mem_addr = Signal(32)
self.mem_rstrb = Signal()
self.mem_rdata = Signal(32)
self.x10 = Signal(32)
self.fsm = None
def elaborate(self, platform):
m = Module()
# Program counter
pc = Signal(32)
self.pc = pc
# Memory
mem_rdata = self.mem_rdata
# Current instruction
instr = Signal(32, reset=0b0110011)
self.instr = instr
# Register bank
regs = Array([Signal(32, name="x"+str(x)) for x in range(32)])
rs1 = Signal(32)
rs2 = Signal(32)
# ALU registers
aluOut = Signal(32)
takeBranch = Signal(32)
# Opcode decoder
# It is nice to have these as actual signals for simulation
isALUreg = Signal()
isALUimm = Signal()
isBranch = Signal()
isJALR = Signal()
isJAL = Signal()
isAUIPC = Signal()
isLUI = Signal()
isLoad = Signal()
isStore = Signal()
isSystem = Signal()
m.d.comb += [
isALUreg.eq(instr[0:7] == 0b0110011),
isALUimm.eq(instr[0:7] == 0b0010011),
isBranch.eq(instr[0:7] == 0b1100011),
isJALR.eq(instr[0:7] == 0b1100111),
isJAL.eq(instr[0:7] == 0b1101111),
isAUIPC.eq(instr[0:7] == 0b0010111),
isLUI.eq(instr[0:7] == 0b0110111),
isLoad.eq(instr[0:7] == 0b0000011),
isStore.eq(instr[0:7] == 0b0100011),
isSystem.eq(instr[0:7] == 0b1110011)
]
self.isALUreg = isALUreg
self.isALUimm = isALUimm
self.isBranch = isBranch
self.isLoad = isLoad
self.isStore = isStore
self.isSystem = isSystem
# Extend a signal with a sign bit repeated n times
def SignExtend(signal, sign, n):
return Cat(signal, Repl(sign, n))
# Immediate format decoder
Uimm = Cat(Repl(0, 12), instr[12:32])
Iimm = Cat(instr[20:31], Repl(instr[31], 21))
Simm = Cat(instr[7:12], instr[25:31], Repl(instr[31], 21))
Bimm = Cat(0, instr[8:12], instr[25:31], instr[7],
Repl(instr[31], 20))
Jimm = Cat(0, instr[21:31], instr[20], instr[12:20],
Repl(instr[31], 12))
self.Iimm = Iimm
# Register addresses decoder
rs1Id = instr[15:20]
rs2Id = instr[20:25]
rdId = instr[7:12]
self.rdId = rdId
self.rs1Id = rs1Id
self.rs2Id = rs2Id
# Function code decdore
funct3 = instr[12:15]
funct7 = instr[25:32]
self.funct3 = funct3
# ALU
aluIn1 = Signal.like(rs1)
aluIn2 = Signal.like(rs2)
shamt = Signal(5)
aluMinus = Signal(33)
aluPlus = Signal.like(aluIn1)
m.d.comb += [
aluIn1.eq(rs1),
aluIn2.eq(Mux((isALUreg | isBranch), rs2, Iimm)),
shamt.eq(Mux(isALUreg, rs2[0:5], instr[20:25]))
]
m.d.comb += [
aluMinus.eq(Cat(~aluIn1, C(0,1)) + Cat(aluIn2, C(0,1)) + 1),
aluPlus.eq(aluIn1 + aluIn2)
]
EQ = aluMinus[0:32] == 0
LTU = aluMinus[32]
LT = Mux((aluIn1[31] ^ aluIn2[31]), aluIn1[31], aluMinus[32])
def flip32(x):
a = [x[i] for i in range(0, 32)]
return Cat(*reversed(a))
# TODO: check these again!
shifter_in = Mux(funct3 == 0b001, flip32(aluIn1), aluIn1)
shifter = Cat(shifter_in, (instr[30] & aluIn1[31])) >> aluIn2[0:5]
leftshift = flip32(shifter)
with m.Switch(funct3) as alu:
with m.Case(0b000):
m.d.comb += aluOut.eq(Mux(funct7[5] & instr[5],
aluMinus[0:32], aluPlus))
with m.Case(0b001):
m.d.comb += aluOut.eq(leftshift)
with m.Case(0b010):
m.d.comb += aluOut.eq(LT)
with m.Case(0b011):
m.d.comb += aluOut.eq(LTU)
with m.Case(0b100):
m.d.comb += aluOut.eq(aluIn1 ^ aluIn2)
with m.Case(0b101):
m.d.comb += aluOut.eq(shifter)
with m.Case(0b110):
m.d.comb += aluOut.eq(aluIn1 | aluIn2)
with m.Case(0b111):
m.d.comb += aluOut.eq(aluIn1 & aluIn2)
with m.Switch(funct3) as alu_branch:
with m.Case(0b000):
m.d.comb += takeBranch.eq(EQ)
with m.Case(0b001):
m.d.comb += takeBranch.eq(~EQ)
with m.Case(0b100):
m.d.comb += takeBranch.eq(LT)
with m.Case(0b101):
m.d.comb += takeBranch.eq(~LT)
with m.Case(0b110):
m.d.comb += takeBranch.eq(LTU)
with m.Case(0b111):
m.d.comb += takeBranch.eq(~LTU)
with m.Case("---"):
m.d.comb += takeBranch.eq(0)
# Next program counter is either next intstruction or depends on
# jump target
pcPlusImm = pc + Mux(instr[3], Jimm[0:32],
Mux(instr[4], Uimm[0:32],
Bimm[0:32]))
pcPlus4 = pc + 4
nextPc = Mux(((isBranch & takeBranch) | isJAL), pcPlusImm,
Mux(isJALR, Cat(C(0, 1), aluPlus[1:32]),
pcPlus4))
# Main state machine
with m.FSM(reset="FETCH_INSTR") as fsm:
self.fsm = fsm
with m.State("FETCH_INSTR"):
m.next = "WAIT_INSTR"
with m.State("WAIT_INSTR"):
m.d.sync += instr.eq(self.mem_rdata)
m.next = ("FETCH_REGS")
with m.State("FETCH_REGS"):
m.d.sync += [
rs1.eq(regs[rs1Id]),
rs2.eq(regs[rs2Id])
]
m.next = "EXECUTE"
with m.State("EXECUTE"):
with m.If(~isSystem):
m.d.sync += pc.eq(nextPc)
with m.If(isLoad):
m.next = "LOAD"
with m.Else():
m.next = "FETCH_INSTR"
with m.State("LOAD"):
m.next = "WAIT_DATA"
with m.State("WAIT_DATA"):
m.next = "FETCH_INSTR"
## Load and store
loadStoreAddr = Signal(32)
m.d.comb += loadStoreAddr.eq(rs1 + Iimm)
# Load
memByteAccess = Signal()
memHalfwordAccess = Signal()
loadHalfword = Signal(16)
loadByte = Signal(8)
loadSign = Signal()
loadData = Signal(32)
m.d.comb += [
memByteAccess.eq(funct3[0:2] == C(0,2)),
memHalfwordAccess.eq(funct3[0:2] == C(1,2)),
loadHalfword.eq(Mux(loadStoreAddr[1], mem_rdata[16:32],
mem_rdata[0:16])),
loadByte.eq(Mux(loadStoreAddr[0], loadHalfword[8:16],
loadHalfword[0:8])),
loadSign.eq(~funct3[2] & Mux(memByteAccess, loadByte[7],
loadHalfword[15])),
loadData.eq(
Mux(memByteAccess, SignExtend(loadByte, loadSign, 24),
Mux(memHalfwordAccess, SignExtend(loadHalfword,
loadSign, 16),
mem_rdata)))
]
# Wire memory address to pc or loadStoreAddr
m.d.comb += [
self.mem_addr.eq(
Mux(fsm.ongoing("WAIT_INSTR") | fsm.ongoing("FETCH_INSTR"),
pc, loadStoreAddr)),
self.mem_rstrb.eq(fsm.ongoing("FETCH_INSTR") | fsm.ongoing("LOAD"))
]
# Register write back
writeBackData = Mux((isJAL | isJALR), pcPlus4,
Mux(isLUI, Uimm,
Mux(isAUIPC, pcPlusImm,
Mux(isLoad, loadData,
aluOut))))
writeBackEn = ((fsm.ongoing("EXECUTE") & ~isBranch & ~isStore & ~isLoad)
| fsm.ongoing("WAIT_DATA"))
self.writeBackData = writeBackData
with m.If(writeBackEn & (rdId != 0)):
m.d.sync += regs[rdId].eq(writeBackData)
# Also assign to debug output to see what is happening
with m.If(rdId == 10):
m.d.sync += self.x10.eq(writeBackData)
return m

58
15_load/memory.py Normal file
View File

@@ -0,0 +1,58 @@
from amaranth import *
from riscv_assembler import RiscvAssembler
class Memory(Elaboratable):
def __init__(self):
a = RiscvAssembler()
a.read("""begin:
LI s0, 0
LI s1, 16
l0:
LB a0, s0, 400
CALL wait
ADDI s0, s0, 1
BNE s0, s1, l0
EBREAK
wait:
LI t0, 1
SLLI t0, t0, 20
l1:
ADDI t0, t0, -1
BNEZ t0, l1
RET
""")
a.assemble()
self.instructions = a.mem
print("memory = {}".format(self.instructions))
# Instruction memory initialised with above instructions
self.mem = Array([Signal(32, reset=x, name="mem{}".format(i))
for i,x in enumerate(self.instructions)])
self.mem_addr = Signal(32)
self.mem_rdata = Signal(32)
self.mem_rstrb = Signal()
while(len(self.mem) < 100):
self.mem.append(0)
self.mem.append(0x04030201)
self.mem.append(0x08070605)
self.mem.append(0x0c0b0a09)
self.mem.append(0xff0f0e0d)
print(self.mem)
def elaborate(self, platform):
m = Module()
with m.If(self.mem_rstrb):
m.d.sync += self.mem_rdata.eq(self.mem[self.mem_addr[2:32]])
return m

82
15_load/soc.py Normal file
View File

@@ -0,0 +1,82 @@
import sys
from amaranth import *
from clockworks import Clockworks
from memory import Memory
from cpu import CPU
class SOC(Elaboratable):
def __init__(self):
self.leds = Signal(5)
# Signals in this list can easily be plotted as vcd traces
self.ports = []
def elaborate(self, platform):
m = Module()
cw = Clockworks()
memory = DomainRenamer("slow")(Memory())
cpu = DomainRenamer("slow")(CPU())
m.submodules.cw = cw
m.submodules.cpu = cpu
m.submodules.memory = memory
self.cpu = cpu
self.memory = memory
x10 = Signal(32)
# Connect memory to CPU
m.d.comb += [
memory.mem_addr.eq(cpu.mem_addr),
memory.mem_rstrb.eq(cpu.mem_rstrb),
cpu.mem_rdata.eq(memory.mem_rdata)
]
# CPU debug output
m.d.comb += [
x10.eq(cpu.x10),
self.leds.eq(x10[0:5])
]
# Export signals for simulation
def export(signal, name):
if type(signal) is not Signal:
newsig = Signal(signal.shape(), name = name)
m.d.comb += newsig.eq(signal)
else:
newsig = signal
self.ports.append(newsig)
setattr(self, name, newsig)
if platform is None:
export(ClockSignal("slow"), "slow_clk")
#export(pc, "pc")
#export(instr, "instr")
#export(isALUreg, "isALUreg")
#export(isALUimm, "isALUimm")
#export(isBranch, "isBranch")
#export(isJAL, "isJAL")
#export(isJALR, "isJALR")
#export(isLoad, "isLoad")
#export(isStore, "isStore")
#export(isSystem, "isSystem")
#export(rdId, "rdId")
#export(rs1Id, "rs1Id")
#export(rs2Id, "rs2Id")
#export(Iimm, "Iimm")
#export(Bimm, "Bimm")
#export(Jimm, "Jimm")
#export(funct3, "funct3")
#export(rdId, "rdId")
#export(rs1, "rs1")
#export(rs2, "rs2")
#export(writeBackData, "writeBackData")
#export(writeBackEn, "writeBackEn")
#export(aluOut, "aluOut")
#export((1 << cpu.fsm.state), "state")
return m

View File

@@ -40,6 +40,8 @@ class Top(Elaboratable):
path = "13_subroutines"
elif step == 14:
path = "14_subroutines_v2"
elif step == 15:
path = "15_load"
else:
print("Invalid step_number {}.".format(step))
exit(1)