use crate::FloatConstants;

/// "Load Exponent": Gives `x * (2**exp)`
///
/// Of course the value of this function is that, ideally, the operation is
/// performed in much less time than that number of multiplies would require.
///
/// ## Algorithm
///
/// Doing this fairly efficiently relies on bit manipulation of the float value.
///
/// The  first goal is to bring the exponent (which can be positive or negative)
/// closer to zero, and then second we do a multiplication by a specially
/// constructed float value.
///
/// First, some important type-based constants:
/// * `MAX_REAL_EXP` is the largest exponent value that the float type supports
///   before that bit pattern becomes NaN / Infinity. Eg: for `f32` this would
///   be 127, and for `f64` it would be 1023. The bit pattern for this is always
///   all exponent bits set *except* for the lowest one.
/// * `MIN_REAL_EXP` is the smallest exponent value that the float type
///   supports. Again, `f32` it's -126, or for `f64` it would be -1022. The bit
///   pattern in this case is to have *only* the lowest bit set. (Having all
///   exponent bits zero is a separate, alternate case that we don't want for
///   this functions purposes).
/// * `MANTISSA_DIGITS` is as per the rust standard library definition: the
///   number of significant digits in the mantissa of the float. This is the
///   number of mantissa bits +1.
///
/// The operation itself is fairly simple:
/// * If the `exp` is *more* than `MAX_REAL_EXP` then multiply `x` by
///   `2**MAX_REAL_EXP` and subtract `MAX_REAL_EXP` from `exp`.
///   * Do this once or twice as needed, but not more than 2 times.
///   * If `exp` is still more than the maximum real exponent after two times
///     then just set it to `MAX_REAL_EXP`.
/// * If the `exp` is *less* than `MIN_REAL_EXP` then we'll also multiply and
///   adjust. For the small exponent case, we want to multiply by
///   `(2**MIN_REAL_EXP) * (2**MANTISSA_DIGITS)`, and then add
///   `(abs(MIN_REAL_EXP)-MANTISSA_DIGITS)` to `exp`
///   * As with the large exponent case, do this 0, 1, or 2 times.
///   * If `exp` is still less than `MIN_REAL_EXP` just set it to
///     `MIN_REAL_EXP`.
/// * Finally, we multipily `x` by a specially constructed bit pattern. We want
///   `MAX_REAL_EXP + exp` (this will always be 1 through 2*`MAX_REAL_EXP`
///   because of the above adjustments), left shifted by the number of mantissa
///   bits, and then bitcast into the float type.
pub trait LdExp<INT> {
  /// Gives `x * (2**exp)`
  fn ldexp(self, exp: INT) -> Self;
}

/// Calls the [ldexp](LdExp::ldexp) trait method.
#[must_use]
#[inline(always)]
pub fn ldexp<F, I>(x: F, i: I) -> F
where
  F: LdExp<I>,
{
  x.ldexp(i)
}

impl LdExp<i32> for f32 {
  fn ldexp(mut self, mut exp: i32) -> Self {
    let mut adjustments = 2;
    // the exponent is reduced by Self::MAX_REAL_EXP internally, so counter that
    // by using twice as much here.
    let two_to_the_max_real_exp =
      f32::from_bits(((Self::MAX_REAL_EXP * 2) as u32) << Self::MANTISSA_BITS);
    while exp > Self::MAX_REAL_EXP && adjustments > 0 {
      self *= two_to_the_max_real_exp;
      exp -= Self::MAX_REAL_EXP;
      adjustments -= 1;
    }
    exp = exp.min(Self::MAX_REAL_EXP);
    //
    let two_to_the_min_real_exp = f32::from_bits(1 << Self::MANTISSA_BITS);
    let two_to_the_mantissa_digits =
      f32::from_bits(((Self::MAX_REAL_EXP as u32) + f32::MANTISSA_DIGITS) << Self::MANTISSA_BITS);
    let float_combined = two_to_the_min_real_exp * two_to_the_mantissa_digits;
    let exp_combined = Self::MIN_REAL_EXP.abs() - (Self::MANTISSA_DIGITS as i32);
    while exp < Self::MIN_REAL_EXP && adjustments > 0 {
      self *= float_combined;
      exp += exp_combined;
      adjustments -= 1;
    }
    exp = exp.max(Self::MIN_REAL_EXP);
    //
    self * f32::from_bits(((Self::MAX_REAL_EXP + exp) as u32) << Self::MANTISSA_BITS)
  }
}

