Add step 12 in top.

This commit is contained in:
Bastian Löher
2023-01-16 10:30:13 +01:00
parent 33760eb512
commit 47fe7bee36
6 changed files with 399 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
from amaranth import *
from amaranth.sim import *
from soc import SOC
soc = SOC()
sim = Simulator(soc)
prev_clk = 0
def proc():
cpu = soc.cpu
mem = soc.memory
while True:
global prev_clk
clk = yield soc.slow_clk
if prev_clk == 0 and prev_clk != clk:
state = (yield soc.cpu.fsm.state)
if state == 2:
print("-- NEW CYCLE -----------------------")
print(" F: LEDS = {:05b}".format((yield soc.leds)))
print(" F: pc={}".format((yield cpu.pc)))
print(" F: instr={:#032b}".format((yield cpu.instr)))
if (yield cpu.isALUreg):
print(" ALUreg rd={} rs1={} rs2={} funct3={}".format(
(yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.rs2Id),
(yield cpu.funct3)))
if (yield cpu.isALUimm):
print(" ALUimm rd={} rs1={} imm={} funct3={}".format(
(yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.Iimm),
(yield cpu.funct3)))
if (yield cpu.isBranch):
print(" BRANCH rs1={} rs2={}".format(
(yield cpu.rs1Id), (yield cpu.rs2Id)))
if (yield cpu.isLoad):
print(" LOAD")
if (yield cpu.isStore):
print(" STORE")
if (yield cpu.isSystem):
print(" SYSTEM")
break
if state == 4:
print(" R: LEDS = {:05b}".format((yield soc.leds)))
print(" R: rs1={}".format((yield cpu.rs1)))
print(" R: rs2={}".format((yield cpu.rs2)))
if state == 1:
print(" E: LEDS = {:05b}".format((yield soc.leds)))
print(" E: Writeback x{} = {:032b}".format((yield cpu.rdId),
(yield cpu.writeBackData)))
if state == 8:
print(" NEW")
yield
prev_clk = clk
sim.add_clock(1e-6)
sim.add_sync_process(proc)
with sim.write_vcd('bench.vcd', 'bench.gtkw', traces=soc.ports):
# Let's run for a quite long time
sim.run_until(2, )

View File

@@ -0,0 +1,38 @@
from amaranth import *
from amaranth_boards.arty_a7 import *
from soc import SOC
# A platform contains board specific information about FPGA pin assignments,
# toolchain and specific information for uploading the bitfile.
platform = ArtyA7_35Platform(toolchain="Symbiflow")
# We need a top level module
m = Module()
# This is the instance of our SOC
soc = SOC()
# The SOC is turned into a submodule (fragment) of our top level module.
m.submodules.soc = soc
# The platform allows access to the various resources defined by the board
# definition from amaranth-boards.
led0 = platform.request('led', 0)
led1 = platform.request('led', 1)
led2 = platform.request('led', 2)
led3 = platform.request('led', 3)
rgb = platform.request('rgb_led')
# We connect the SOC leds signal to the various LEDs on the board.
m.d.comb += [
led0.o.eq(soc.leds[0]),
led1.o.eq(soc.leds[1]),
led1.o.eq(soc.leds[2]),
led1.o.eq(soc.leds[3]),
rgb.r.o.eq(soc.leds[4]),
]
# To generate the bitstream, we build() the platform using our top level
# module m.
platform.build(m, do_program=False)

180
12_size_optimisation/cpu.py Normal file
View File

