import inspect

from amaranth import *
from amaranth.lib.coding import Encoder

from power_fv import pfv
from power_fv.check.insn import InsnCheck
from power_fv.check.insn import all as all_checks
from power_fv.reg import *


__all__ = ["Context", "Model"]


def _all_specs(**kwargs):
    for name, obj in inspect.getmembers(all_checks):
        if inspect.isclass(obj) and issubclass(obj, InsnCheck):
            insn = obj.insn_cls(name=name.lower())
            spec = obj.spec_cls(insn, **kwargs)
            yield spec


class Context(Elaboratable):
    def __init__(self, *, mem_size, **kwargs):
        self.gpr  = Array(Signal(64, name=f"G{i}") for i in range(32))
        self.mem  = Array(Signal(64, name=f"M{i}") for i in range(mem_size*8//64))
        self.iar  = Record( _ea_layout)
        self.cr   = Record(  cr_layout)
        self.msr  = Record( msr_layout)
        self.lr   = Record(  lr_layout)
        self.ctr  = Record( ctr_layout)
        self.tar  = Record( tar_layout)
        self.xer  = Record( xer_layout)
        self.srr0 = Record(srr0_layout)
        self.srr1 = Record(srr1_layout)

        self.pfv  = pfv.Interface(**kwargs)

    def connect_outputs(self, spec):
        stmts = []

        stmts += [spec.pfv.cia.eq(self.pfv.cia)]

        for field in ("ra", "rb", "rs", "rt",
                      "mem",
                      "cr", "msr", "lr", "ctr", "tar", "xer", "srr0", "srr1"):
            self_field = getattr(self.pfv, field)
            spec_field = getattr(spec.pfv, field)
            stmts += [spec_field.r_data.eq(self_field.r_data)]

        return stmts

    def connect_inputs(self, spec):
        stmts = []

        stmts += [self.pfv.nia.eq(spec.pfv.nia)]

        for gpr_field in ("ra", "rb", "rs", "rt"):
            self_field = getattr(self.pfv, gpr_field)
            spec_field = getattr(spec.pfv, gpr_field)
            stmts += [
                self_field.index .eq(spec_field.index ),
                self_field.r_stb .eq(spec_field.r_stb ),
                self_field.w_stb .eq(spec_field.w_stb ),
                self_field.w_data.eq(spec_field.w_data),
            ]

        stmts += [
            self.pfv.mem.addr  .eq(spec.pfv.mem.addr),
            self.pfv.mem.r_mask.eq(spec.pfv.mem.r_mask),
            self.pfv.mem.w_mask.eq(spec.pfv.mem.w_mask),
            self.pfv.mem.w_data.eq(spec.pfv.mem.w_data),
        ]

        for reg_field in ("cr", "msr", "lr", "ctr", "tar", "xer", "srr0", "srr1"):
            self_field = getattr(self.pfv, reg_field)
            spec_field = getattr(spec.pfv, reg_field)
            stmts += [
                self_field.r_mask.eq(spec_field.r_mask),
                self_field.w_mask.eq(spec_field.w_mask),
                self_field.w_data.eq(spec_field.w_data),
            ]

        return stmts

    def elaborate(self, platform):
        m = Module()

        m.d.comb += [
            self.pfv.cia.eq(self.iar),

            self.pfv.ra .r_data.eq(self.gpr[self.pfv.ra.index]),
            self.pfv.rb .r_data.eq(self.gpr[self.pfv.rb.index]),
            self.pfv.rs .r_data.eq(self.gpr[self.pfv.rs.index]),
            self.pfv.rt .r_data.eq(self.gpr[self.pfv.rt.index]),
            self.pfv.mem.r_data.eq(self.mem[self.pfv.mem.addr]),

            self.pfv.cr  .r_data.eq(self.cr  ),
            self.pfv.msr .r_data.eq(self.msr ),
            self.pfv.lr  .r_data.eq(self.lr  ),
            self.pfv.ctr .r_data.eq(self.ctr ),
            self.pfv.tar .r_data.eq(self.tar ),
            self.pfv.xer .r_data.eq(self.xer ),
            self.pfv.srr0.r_data.eq(self.srr0),
            self.pfv.srr1.r_data.eq(self.srr1),
        ]

        with m.If(self.pfv.stb):
            m.d.sync += self.iar.eq(self.pfv.nia)

            for gpr_field in ("ra", "rb", "rs", "rt"):
                port = getattr(self.pfv, gpr_field)
                with m.If(port.w_stb):
                    m.d.sync += self.gpr[port.index].eq(port.w_data)

            mem_value = self.mem[self.pfv.mem.addr]
            m.d.sync += mem_value.eq(self.pfv.mem.w_data & self.pfv.mem.w_mask | mem_value & ~self.pfv.mem.w_mask)

            for reg_field in ("cr", "msr", "lr", "ctr", "tar", "xer", "srr0", "srr1"):
                port   = getattr(self.pfv, reg_field)
                shadow = getattr(self,     reg_field)
                m.d.sync += shadow.eq(port.w_data & port.w_mask | shadow & ~port.w_mask)

        return m


class Model(Elaboratable):
    def __init__(self, *, mem_size, **kwargs):
        self.specs = list(_all_specs(**kwargs))
        self.ctx   = Context(mem_size=mem_size)

        self.stb   = Signal()
        self.insn  = Signal(64)
        self.err   = Record([("insn", 1)])

    def elaborate(self, platform):
        m = Module()

        # - `self.insn` is wired to the `pfv.insn` input of each spec. If the spec recognizes an
        #   instruction encoding, it asserts its `pfv.stb` output.
        # - If the instruction is recognized by exactly one spec, then the context is updated
        #   according to its execution side-effects. Otherwise, `self.err.insn` is asserted.

        m.submodules.ctx = self.ctx
        m.submodules.enc = enc = Encoder(width=len(self.specs))

        for j, spec in enumerate(self.specs):
            m.submodules[f"spec_{spec.insn.name}"] = spec
            m.d.comb += [
                spec.pfv.insn.eq(self.insn),
                self.ctx.connect_outputs(spec),
            ]
            m.d.comb += enc.i[j].eq(spec.pfv.stb)

        m.d.comb += [
            self.ctx.pfv.insn.eq(self.insn),
            self.ctx.pfv.stb .eq(self.stb & ~enc.n),
        ]

        with m.Switch(enc.o):
            for j, spec in enumerate(self.specs):
                with m.Case(j):
                    m.d.comb += self.ctx.connect_inputs(spec)

            with m.Default():
                m.d.comb += self.err.insn.eq(1)

        return m