在程序 VHDL 中将两个半精度浮点数相乘

Multiplying two half-precision floats in procedural VHDL

我正在尝试在 VHDL 中实现程序化的半浮点数乘法函数。就目前情况而言,我有这个:

function "*" (l,r: half_float) return half_float is
    variable l_full_mantissa, r_full_mantissa: unsigned(l.mantissa'length downto 0);
    variable multiplied_mantissae: unsigned((l.mantissa'length + 1) * 2 - 1 downto 0);
    variable new_exponent: unsigned(fp16_exponent_len - 1 downto 0);
    begin

        if (l = fp16_zero or r = fp16_zero) then
            return fp16_zero;
        end if;
        report "L-Mantissa: " & to_string(l.mantissa);
        report "R-Mantissa: " & to_string(r.mantissa);
        report "L-Exponent: " & integer'image(to_integer(unsigned(l.exponent)));
        report "R-Exponent: " & integer'image(to_integer(unsigned(r.exponent)));
        new_exponent := unsigned(l.exponent) + (unsigned(r.exponent) - to_unsigned(fp16_exponent_bias, get_width_for_unsigned(fp16_exponent_bias))); -- Subtract the bias to prevent double counting
        report "N-Exponent: " & integer'image(to_integer(unsigned(new_exponent)));

        -- Prepend the leading 1s
        l_full_mantissa := unsigned('1' & l.mantissa);
        r_full_mantissa := unsigned('1' & r.mantissa);

        report integer'image(to_integer(l_full_mantissa));
        report integer'image(to_integer(r_full_mantissa));

        -- Multiply the mantissae
        multiplied_mantissae := l_full_mantissa * r_full_mantissa;
        multiplied_mantissae := multiplied_mantissae sll 2; -- <-- not sure about this

        report integer'image(to_integer(multiplied_mantissae));
        report to_string(multiplied_mantissae);

        return (l.sign xor r.sign, std_logic_vector(new_exponent), std_logic_vector(multiplied_mantissae(multiplied_mantissae'high downto multiplied_mantissae'high - fp16_mantissa_len + 1)));
    end function;

这似乎适用于所有这些测试用例除了最后一个:

        report "Testing 5x2";
        assert to_float(five) * to_float(two) = to_float(ten);

        report "Testing 2x42";
        assert to_float(two) * to_float(forty_two) = to_float(eighty_four);

        report "Testing 2.5x4";
        assert to_float(two_point_five_slv) * to_float(four) = to_float(ten);

        report "Testing -4x2.5";
        assert to_float(minus_four) * to_float(two_point_five_slv) = to_float(minus_ten);

        report "Testing 0.25x0.25";
        assert to_float(one_quarter) * to_float(one_quarter) = to_float(one_sixteenth);

        report "Testing 0.45x4" severity note;
        assert to_float(point_four_five) * to_float(four) = to_float(one_point_eight);

        report "Testing 0.0005x640";
        assert to_float(point_o_o_o_five) * to_float(six_forty) = to_float(point_three_two);

        report "Testing -0.96x-0.96";
        assert to_float(minus_point_nine_six) * to_float(minus_point_nine_six) = to_float(point_nine_two_one_four);

对于 -0.96x-0.96 我得到 0.4214 而不是 0.9216

我的函数中有一个 sll 2,我不完全确定为什么需要它。我也不确定是什么让我的最后一个案例与其他案例不同。因此,我希望那是我遗漏的地方。

有什么想法吗?

这并不完整,因为它没有解决所有边缘情况、无穷大或舍入模式,它也不支持次正规。但是,这似乎有效:

  1. 检查是否乘以零,return如果是则为零。
  2. 检查是否乘以 NaN,return如果是则为零。
  3. 通过将两个指数相加来计算中间指数。我们需要从中减去偏差以避免它被双重添加。
  4. 将推断的 1s 添加到尾数。
  5. 尾数相乘
  6. 将此尾数截断为您存储的尾数加 2 的长度。(x.yyyy... * y.zzzz... 始终等于 aa.bbbb... 如果 yyyy...zzzz... 是一样长)
  7. 2 个高位将始终在 {11, 10, 01} 中,因此我们检查高位并右移 1(如果已设置)。如果我们这样做,我们还必须增加指数。
  8. 然后我们将低 n-2 位作为我们的新尾数。
  9. l.sign xor r.sign.
  10. 获得合唱
  11. 连接所有批次以获得结果。

在 VHDL 中,这看起来像:

    function "*" (l,r: float) return float is
    variable l_full_mantissa, r_full_mantissa: unsigned(l.mantissa'length downto 0);
    variable multiplied_mantissae: unsigned((l.mantissa'length + 1) * 2 - 1 downto 0);
    variable new_exponent: unsigned(fp_exponent_len - 1 downto 0);
    variable truncated_mantissa: std_logic_vector(l_full_mantissa'length downto 0);
    begin

        -- Check edge conditions

        if (l = fp_zero or r = fp_zero) then
            return fp_zero;
        end if;

        if (l = nan or r = nan) then
            return nan;
        end if;

        -- Calculate the interim exponent
        new_exponent := (unsigned(l.exponent) + unsigned(r.exponent)) - to_unsigned(fp_exponent_bias, get_width_for_unsigned(fp_exponent_bias)); -- Subtract the bias to prevent double counting

        -- Prepend the leading 1s
        l_full_mantissa := unsigned('1' & l.mantissa);
        r_full_mantissa := unsigned('1' & r.mantissa);

        -- Multiply the mantissae
        multiplied_mantissae := l_full_mantissa * r_full_mantissa;

        -- Truncate resulting mantissa to 24 bits
        truncated_mantissa := std_logic_vector(multiplied_mantissae(multiplied_mantissae'high downto multiplied_mantissae'high - l_full_mantissa'length));

        -- Renormalise the result
        -- Result will always be xx.yyyyyyyy...
        -- xx will always be 11, 10 or 01, so if the high bit is 1 we need to right shift once
        if (truncated_mantissa(truncated_mantissa'high) = '1') then
            truncated_mantissa := truncated_mantissa srl 1;
            new_exponent := new_exponent + 1;
        end if;

        report  to_string((l.sign xor r.sign) & std_logic_vector(new_exponent) & std_logic_vector(truncated_mantissa(truncated_mantissa'high - 2 downto 0)));
        return ((l.sign xor r.sign), std_logic_vector(new_exponent), std_logic_vector(truncated_mantissa(truncated_mantissa'high - 2 downto 0)));
    end function;