@@ -0,0 +1,180 @@
from amaranth import *
class CPU(Elaboratable):
def __init__(self):
self.mem_addr = Signal(32)
self.mem_rstrb = Signal()
self.mem_rdata = Signal(32)
self.x1 = Signal(32)
self.fsm = None
def elaborate(self, platform):
m = Module()
# Program counter
pc = Signal(32)
self.pc = pc
# Current instruction
instr = Signal(32, reset=0b0110011)
self.instr = instr
# Register bank
regs = Array([Signal(32, name="x"+str(x)) for x in range(32)])
rs1 = Signal(32)
rs2 = Signal(32)
# ALU registers
aluOut = Signal(32)
takeBranch = Signal(32)
# Opcode decoder
isALUreg = (instr[0:7] == 0b0110011)
isALUimm = (instr[0:7] == 0b0010011)
isBranch = (instr[0:7] == 0b1100011)
isJALR = (instr[0:7] == 0b1100111)
isJAL = (instr[0:7] == 0b1101111)
isAUIPC = (instr[0:7] == 0b0010111)
isLUI = (instr[0:7] == 0b0110111)
isLoad = (instr[0:7] == 0b0000011)
isStore = (instr[0:7] == 0b0100011)
isSystem = (instr[0:7] == 0b1110011)
self.isALUreg = isALUreg
self.isALUimm = isALUimm
self.isBranch = isBranch
self.isLoad = isLoad
self.isStore = isStore
self.isSystem = isSystem
def Extend(x, n):
return [x for i in range(n + 1)]
# Immediate format decoder
Uimm = Cat(Const(0, 12), instr[12:32])
Iimm = Cat(instr[20:31], *Extend(instr[31], 21))
Simm = Cat(instr[7:12], instr[25:31], *Extend(instr[31], 21))
Bimm = Cat(0, instr[8:12], instr[25:31], instr[7],
*Extend(instr[31], 20))
Jimm = Cat(0, instr[21:31], instr[20], instr[12:20],
*Extend(instr[31], 12))
self.Iimm = Iimm
# Register addresses decoder
rs1Id = instr[15:20]
rs2Id = instr[20:25]
rdId = instr[7:12]
self.rdId = rdId
self.rs1Id = rs1Id
self.rs2Id = rs2Id
# Function code decdore
funct3 = instr[12:15]
funct7 = instr[25:32]
self.funct3 = funct3
# ALU
aluIn1 = rs1
aluIn2 = Mux((isALUreg | isBranch), rs2, Iimm)
shamt = Mux(isALUreg, rs2[0:5], instr[20:25])
# Wire memory address to pc
m.d.comb += self.mem_addr.eq(pc)
aluMinus = Cat(~aluIn1, C(0,1)) + Cat(aluIn2, C(0,1)) + 1
aluPlus = aluIn1 + aluIn2
EQ = aluMinus[0:32] == 0
LTU = aluMinus[32]
LT = Mux((aluIn1[31] ^ aluIn2[31]), aluIn1[31], aluMinus[32])
def flip32(x):
a = [x[i] for i in range(0, 32)]
return Cat(*reversed(a))
# TODO: check these again!
shifter_in = Mux(funct3 == 0b001, flip32(aluIn1), aluIn1)
shifter = Cat(shifter_in, (instr[30] & aluIn1[31])) >> aluIn2[0:5]
leftshift = flip32(shifter)
with m.Switch(funct3) as alu:
with m.Case(0b000):
m.d.comb += aluOut.eq(Mux(funct7[5] & instr[5],
aluMinus[0:32], aluPlus))
with m.Case(0b001):
m.d.comb += aluOut.eq(leftshift)
with m.Case(0b010):
m.d.comb += aluOut.eq(LT)
with m.Case(0b011):
m.d.comb += aluOut.eq(LTU)
with m.Case(0b100):
m.d.comb += aluOut.eq(aluIn1 ^ aluIn2)
with m.Case(0b101):
m.d.comb += aluOut.eq(shifter)
with m.Case(0b110):
m.d.comb += aluOut.eq(aluIn1 | aluIn2)
with m.Case(0b111):
m.d.comb += aluOut.eq(aluIn1 & aluIn2)
with m.Switch(funct3) as alu_branch:
with m.Case(0b000):
m.d.comb += takeBranch.eq(EQ)
with m.Case(0b001):
m.d.comb += takeBranch.eq(~EQ)
with m.Case(0b100):
m.d.comb += takeBranch.eq(LT)
with m.Case(0b101):
m.d.comb += takeBranch.eq(~LT)
with m.Case(0b110):
m.d.comb += takeBranch.eq(LTU)
with m.Case(0b111):
m.d.comb += takeBranch.eq(~LTU)
with m.Case("---"):
m.d.comb += takeBranch.eq(0)
# Next program counter is either next intstruction or depends on
# jump target
pcPlusImm = pc + Mux(instr[3], Jimm[0:32],
Mux(instr[4], Uimm[0:32],
Bimm[0:32]))
pcPlus4 = pc + 4
nextPc = Mux(((isBranch & takeBranch) | isJAL), pcPlusImm,
Mux(isJALR, Cat(C(0, 1), aluPlus[1:32]),
pcPlus4))
# Main state machine
with m.FSM(reset="FETCH_INSTR") as fsm:
self.fsm = fsm
m.d.comb += self.mem_rstrb.eq(fsm.ongoing("FETCH_INSTR"))
with m.State("FETCH_INSTR"):
m.next = "WAIT_INSTR"
with m.State("WAIT_INSTR"):
m.d.sync += instr.eq(self.mem_rdata)
m.next = ("FETCH_REGS")
with m.State("FETCH_REGS"):
m.d.sync += [
rs1.eq(regs[rs1Id]),
rs2.eq(regs[rs2Id])
]
m.next = "EXECUTE"
with m.State("EXECUTE"):
m.d.sync += pc.eq(nextPc)
m.next = "FETCH_INSTR"
# Register write back
writeBackData = Mux((isJAL | isJALR), pcPlus4,
Mux(isLUI, Uimm,
Mux(isAUIPC, pcPlusImm, aluOut)))
writeBackEn = fsm.ongoing("EXECUTE") & ~isBranch & ~isStore
self.writeBackData = writeBackData
with m.If(writeBackEn & (rdId != 0)):
m.d.sync += regs[rdId].eq(writeBackData)
# Also assign to debug output to see what is happening
m.d.sync += self.x1.eq(writeBackData)
return m

