Add pseudo instructions in assembler. Add step 14 with pseudo instructions.

This commit is contained in:
Bastian Löher
2023-02-05 01:57:18 +01:00
parent ba88db9a79
commit b2b67d191c
9 changed files with 592 additions and 16 deletions

View File

@@ -39,7 +39,8 @@ class CPU(Elaboratable):
isLUI = (instr[0:7] == 0b0110111)
isLoad = (instr[0:7] == 0b0000011)
isStore = (instr[0:7] == 0b0100011)
isSystem = (instr[0:7] == 0b1110011)
isSystem = Signal()
m.d.comb += isSystem.eq((instr[0:7] == 0b1110011))
self.isALUreg = isALUreg
self.isALUimm = isALUimm
self.isBranch = isBranch
@@ -160,7 +161,8 @@ class CPU(Elaboratable):
]
m.next = "EXECUTE"
with m.State("EXECUTE"):
m.d.sync += pc.eq(nextPc)
with m.If(~isSystem):
m.d.sync += pc.eq(nextPc)
m.next = "FETCH_INSTR"
# Register write back

View File

@@ -17,7 +17,7 @@ class Memory(Elaboratable):
wait:
ADDI x11, x0, 1
SLLI x11, x11, 15
SLLI x11, x11, 20
l1:
ADDI x11, x11, -1

View File

@@ -0,0 +1,61 @@
from amaranth import *
from amaranth.sim import *
from soc import SOC
soc = SOC()
sim = Simulator(soc)
prev_clk = 0
def proc():
cpu = soc.cpu
mem = soc.memory
while True:
global prev_clk
clk = yield soc.slow_clk
if prev_clk == 0 and prev_clk != clk:
state = (yield soc.cpu.fsm.state)
if state == 2:
print("-- NEW CYCLE -----------------------")
print(" F: LEDS = {:05b}".format((yield soc.leds)))
print(" F: pc={}".format((yield cpu.pc)))
print(" F: instr={:#032b}".format((yield cpu.instr)))
if (yield cpu.isALUreg):
print(" ALUreg rd={} rs1={} rs2={} funct3={}".format(
(yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.rs2Id),
(yield cpu.funct3)))
if (yield cpu.isALUimm):
print(" ALUimm rd={} rs1={} imm={} funct3={}".format(
(yield cpu.rdId), (yield cpu.rs1Id), (yield cpu.Iimm),
(yield cpu.funct3)))
if (yield cpu.isBranch):
print(" BRANCH rs1={} rs2={}".format(
(yield cpu.rs1Id), (yield cpu.rs2Id)))
if (yield cpu.isLoad):
print(" LOAD")
if (yield cpu.isStore):
print(" STORE")
if (yield cpu.isSystem):
print(" SYSTEM")
break
if state == 4:
print(" R: LEDS = {:05b}".format((yield soc.leds)))
print(" R: rs1={}".format((yield cpu.rs1)))
print(" R: rs2={}".format((yield cpu.rs2)))
if state == 1:
print(" E: LEDS = {:05b}".format((yield soc.leds)))
print(" E: Writeback x{} = {:032b}".format((yield cpu.rdId),
(yield cpu.writeBackData)))
if state == 8:
print(" NEW")
yield
prev_clk = clk
sim.add_clock(1e-6)
sim.add_sync_process(proc)
with sim.write_vcd('bench.vcd', 'bench.gtkw', traces=soc.ports):
# Let's run for a quite long time
sim.run_until(2, )

View File