impl LdExp<i64> for f64 {
  fn ldexp(mut self, mut exp: i64) -> Self {
    let mut adjustments = 2;
    // the exponent is reduced by Self::MAX_REAL_EXP internally, so counter that
    // by using twice as much here.
    let two_to_the_max_real_exp =
      f64::from_bits(((Self::MAX_REAL_EXP * 2) as u64) << Self::MANTISSA_BITS);
    while exp > Self::MAX_REAL_EXP && adjustments > 0 {
      self *= two_to_the_max_real_exp;
      exp -= Self::MAX_REAL_EXP;
      adjustments -= 1;
    }
    exp = exp.min(Self::MAX_REAL_EXP);
    //
    let two_to_the_min_real_exp = f64::from_bits(1 << Self::MANTISSA_BITS);
    let two_to_the_mantissa_digits = f64::from_bits(
      ((Self::MAX_REAL_EXP as u64) + (f64::MANTISSA_DIGITS as u64)) << Self::MANTISSA_BITS,
    );
    let float_combined = two_to_the_min_real_exp * two_to_the_mantissa_digits;
    let exp_combined = Self::MIN_REAL_EXP.abs() - (Self::MANTISSA_DIGITS as i64);
    while exp < Self::MIN_REAL_EXP && adjustments > 0 {
      self *= float_combined;
      exp += exp_combined;
      adjustments -= 1;
    }
    exp = exp.max(Self::MIN_REAL_EXP);
    //
    self * f64::from_bits(((Self::MAX_REAL_EXP + exp) as u64) << Self::MANTISSA_BITS)
  }
}

impl LdExp<i32> for f64 {
  fn ldexp(self, exp: i32) -> Self {
    self.ldexp(i64::from(exp))
  }
}

#[cfg(feature = "portable_simd")]
impl<const N: usize> LdExp<core::simd::Simd<i32, N>> for core::simd::Simd<f32, N>
where
  core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
{
  fn ldexp(self, exp: core::simd::Simd<i32, N>) -> Self {
    use crate::i32xN_to_u32xN;
    use core::simd::Simd;
    type Signed<const N: usize> = Simd<i32, N>;
    type Unsigned<const N: usize> = Simd<u32, N>;
    let two_to_the_max_real_exp: Self =
      Self::splat(f32::from_bits(((f32::MAX_REAL_EXP * 2) as u32) << f32::MANTISSA_BITS));
    let two_to_the_min_real_exp = f32::from_bits(1 << f32::MANTISSA_BITS);
    let two_to_the_mantissa_digits =
      f32::from_bits(((f32::MAX_REAL_EXP as u32) + f32::MANTISSA_DIGITS) << f32::MANTISSA_BITS);
    let float_combined = Self::splat(two_to_the_min_real_exp * two_to_the_mantissa_digits);
    let exp_combined = Signed::splat(f32::MIN_REAL_EXP.abs() - (f32::MANTISSA_DIGITS as i32));
    //
    let x = self;
    //
    let x = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      x * two_to_the_max_real_exp
    } else {
      x
    });
    let exp = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      exp - Self::MAX_REAL_EXP
    } else {
      exp
    });
    //
    let x = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      x * two_to_the_max_real_exp
    } else {
      x
    });
    let exp = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      exp - Self::MAX_REAL_EXP
    } else {
      exp
    });
    // TODO: fix this once int simd supports lanewise `min` properly.
    let exp = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      Self::MAX_REAL_EXP
    } else {
      exp
    });
    //
    let x = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      x * float_combined
    } else {
      x
    });
    let exp = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      exp + exp_combined
    } else {
      exp
    });
    //
    let x = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      x * float_combined
    } else {
      x
    });
    let exp = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      exp + exp_combined
    } else {
      exp
    });
    // TODO: fix this once int simd supports lanewise `max` properly.
    let exp = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      Self::MAX_REAL_EXP
    } else {
      exp
    });
    //
    let bit_pattern = i32xN_to_u32xN(Self::MAX_REAL_EXP + exp) << Self::MANTISSA_BITS;
    x * Self::from_bits(bit_pattern)
  }
}