View File

@@ -0,0 +1,36 @@
from amaranth import *
from riscv_assembler import RiscvAssembler
class Memory(Elaboratable):
def __init__(self):
a = RiscvAssembler()
a.read("""begin:
ADD x1, x0, x0
ADDI x2, x0, 31
l0:
ADDI x1, x1, 1
BNE x1, x2, l0
EBREAK
""")
a.assemble()
self.instructions = a.mem
print("memory = {}".format(self.instructions))
# Instruction memory initialised with above instructions
self.mem = Array([Signal(32, reset=x, name="mem")
for x in self.instructions])
self.mem_addr = Signal(32)
self.mem_rdata = Signal(32)
self.mem_rstrb = Signal()
def elaborate(self, platform):
m = Module()
with m.If(self.mem_rstrb):
m.d.sync += self.mem_rdata.eq(self.mem[self.mem_addr[2:32]])
return m

View File

@@ -0,0 +1,82 @@
import sys
from amaranth import *
from clockworks import Clockworks
from memory import Memory
from cpu import CPU
class SOC(Elaboratable):
def __init__(self):
self.leds = Signal(5)
# Signals in this list can easily be plotted as vcd traces
self.ports = []
def elaborate(self, platform):
m = Module()
cw = Clockworks(slow=19, sim_slow=10)
memory = DomainRenamer("slow")(Memory())
cpu = DomainRenamer("slow")(CPU())
m.submodules.cw = cw
m.submodules.cpu = cpu
m.submodules.memory = memory
self.cpu = cpu
self.memory = memory
x1 = Signal(32)
# Connect memory to CPU
m.d.comb += [
memory.mem_addr.eq(cpu.mem_addr),
memory.mem_rstrb.eq(cpu.mem_rstrb),
cpu.mem_rdata.eq(memory.mem_rdata)
]
# CPU debug output
m.d.comb += [
x1.eq(cpu.x1),
self.leds.eq(x1[0:5])
]
# Export signals for simulation
def export(signal, name):
if type(signal) is not Signal:
newsig = Signal(signal.shape(), name = name)
m.d.comb += newsig.eq(signal)
else:
newsig = signal
self.ports.append(newsig)
setattr(self, name, newsig)
if platform is None:
export(ClockSignal("slow"), "slow_clk")
#export(pc, "pc")
#export(instr, "instr")
#export(isALUreg, "isALUreg")
#export(isALUimm, "isALUimm")
#export(isBranch, "isBranch")
#export(isJAL, "isJAL")
#export(isJALR, "isJALR")
#export(isLoad, "isLoad")
#export(isStore, "isStore")
#export(isSystem, "isSystem")
#export(rdId, "rdId")
#export(rs1Id, "rs1Id")
#export(rs2Id, "rs2Id")
#export(Iimm, "Iimm")
#export(Bimm, "Bimm")
#export(Jimm, "Jimm")
#export(funct3, "funct3")
#export(rdId, "rdId")
#export(rs1, "rs1")
#export(rs2, "rs2")
#export(writeBackData, "writeBackData")
#export(writeBackEn, "writeBackEn")
#export(aluOut, "aluOut")
#export((1 << cpu.fsm.state), "state")
return m

View File

@@ -33,6 +33,8 @@ class Top(Elaboratable):
path = "10_lui_auipc"
elif step == 11:
path = "11_modules"
elif step == 12:
path = "12_size_optimisation"
else:
print("Invalid step_number {}.".format(step))
exit(1)