@@ -0,0 +1,38 @@
from amaranth import *
from amaranth_boards.arty_a7 import *
from soc import SOC
# A platform contains board specific information about FPGA pin assignments,
# toolchain and specific information for uploading the bitfile.
platform = ArtyA7_35Platform(toolchain="Symbiflow")
# We need a top level module
m = Module()
# This is the instance of our SOC
soc = SOC()
# The SOC is turned into a submodule (fragment) of our top level module.
m.submodules.soc = soc
# The platform allows access to the various resources defined by the board
# definition from amaranth-boards.
led0 = platform.request('led', 0)
led1 = platform.request('led', 1)
led2 = platform.request('led', 2)
led3 = platform.request('led', 3)
rgb = platform.request('rgb_led')
# We connect the SOC leds signal to the various LEDs on the board.
m.d.comb += [
led0.o.eq(soc.leds[0]),
led1.o.eq(soc.leds[1]),
led1.o.eq(soc.leds[2]),
led1.o.eq(soc.leds[3]),
rgb.r.o.eq(soc.leds[4]),
]
# To generate the bitstream, we build() the platform using our top level
# module m.
platform.build(m, do_program=False)

205
14_subroutines_v2/cpu.py Normal file
View File

@@ -0,0 +1,205 @@
from amaranth import *
class CPU(Elaboratable):
def __init__(self):
self.mem_addr = Signal(32)
self.mem_rstrb = Signal()
self.mem_rdata = Signal(32)
self.x10 = Signal(32)
self.fsm = None
def elaborate(self, platform):
m = Module()
# Program counter
pc = Signal(32)
self.pc = pc
# Current instruction
instr = Signal(32, reset=0b0110011)
self.instr = instr
# Register bank
regs = Array([Signal(32, name="x"+str(x)) for x in range(32)])
rs1 = Signal(32)
rs2 = Signal(32)
# ALU registers
aluOut = Signal(32)
takeBranch = Signal(32)
# Opcode decoder
# It is nice to have these as actual signals for simulation
isALUreg = Signal()
isALUimm = Signal()
isBranch = Signal()
isJALR = Signal()
isJAL = Signal()
isAUIPC = Signal()
isLUI = Signal()
isLoad = Signal()
isStore = Signal()
isSystem = Signal()
m.d.comb += [
isALUreg.eq(instr[0:7] == 0b0110011),
isALUimm.eq(instr[0:7] == 0b0010011),
isBranch.eq(instr[0:7] == 0b1100011),
isJALR.eq(instr[0:7] == 0b1100111),
isJAL.eq(instr[0:7] == 0b1101111),
isAUIPC.eq(instr[0:7] == 0b0010111),
isLUI.eq(instr[0:7] == 0b0110111),
isLoad.eq(instr[0:7] == 0b0000011),
isStore.eq(instr[0:7] == 0b0100011),
isSystem.eq(instr[0:7] == 0b1110011)
]
self.isALUreg = isALUreg
self.isALUimm = isALUimm
self.isBranch = isBranch
self.isLoad = isLoad
self.isStore = isStore
self.isSystem = isSystem
def Extend(x, n):
return [x for i in range(n + 1)]
# Immediate format decoder
Uimm = Cat(Const(0, 12), instr[12:32])
Iimm = Cat(instr[20:31], *Extend(instr[31], 21))
Simm = Cat(instr[7:12], instr[25:31], *Extend(instr[31], 21))
Bimm = Cat(0, instr[8:12], instr[25:31], instr[7],
*Extend(instr[31], 20))
Jimm = Cat(0, instr[21:31], instr[20], instr[12:20],
*Extend(instr[31], 12))
self.Iimm = Iimm
# Register addresses decoder
rs1Id = instr[15:20]
rs2Id = instr[20:25]
rdId = instr[7:12]
self.rdId = rdId
self.rs1Id = rs1Id
self.rs2Id = rs2Id
# Function code decdore
funct3 = instr[12:15]
funct7 = instr[25:32]
self.funct3 = funct3
# ALU
aluIn1 = Signal.like(rs1)
aluIn2 = Signal.like(rs2)
shamt = Signal(5)
aluMinus = Signal(33)
aluPlus = Signal.like(aluIn1)
m.d.comb += [
aluIn1.eq(rs1),
aluIn2.eq(Mux((isALUreg | isBranch), rs2, Iimm)),
shamt.eq(Mux(isALUreg, rs2[0:5], instr[20:25]))
]
# Wire memory address to pc
m.d.comb += self.mem_addr.eq(pc)
m.d.comb += [
aluMinus.eq(Cat(~aluIn1, C(0,1)) + Cat(aluIn2, C(0,1)) + 1),
aluPlus.eq(aluIn1 + aluIn2)
]
EQ = aluMinus[0:32] == 0
LTU = aluMinus[32]
LT = Mux((aluIn1[31] ^ aluIn2[31]), aluIn1[31], aluMinus[32])
def flip32(x):
a = [x[i] for i in range(0, 32)]
return Cat(*reversed(a))
# TODO: check these again!
shifter_in = Mux(funct3 == 0b001, flip32(aluIn1), aluIn1)
shifter = Cat(shifter_in, (instr[30] & aluIn1[31])) >> aluIn2[0:5]
leftshift = flip32(shifter)
with m.Switch(funct3) as alu:
with m.Case(0b000):
m.d.comb += aluOut.eq(Mux(funct7[5] & instr[5],
aluMinus[0:32], aluPlus))
with m.Case(0b001):
m.d.comb += aluOut.eq(leftshift)
with m.Case(0b010):
m.d.comb += aluOut.eq(LT)
with m.Case(0b011):
m.d.comb += aluOut.eq(LTU)
with m.Case(0b100):
m.d.comb += aluOut.eq(aluIn1 ^ aluIn2)
with m.Case(0b101):
m.d.comb += aluOut.eq(shifter)
with m.Case(0b110):
m.d.comb += aluOut.eq(aluIn1 | aluIn2)
with m.Case(0b111):
m.d.comb += aluOut.eq(aluIn1 & aluIn2)
with m.Switch(funct3) as alu_branch:
with m.Case(0b000):
m.d.comb += takeBranch.eq(EQ)
with m.Case(0b001):
m.d.comb += takeBranch.eq(~EQ)
with m.Case(0b100):
m.d.comb += takeBranch.eq(LT)
with m.Case(0b101):
m.d.comb += takeBranch.eq(~LT)
with m.Case(0b110):
m.d.comb += takeBranch.eq(LTU)
with m.Case(0b111):
m.d.comb += takeBranch.eq(~LTU)
with m.Case("---"):
m.d.comb += takeBranch.eq(0)
# Next program counter is either next intstruction or depends on
# jump target
pcPlusImm = pc + Mux(instr[3], Jimm[0:32],
Mux(instr[4], Uimm[0:32],
Bimm[0:32]))
pcPlus4 = pc + 4
nextPc = Mux(((isBranch & takeBranch) | isJAL), pcPlusImm,
Mux(isJALR, Cat(C(0, 1), aluPlus[1:32]),
pcPlus4))
# Main state machine
with m.FSM(reset="FETCH_INSTR") as fsm:
self.fsm = fsm
m.d.comb += self.mem_rstrb.eq(fsm.ongoing("FETCH_INSTR"))
with m.State("FETCH_INSTR"):
m.next = "WAIT_INSTR"
with m.State("WAIT_INSTR"):
m.d.sync += instr.eq(self.mem_rdata)
m.next = ("FETCH_REGS")
with m.State("FETCH_REGS"):
m.d.sync += [
rs1.eq(regs[rs1Id]),
rs2.eq(regs[rs2Id])
]
m.next = "EXECUTE"
with m.State("EXECUTE"):
with m.If(~isSystem):
m.d.sync += pc.eq(nextPc)
m.next = "FETCH_INSTR"
# Register write back
writeBackData = Mux((isJAL | isJALR), pcPlus4,
Mux(isLUI, Uimm,
Mux(isAUIPC, pcPlusImm, aluOut)))
writeBackEn = fsm.ongoing("EXECUTE") & ~isBranch & ~isStore
self.writeBackData = writeBackData
with m.If(writeBackEn & (rdId != 0)):
m.d.sync += regs[rdId].eq(writeBackData)
# Also assign to debug output to see what is happening
with m.If(rdId == 10):
m.d.sync += self.x10.eq(writeBackData)
return m