#[cfg(feature = "portable_simd")]
impl<const N: usize> LdExp<core::simd::Simd<i64, N>> for core::simd::Simd<f64, N>
where
  core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
{
  fn ldexp(self, exp: core::simd::Simd<i64, N>) -> Self {
    use crate::i64xN_to_u64xN;
    use core::simd::Simd;
    type Signed<const N: usize> = Simd<i64, N>;
    type Unsigned<const N: usize> = Simd<u64, N>;
    let two_to_the_max_real_exp: Self =
      Self::splat(f64::from_bits(((f64::MAX_REAL_EXP * 2) as u64) << f64::MANTISSA_BITS));
    let two_to_the_min_real_exp = f64::from_bits(1 << f64::MANTISSA_BITS);
    let two_to_the_mantissa_digits = f64::from_bits(
      ((f64::MAX_REAL_EXP as u64) + (f64::MANTISSA_DIGITS as u64)) << f64::MANTISSA_BITS,
    );
    let float_combined = Self::splat(two_to_the_min_real_exp * two_to_the_mantissa_digits);
    let exp_combined = Signed::splat(f64::MIN_REAL_EXP.abs() - (f64::MANTISSA_DIGITS as i64));
    //
    let x = self;
    //
    let x = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      x * two_to_the_max_real_exp
    } else {
      x
    });
    let exp = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      exp - Self::MAX_REAL_EXP
    } else {
      exp
    });
    //
    let x = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      x * two_to_the_max_real_exp
    } else {
      x
    });
    let exp = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      exp - Self::MAX_REAL_EXP
    } else {
      exp
    });
    // TODO: fix this once int simd supports lanewise `min` properly.
    let exp = simdif!(exp.lanes_gt(Self::MAX_REAL_EXP) => {
      Self::MAX_REAL_EXP
    } else {
      exp
    });
    //
    let x = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      x * float_combined
    } else {
      x
    });
    let exp = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      exp + exp_combined
    } else {
      exp
    });
    //
    let x = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      x * float_combined
    } else {
      x
    });
    let exp = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      exp + exp_combined
    } else {
      exp
    });
    // TODO: fix this once int simd supports lanewise `max` properly.
    let exp = simdif!(exp.lanes_lt(Self::MIN_REAL_EXP) => {
      Self::MAX_REAL_EXP
    } else {
      exp
    });
    //
    let bit_pattern = i64xN_to_u64xN(Self::MAX_REAL_EXP + exp) << Self::MANTISSA_BITS;
    x * Self::from_bits(bit_pattern)
  }
}

