library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

library work;
use work.common.all;

entity rotator is
    port (rs: in std_ulogic_vector(63 downto 0);
          ra: in std_ulogic_vector(63 downto 0);
          shift: in std_ulogic_vector(6 downto 0);
          insn: in std_ulogic_vector(31 downto 0);
          is_32bit: in std_ulogic;
          right_shift: in std_ulogic;
          arith: in std_ulogic;
          clear_left: in std_ulogic;
          clear_right: in std_ulogic;
          sign_ext_rs: in std_ulogic;
          result: out std_ulogic_vector(63 downto 0);
          carry_out: out std_ulogic
      );
end entity rotator;

architecture behaviour of rotator is
    signal repl32: std_ulogic_vector(63 downto 0);
    signal rot_count: std_ulogic_vector(5 downto 0);
    signal rot1, rot2, rot: std_ulogic_vector(63 downto 0);
    signal sh, mb, me: std_ulogic_vector(6 downto 0);
    signal mr, ml: std_ulogic_vector(63 downto 0);
    signal output_mode: std_ulogic_vector(1 downto 0);

    -- note BE bit numbering
    function right_mask(mask_begin: std_ulogic_vector(6 downto 0)) return std_ulogic_vector is
        variable ret: std_ulogic_vector(63 downto 0);
    begin
        ret := (others => '0');
	if is_X(mask_begin) then
	    ret := (others => 'X');
	    return ret;
	end if;
        for i in 0 to 63 loop
            if i >= to_integer(unsigned(mask_begin)) then
                ret(63 - i) := '1';
            end if;
        end loop;
        return ret;
    end;

    function left_mask(mask_end: std_ulogic_vector(6 downto 0)) return std_ulogic_vector is
        variable ret: std_ulogic_vector(63 downto 0);
    begin
        ret := (others => '0');
        if mask_end(6) = '0' then
            for i in 0 to 63 loop
                if i <= to_integer(unsigned(mask_end)) then
                    ret(63 - i) := '1';
                end if;
            end loop;
        end if;
        return ret;
    end;

begin
    rotator_0: process(all)
        variable hi32: std_ulogic_vector(31 downto 0);
    begin
        -- First replicate bottom 32 bits to both halves if 32-bit
        if is_32bit = '1' then
            hi32 := rs(31 downto 0);
        elsif sign_ext_rs = '1' then
            -- sign extend bottom 32 bits
            hi32 := (others => rs(31));
        else
            hi32 := rs(63 downto 32);
        end if;
        repl32 <= hi32 & rs(31 downto 0);

        -- Negate shift count for right shifts
        if right_shift = '1' then
            rot_count <= std_ulogic_vector(- signed(shift(5 downto 0)));
        else
            rot_count <= shift(5 downto 0);
        end if;

        -- Rotator works in 3 stages using 2 bits of the rotate count each
        -- time.  This gives 4:1 multiplexors which is ideal for the 6-input
        -- LUTs in the Xilinx Artix 7.
        -- We look at the low bits of the rotate count first because they will
        -- have less delay through the negation above.
        -- First rotate by 0, 1, 2, or 3
        case rot_count(1 downto 0) is
            when "00" =>
                rot1 <= repl32;
            when "01" =>
                rot1 <= repl32(62 downto 0) & repl32(63);
            when "10" =>
                rot1 <= repl32(61 downto 0) & repl32(63 downto 62);
            when others =>
                rot1 <= repl32(60 downto 0) & repl32(63 downto 61);
        end case;
        -- Next rotate by 0, 4, 8 or 12
        case rot_count(3 downto 2) is
            when "00" =>
                rot2 <= rot1;
            when "01" =>
                rot2 <= rot1(59 downto 0) & rot1(63 downto 60);
            when "10" =>
                rot2 <= rot1(55 downto 0) & rot1(63 downto 56);
            when others =>
                rot2 <= rot1(51 downto 0) & rot1(63 downto 52);
        end case;
        -- Lastly rotate by 0, 16, 32 or 48
        case rot_count(5 downto 4) is
            when "00" =>
                rot <= rot2;
            when "01" =>
                rot <= rot2(47 downto 0) & rot2(63 downto 48);
            when "10" =>
                rot <= rot2(31 downto 0) & rot2(63 downto 32);
            when others =>
                rot <= rot2(15 downto 0) & rot2(63 downto 16);
        end case;

        -- Trim shift count to 6 bits for 32-bit shifts
        sh <= (shift(6) and not is_32bit) & shift(5 downto 0);

        -- Work out mask begin/end indexes (caution, big-endian bit numbering)
        if clear_left = '1' then
            if is_32bit = '1' then
                mb <= "01" & insn(10 downto 6);
            else
                mb <= "0" & insn(5) & insn(10 downto 6);
            end if;
        elsif right_shift = '1' then
            -- this is basically mb <= sh + (is_32bit? 32: 0);
            if is_32bit = '1' then
                mb <= sh(5) & not sh(5) & sh(4 downto 0);
            else
                mb <= sh;
            end if;
        else
            mb <= ('0' & is_32bit & "00000");
        end if;
        if clear_right = '1' and is_32bit = '1' then
            me <= "01" & insn(5 downto 1);
        elsif clear_right = '1' and clear_left = '0' then
            me <= "0" & insn(5) & insn(10 downto 6);
        else
            -- effectively, 63 - sh
            me <= sh(6) & not sh(5 downto 0);
        end if;

        -- Calculate left and right masks
        mr <= right_mask(mb);
        ml <= left_mask(me);

        -- Work out output mode
        -- 00 for sl[wd]
        -- 0w for rlw*, rldic, rldicr, rldimi, where w = 1 iff mb > me
        -- 10 for rldicl, sr[wd]
        -- 1z for sra[wd][i], z = 1 if rs is negative
        if (clear_left = '1' and clear_right = '0') or right_shift = '1' then
            output_mode(1) <= '1';
            output_mode(0) <= arith and repl32(63);
        else
            output_mode(1) <= '0';
            if clear_right = '1' and unsigned(mb(5 downto 0)) > unsigned(me(5 downto 0)) then
                output_mode(0) <= '1';
            else
                output_mode(0) <= '0';
            end if;
        end if;

        -- Generate output from rotated input and masks
        case output_mode is
            when "00" =>
                result <= (rot and (mr and ml)) or (ra and not (mr and ml));
            when "01" =>
                result <= (rot and (mr or ml)) or (ra and not (mr or ml));
            when "10" =>
                result <= rot and mr;
            when others =>
                result <= rot or not mr;
        end case;

        -- Generate carry output for arithmetic shift right of negative value
        if output_mode = "11" then
            carry_out <= or (rs and not ml);
        else
            carry_out <= '0';
        end if;
    end process;
end behaviour;