View File

@@ -0,0 +1,46 @@
from amaranth import *
from riscv_assembler import RiscvAssembler
class Memory(Elaboratable):
def __init__(self):
a = RiscvAssembler()
a.read("""begin:
LI a0, 0
l0:
ADDI a0, a0, 1
CALL wait
J l0
EBREAK
wait:
LI a1, 1
SLLI a1, a1, 20
l1:
ADDI a1, a1, -1
BNEZ a1, l1
RET
""")
a.assemble()
self.instructions = a.mem
print("memory = {}".format(self.instructions))
# Instruction memory initialised with above instructions
self.mem = Array([Signal(32, reset=x, name="mem")
for x in self.instructions])
self.mem_addr = Signal(32)
self.mem_rdata = Signal(32)
self.mem_rstrb = Signal()
def elaborate(self, platform):
m = Module()
with m.If(self.mem_rstrb):
m.d.sync += self.mem_rdata.eq(self.mem[self.mem_addr[2:32]])
return m

82
14_subroutines_v2/soc.py Normal file
View File

@@ -0,0 +1,82 @@
import sys
from amaranth import *
from clockworks import Clockworks
from memory import Memory
from cpu import CPU
class SOC(Elaboratable):
def __init__(self):
self.leds = Signal(5)
# Signals in this list can easily be plotted as vcd traces
self.ports = []
def elaborate(self, platform):
m = Module()
cw = Clockworks()
memory = DomainRenamer("slow")(Memory())
cpu = DomainRenamer("slow")(CPU())
m.submodules.cw = cw
m.submodules.cpu = cpu
m.submodules.memory = memory
self.cpu = cpu
self.memory = memory
x10 = Signal(32)
# Connect memory to CPU
m.d.comb += [
memory.mem_addr.eq(cpu.mem_addr),
memory.mem_rstrb.eq(cpu.mem_rstrb),
cpu.mem_rdata.eq(memory.mem_rdata)
]
# CPU debug output
m.d.comb += [
x10.eq(cpu.x10),
self.leds.eq(x10[0:5])
]
# Export signals for simulation
def export(signal, name):
if type(signal) is not Signal:
newsig = Signal(signal.shape(), name = name)
m.d.comb += newsig.eq(signal)
else:
newsig = signal
self.ports.append(newsig)
setattr(self, name, newsig)
if platform is None:
export(ClockSignal("slow"), "slow_clk")
#export(pc, "pc")
#export(instr, "instr")
#export(isALUreg, "isALUreg")
#export(isALUimm, "isALUimm")
#export(isBranch, "isBranch")
#export(isJAL, "isJAL")
#export(isJALR, "isJALR")
#export(isLoad, "isLoad")
#export(isStore, "isStore")
#export(isSystem, "isSystem")
#export(rdId, "rdId")
#export(rs1Id, "rs1Id")
#export(rs2Id, "rs2Id")
#export(Iimm, "Iimm")
#export(Bimm, "Bimm")
#export(Jimm, "Jimm")
#export(funct3, "funct3")
#export(rdId, "rdId")
#export(rs1, "rs1")
#export(rs2, "rs2")
#export(writeBackData, "writeBackData")
#export(writeBackEn, "writeBackEn")
#export(aluOut, "aluOut")
#export((1 << cpu.fsm.state), "state")
return m

