@ -1,33 +1,32 @@
from amaranth import *
from amaranth import *
from amaranth.asserts import *
from amaranth.asserts import *
from .. import pfv
from .. import PowerFVCheck
from ..insn import *
from ... import pfv, tb
from ..utils import iea_mask
from ...utils import iea_mask
from ._fmt import *
from ._insn import *
__all__ = ["Check"]
__all__ = ["BranchSpec", "BranchCheck"]
class Check(Elaboratable):
_insn_cls = None
def __init_subclass__(cls, *, insn_cls):
class BranchSpec(Elaboratable):
cls._insn_cls = insn_cls
def __init__(self, insn_cls, post):
self.insn_cls = insn_cls
self.pfv = pfv.Interface()
self.post = tb.Trigger(cycle=post)
def __init__(self):
def triggers(self):
self.pfv = pfv.Interface()
yield self.post
self.trig = Record([
("pre", 1),
("post", 1),
])
def elaborate(self, platform):
def elaborate(self, platform):
m = Module()
m = Module()
spec_insn = self._insn_cls()
spec_insn = self.insn_cls()
with m.If(self.trig.post):
with m.If(self.post.stb):
m.d.sync += [
m.d.sync += [
Assume(self.pfv.stb),
Assume(self.pfv.stb),
Assume(self.pfv.insn[32:] == spec_insn),
Assume(self.pfv.insn[32:] == spec_insn),
@ -37,7 +36,7 @@ class Check(Elaboratable):
msr_w_sf = Signal()
msr_w_sf = Signal()
m.d.comb += msr_w_sf.eq(self.pfv.msr.w_data[63])
m.d.comb += msr_w_sf.eq(self.pfv.msr.w_data[63])
if isinstance(spec_insn, (Instruction_B, Instruction_XL_b)):
if isinstance(spec_insn, (Instruction_B, Instruction_XL_bc)):
bo_valid_patterns = [
bo_valid_patterns = [
"001--",
"001--",
"011--",
"011--",
@ -57,7 +56,7 @@ class Check(Elaboratable):
bo_invalid = Signal()
bo_invalid = Signal()
m.d.comb += bo_invalid.eq(~spec_insn.bo.matches(*bo_valid_patterns))
m.d.comb += bo_invalid.eq(~spec_insn.bo.matches(*bo_valid_patterns))
with m.If(self.trig.post):
with m.If(self.post.stb):
m.d.sync += Assert(bo_invalid.implies(self.pfv.intr))
m.d.sync += Assert(bo_invalid.implies(self.pfv.intr))
# NIA
# NIA
@ -73,7 +72,7 @@ class Check(Elaboratable):
offset.eq(spec_insn.li)
offset.eq(spec_insn.li)
]
]
elif isinstance(spec_insn, (Instruction_B, Instruction_XL_b)):
elif isinstance(spec_insn, (Instruction_B, Instruction_XL_bc)):
cond_bit = Signal()
cond_bit = Signal()
ctr_any = Signal()
ctr_any = Signal()
cond_ok = Signal()
cond_ok = Signal()
@ -116,7 +115,7 @@ class Check(Elaboratable):
m.d.comb += spec_nia.eq(iea_mask(target, msr_w_sf))
m.d.comb += spec_nia.eq(iea_mask(target, msr_w_sf))
with m.If(self.trig.post & ~self.pfv.intr):
with m.If(self.post.stb & ~self.pfv.intr):
m.d.sync += Assert(self.pfv.nia == spec_nia)
m.d.sync += Assert(self.pfv.nia == spec_nia)
# CR
# CR
@ -125,12 +124,12 @@ class Check(Elaboratable):
if isinstance(spec_insn, Instruction_I):
if isinstance(spec_insn, Instruction_I):
m.d.comb += spec_cr_r_stb.eq(0)
m.d.comb += spec_cr_r_stb.eq(0)
elif isinstance(spec_insn, (Instruction_B, Instruction_XL_b)):
elif isinstance(spec_insn, (Instruction_B, Instruction_XL_bc)):
m.d.comb += spec_cr_r_stb[::-1].bit_select(spec_insn.bi[2:], width=1).eq(1)
m.d.comb += spec_cr_r_stb[::-1].bit_select(spec_insn.bi[2:], width=1).eq(1)
else:
else:
assert False
assert False
with m.If(self.trig.post & ~self.pfv.intr):
with m.If(self.post.stb & ~self.pfv.intr):
for i, spec_cr_r_stb_bit in enumerate(spec_cr_r_stb):
for i, spec_cr_r_stb_bit in enumerate(spec_cr_r_stb):
pfv_cr_r_stb_bit = self.pfv.cr.r_stb[i]
pfv_cr_r_stb_bit = self.pfv.cr.r_stb[i]
m.d.sync += Assert(spec_cr_r_stb_bit.implies(pfv_cr_r_stb_bit))
m.d.sync += Assert(spec_cr_r_stb_bit.implies(pfv_cr_r_stb_bit))
@ -143,7 +142,7 @@ class Check(Elaboratable):
if isinstance(spec_insn, (Instruction_I, Instruction_B)):
if isinstance(spec_insn, (Instruction_I, Instruction_B)):
m.d.comb += spec_lr_r_stb.eq(0)
m.d.comb += spec_lr_r_stb.eq(0)
elif isinstance(spec_insn, (Instruction_XL_b)):
elif isinstance(spec_insn, (Instruction_XL_bc)):
if isinstance(spec_insn, (BCLR, BCLRL)):
if isinstance(spec_insn, (BCLR, BCLRL)):
m.d.comb += spec_lr_r_stb.eq(1)
m.d.comb += spec_lr_r_stb.eq(1)
else:
else:
@ -159,7 +158,7 @@ class Check(Elaboratable):
spec_lr_w_data.eq(iea_mask(cia_4, msr_w_sf)),
spec_lr_w_data.eq(iea_mask(cia_4, msr_w_sf)),
]
]
with m.If(self.trig.post & ~self.pfv.intr):
with m.If(self.post.stb & ~self.pfv.intr):
m.d.sync += [
m.d.sync += [
Assert(self.pfv.lr.r_stb == spec_lr_r_stb),
Assert(self.pfv.lr.r_stb == spec_lr_r_stb),
Assert(self.pfv.lr.w_stb == spec_lr_w_stb),
Assert(self.pfv.lr.w_stb == spec_lr_w_stb),
@ -176,7 +175,7 @@ class Check(Elaboratable):
m.d.comb += spec_ctr_r_stb.eq(0)
m.d.comb += spec_ctr_r_stb.eq(0)
elif isinstance(spec_insn, Instruction_B):
elif isinstance(spec_insn, Instruction_B):
m.d.comb += spec_ctr_r_stb.eq(~spec_insn.bo[4-2])
m.d.comb += spec_ctr_r_stb.eq(~spec_insn.bo[4-2])
elif isinstance(spec_insn, Instruction_XL_b):
elif isinstance(spec_insn, Instruction_XL_bc):
if isinstance(spec_insn, (BCCTR, BCCTRL)):
if isinstance(spec_insn, (BCCTR, BCCTRL)):
m.d.comb += spec_ctr_r_stb.eq(1)
m.d.comb += spec_ctr_r_stb.eq(1)
else:
else:
@ -188,7 +187,7 @@ class Check(Elaboratable):
m.d.comb += spec_ctr_w_stb.eq(0)
m.d.comb += spec_ctr_w_stb.eq(0)
elif isinstance(spec_insn, Instruction_B):
elif isinstance(spec_insn, Instruction_B):
m.d.comb += spec_ctr_w_stb.eq(~spec_insn.bo[4-2])
m.d.comb += spec_ctr_w_stb.eq(~spec_insn.bo[4-2])
elif isinstance(spec_insn, Instruction_XL_b):
elif isinstance(spec_insn, Instruction_XL_bc):
if isinstance(spec_insn, (BCCTR, BCCTRL)):
if isinstance(spec_insn, (BCCTR, BCCTRL)):
m.d.comb += spec_ctr_w_stb.eq(0)
m.d.comb += spec_ctr_w_stb.eq(0)
else:
else:
@ -198,7 +197,7 @@ class Check(Elaboratable):
m.d.comb += spec_ctr_w_data.eq(self.pfv.ctr.r_data - 1)
m.d.comb += spec_ctr_w_data.eq(self.pfv.ctr.r_data - 1)
with m.If(self.trig.post & ~self.pfv.intr):
with m.If(self.post.stb & ~self.pfv.intr):
m.d.sync += [
m.d.sync += [
Assert(self.pfv.ctr.r_stb == spec_ctr_r_stb),
Assert(self.pfv.ctr.r_stb == spec_ctr_r_stb),
Assert(self.pfv.ctr.w_stb == spec_ctr_w_stb),
Assert(self.pfv.ctr.w_stb == spec_ctr_w_stb),
@ -211,7 +210,7 @@ class Check(Elaboratable):
if isinstance(spec_insn, (Instruction_I, Instruction_B)):
if isinstance(spec_insn, (Instruction_I, Instruction_B)):
m.d.comb += spec_tar_r_stb.eq(0)
m.d.comb += spec_tar_r_stb.eq(0)
elif isinstance(spec_insn, (Instruction_XL_b)):
elif isinstance(spec_insn, (Instruction_XL_bc)):
if isinstance(spec_insn, (BCTAR, BCTARL)):
if isinstance(spec_insn, (BCTAR, BCTARL)):
m.d.comb += spec_tar_r_stb.eq(1)
m.d.comb += spec_tar_r_stb.eq(1)
else:
else:
@ -219,9 +218,20 @@ class Check(Elaboratable):
else:
else:
assert False
assert False
with m.If(self.trig.post & ~self.pfv.intr):
with m.If(self.post.stb & ~self.pfv.intr):
m.d.sync += [
m.d.sync += [
Assert(self.pfv.tar.r_stb == spec_tar_r_stb),
Assert(self.pfv.tar.r_stb == spec_tar_r_stb),
]
]
return m
return m
class BranchCheck(PowerFVCheck, name=None):
def __init_subclass__(cls, name, insn_cls):
super().__init_subclass__(name)
cls.insn_cls = insn_cls
def get_testbench(self, dut, post):
tb_spec = BranchSpec(self.insn_cls, post)
tb_top = tb.Testbench(tb_spec, dut)
return tb_top