#[test]
fn test_ldexp() {
  // f32
  for (f, i) in [
    (1.5, 3),
    (1.5, 128),
    (1.5, 256),
    (1.5, 500),
    (-1.5, 3),
    (-1.5, 128),
    (-1.5, 256),
    (-1.5, 500),
    (1.5, -3),
    (1.5, -128),
    (1.5, -256),
    (1.5, -500),
    (-1.5, -3),
    (-1.5, -128),
    (-1.5, -256),
    (-1.5, -500),
    (f32::INFINITY, -2),
    (f32::INFINITY, -1),
    (f32::INFINITY, 0),
    (f32::INFINITY, 1),
    (f32::INFINITY, 2),
    (f32::NEG_INFINITY, -2),
    (f32::NEG_INFINITY, -1),
    (f32::NEG_INFINITY, 0),
    (f32::NEG_INFINITY, 1),
    (f32::NEG_INFINITY, 2),
    (f32::NAN, 10),
  ]
  .iter()
  .copied()
  {
    let expected = libm::ldexpf(f, i);
    let actual = ldexp(f, i);
    if expected.is_nan() {
      assert!(actual.is_nan(), "failed on inputs ({}, {})", f, i)
    } else {
      assert_eq!(expected, actual, "failed on inputs ({}, {})", f, i);
    }
  }

  // f64
  for (f, i) in [
    (1.5, 3_i64),
    (1.5, 1280),
    (1.5, 2560),
    (1.5, 5000),
    (-1.5, 3),
    (-1.5, 1280),
    (-1.5, 2560),
    (-1.5, 5000),
    (1.5, -3),
    (1.5, -1280),
    (1.5, -2560),
    (1.5, -5000),
    (-1.5, -3),
    (-1.5, -1280),
    (-1.5, -2560),
    (-1.5, -5000),
    (f64::INFINITY, -2),
    (f64::INFINITY, -1),
    (f64::INFINITY, 0),
    (f64::INFINITY, 1),
    (f64::INFINITY, 2),
    (f64::NEG_INFINITY, -2),
    (f64::NEG_INFINITY, -1),
    (f64::NEG_INFINITY, 0),
    (f64::NEG_INFINITY, 1),
    (f64::NEG_INFINITY, 2),
    (f64::NAN, 10),
  ]
  .iter()
  .copied()
  {
    let expected = libm::ldexp(f, i.try_into().unwrap());
    let actual = ldexp(f, i);
    if expected.is_nan() {
      assert!(actual.is_nan(), "failed on inputs ({}, {})", f, i)
    } else {
      assert_eq!(expected, actual, "failed on inputs ({}, {})", f, i);
    }
  }

  #[cfg(feature = "portable_simd")]
  {
    use core::simd::{f32x4, f64x2, i32x4, i64x2};

    // f32x4
    for (f, i) in [
      (1.5, 3),
      (1.5, 128),
      (1.5, 256),
      (1.5, 500),
      (-1.5, 3),
      (-1.5, 128),
      (-1.5, 256),
      (-1.5, 500),
      (1.5, -3),
      (1.5, -128),
      (1.5, -256),
      (1.5, -500),
      (-1.5, -3),
      (-1.5, -128),
      (-1.5, -256),
      (-1.5, -500),
      (f32::INFINITY, -2),
      (f32::INFINITY, -1),
      (f32::INFINITY, 0),
      (f32::INFINITY, 1),
      (f32::INFINITY, 2),
      (f32::NEG_INFINITY, -2),
      (f32::NEG_INFINITY, -1),
      (f32::NEG_INFINITY, 0),
      (f32::NEG_INFINITY, 1),
      (f32::NEG_INFINITY, 2),
      (f32::NAN, 10),
    ]
    .iter()
    .copied()
    {
      let expected: f32 = libm::ldexpf(f, i);
      let actual: f32 = ldexp(f32x4::splat(f), i32x4::splat(i)).to_array()[0];
      if expected.is_nan() {
        assert!(actual.is_nan(), "failed on inputs ({}, {})", f, i)
      } else {
        assert_eq!(expected, actual, "failed on inputs ({}, {})", f, i);
      }
    }

    // f64x2
    for (f, i) in [
      (1.5, 3_i64),
      (1.5, 128),
      (1.5, 256),
      (1.5, 500),
      (-1.5, 3),
      (-1.5, 128),
      (-1.5, 256),
      (-1.5, 500),
      (1.5, -3),
      (1.5, -128),
      (1.5, -256),
      (1.5, -500),
      (-1.5, -3),
      (-1.5, -128),
      (-1.5, -256),
      (-1.5, -500),
      (f64::INFINITY, -2),
      (f64::INFINITY, -1),
      (f64::INFINITY, 0),
      (f64::INFINITY, 1),
      (f64::INFINITY, 2),
      (f64::NEG_INFINITY, -2),
      (f64::NEG_INFINITY, -1),
      (f64::NEG_INFINITY, 0),
      (f64::NEG_INFINITY, 1),
      (f64::NEG_INFINITY, 2),
      (f64::NAN, 10),
    ]
    .iter()
    .copied()
    {
      let expected: f64 = libm::ldexp(f, i.try_into().unwrap());
      let actual: f64 = ldexp(f64x2::splat(f), i64x2::splat(i)).to_array()[0];
      if expected.is_nan() {
        assert!(actual.is_nan(), "failed on inputs ({}, {})", f, i)
      } else {
        assert_eq!(expected, actual, "failed on inputs ({}, {})", f, i);
      }
    }
  }
}