View File

@@ -11,6 +11,7 @@ class Top(Elaboratable):
print("step = {}".format(step))
self.leds = leds
# TODO: this is messy and should be done with iterating over dirs
if step == 1:
path = "01_blink"
elif step == 2:
@@ -37,6 +38,8 @@ class Top(Elaboratable):
path = "12_size_optimisation"
elif step == 13:
path = "13_subroutines"
elif step == 14:
path = "14_subroutines_v2"
else:
print("Invalid step_number {}.".format(step))
exit(1)

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3
import re
# instructions
@@ -84,6 +85,37 @@ SysInstructions = [
]
SysOps = [x[0] for x in SysInstructions]
PseudoInstructions = [
("LI",),
("CALL",),
("RET",),
("MV",),
("NOP",),
("J",),
("BEQZ",),
("BNEZ",),
("BGT",),
]
PseudoOps = [x[0] for x in PseudoInstructions]
class LabelRef():
def __init__(self, op, name, arg):
self.op = op
self.name = name
self.arg = arg
def __repr__(self):
text = "LABELREF({:4} {} {})".format(self.op, self.name, self.arg)
return text
@classmethod
def fromString(cls, string):
r = re.compile('[ ()]+')
args = r.split(string)
op = args[1]
name = args[2]
arg = args[3]
# print(args)
return cls(op, name, arg)
class Instruction():
def __init__(self, op, *args):
self.op = op
@@ -145,6 +177,7 @@ class RiscvAssembler():
def __init__(self):
self.pc = 0
self.labels = {}
self.pseudos = {}
self.instructions = []
self.mem = []
@@ -250,6 +283,58 @@ class RiscvAssembler():
else:
print("Unhandled system op {}".format(op))
def unravelPseudoOps(self, instruction):
op = instruction.op
instr = []
if op == "NOP":
instr.append(self.iFromLine("ADD x0, x0, x0"))
elif op == "LI":
rd = instruction.args[0]
imm = self.imm2int(instruction.args[1])
if imm == 0:
instr.append(self.iFromLine("ADD {}, zero, zero".format(rd)))
elif -2048 <= imm < 2048:
instr.append(self.iFromLine("ADDI {}, zero, {}".format(
rd, imm)))
else:
imm2 = hex(imm + ((imm & 0x800) << 12))
imm12 = hex(imm & 0xfff)
instr.append(self.iFromLine("LUI {}, {}".format(rd, imm2)))
if imm12 != 0:
instr.append(self.iFromLine("ADDI {}, {}, {}".format(
rd, rd, imm12)))
elif op == "CALL":
ref1 = LabelRef(op, "offset", instruction.args[0])
ref2 = LabelRef(op, "offset12", instruction.args[0])
instr.append(self.iFromLine("AUIPC x6, {}".format(ref1)))
instr.append(self.iFromLine("JALR x1, x6, {}".format(ref2)))
elif op == "RET":
instr.append(self.iFromLine("JALR x0, x1, 0"))
elif op == "MV":
rd = instruction.args[0]
rs1 = instruction.args[1]
instr.append(self.iFromLine("ADD {}, {}, zero".format(rd, rs1)))
elif op == "J":
ref = LabelRef(op, "imm", instruction.args[0])
instr.append(self.iFromLine("JAL zero, {}".format(ref)))
elif op == "BEQZ":
rs1 = instruction.args[0]
ref = LabelRef(op, "imm", instruction.args[1])
instr.append(self.iFromLine("BEQ {}, x0, {}".format(rs1, ref)))
elif op == "BNEZ":
rs1 = instruction.args[0]
ref = LabelRef(op, "imm", instruction.args[1])
instr.append(self.iFromLine("BNE {}, x0, {}".format(rs1, ref)))
elif op == "BGT":
rs1 = instruction.args[0]
rs2 = instruction.args[1]
ref = LabelRef(op, "imm", instruction.args[2])
instr.append(self.iFromLine("BLT {}, {}, {}".format(
rs2, rs1, ref)))
else:
return [instruction], False
return instr, True
def encode(self, instruction):
encoded = 0
if instruction.op in ROps:
@@ -273,11 +358,30 @@ class RiscvAssembler():
else:
print("Unhandled instruction / opcode {}".format(instruction))
exit(1)
for l in self.labels:
if self.labels[l] == self.pc:
print(" lab@pc=0x{:03x}={} -> {}".format(self.pc, self.pc, l))
if self.pc in self.pseudos:
print(" psu@pc=0x{:03x}={} -> {}".format(self.pc, self.pc,
self.pseudos[self.pc]))
print(" enc@pc=0x{:03x} {} -> 0b{:032b}".format(
self.pc, instruction, encoded))
self.pc += 4
return encoded
def iFromLine(self, line):
line = line.strip()
if len(line) == 0:
return None
if ' ' not in line:
return Instruction(line)
else:
op, rest = [x.strip().upper() for x in (
line.split(' ', maxsplit=1))]
# print("op = {}, rest = {}".format(op, rest))
items = [x.strip() for x in rest.split(',')]
return Instruction(op, *items)
def read(self, text):
instructions = []
for line in text.splitlines():
@@ -288,18 +392,15 @@ class RiscvAssembler():
pc = len(instructions) * 4
self.labels[label.upper()] = pc
print("found label '{}', pc = {}".format(label, pc))
if len(line) == 0:
continue
if ' ' not in line:
i = Instruction(line)
else:
op, rest = [x.strip().upper() for x in (
line.split(' ', maxsplit=1))]
# print("op = {}, rest = {}".format(op, rest))
items = [x.strip() for x in rest.split(',')]
i = Instruction(op, *items)
i = self.iFromLine(line)
if i is not None:
instructions.append(i)
unravelled, isPseudo = self.unravelPseudoOps(i)
if isPseudo:
pc = len(instructions) * 4
self.pseudos[pc] = i.op
print("found peudo '{}', pc = {}".format(i.op, pc))
for u in unravelled:
instructions.append(u)
self.instructions += instructions
def imm2int(self, arg):
@@ -309,12 +410,28 @@ class RiscvAssembler():
offset = self.labels[arg] - self.pc
# print("label offset = {}".format(offset))
return offset
if arg.startswith("LABELREF"):
print(" found labelref")
l = LabelRef.fromString(arg)
if l.op == "CALL":
offset = self.imm2int(l.arg)
print(" resolving label {} -> {}".format(l.arg, offset))
# print("offset = {}".format(offset))
if l.name == "OFFSET":
return offset
if l.name == "OFFSET12":
return (offset + 4) & 0xfff
elif (l.op in ["J", "BEQZ", "BNEZ", "BGT"]):
if l.name == "IMM":
imm = self.imm2int(l.arg)
print(" resolving label {} -> {}".format(l.arg, imm))
return imm
try:
return int(arg)
except ValueError as e:
if 'B' in arg.upper():
if 'B' in arg.upper()[1]:
return int(arg, 2)
elif 'X' in arg.upper():
elif 'X' in arg.upper()[1]:
return int(arg, 16)
else:
raise ValueError("Can't parse arg {}".format(arg))
@@ -343,6 +460,7 @@ if __name__ == "__main__":
jumps:
JAL x4, 255
JALR x5, x7, start
JALR x5, x7, future
branches:
BEQ x3, x4, 1
BNE x3, x4, 1
@@ -350,6 +468,7 @@ if __name__ == "__main__":
BGE x3, x4, 1
BLTU x3, x4, 1
BGEU x3, x4, 1
future:
luiandauipc:
lui: LUI x5, 0x30000
AUIPC x5, 0x30000
@@ -364,6 +483,26 @@ if __name__ == "__main__":
SB x7, x10, 1
SH x7, x10, 2
SW x7, x10, 3
before_li:
LI x3, 400
after_li:
LI a1, 0
LI a2, 128
LI a3, 4000
LI a4, 0x2000
test_other_pseudos:
CALL load
CALL futurelabel
RET
test_mv:
MV x2, x3
NOP
J after_li
BEQZ a2, store
BNEZ a1, store
BGT a3, a2, store
futurelabel:
NOP
EBREAK
""")
print(a.instructions)