use crate::{float_constants::*, ops::get_exponent_value::*};

/// Truncates a value, in other words this rounds toward zero.
///
/// ## Algorithm
///
/// Performing this operation involves examining the value's bit pattern.
///
/// * If the exponent value of the float is negative, return a zero with the
///   same sign as the input.
/// * Otherwise you're going to keep some number of the high bits and clear the
///   rest. For any given width of float type, you should keep the exponent
///   value plus the number of non-mantissa digits in that float type. If the
///   number of bits to keep exceeds the number of bits in the type you can
///   return the value directly.
pub trait Truncate {
  /// Truncates the value
  fn truncate(self) -> Self;
}

/// Calls the [truncate](Truncate::truncate) trait method.
#[must_use]
#[inline(always)]
pub fn truncate<T: Truncate>(x: T) -> T {
  x.truncate()
}

impl Truncate for f32 {
  #[inline]
  #[must_use]
  fn truncate(self) -> Self {
    pick! {
      if #[cfg(feature = "core_intrinsics")] {
        unsafe { core::intrinsics::truncf32(self) }
      } else {
        let exponent_value = get_exponent_value(self);
        if exponent_value < 0 {
          f32::from_bits(self.to_bits() & Self::SIGN_BIT_MASK)
        } else {
          let bits_to_keep = exponent_value + (Self::NON_MANTISSA_BITS as i32);
          if bits_to_keep >= 32 {
            self
          } else {
            let mask = !(u32::MAX >> bits_to_keep);
            f32::from_bits(self.to_bits() & mask)
          }
        }
      }
    }
  }
}

impl Truncate for f64 {
  #[inline]
  #[must_use]
  fn truncate(self) -> Self {
    pick! {
      if #[cfg(feature = "core_intrinsics")] {
        unsafe { core::intrinsics::truncf64(self) }
      } else {
        let exponent_value = get_exponent_value(self);
        if exponent_value < 0 {
          f64::from_bits(self.to_bits() & Self::SIGN_BIT_MASK)
        } else {
          let bits_to_keep = exponent_value + (Self::NON_MANTISSA_BITS as i64);
          if bits_to_keep >= 64 {
            self
          } else {
            let mask = !(u64::MAX >> bits_to_keep);
            f64::from_bits(self.to_bits() & mask)
          }
        }
      }
    }
  }
}

#[cfg(feature = "portable_simd")]
impl<const N: usize> Truncate for core::simd::Simd<f32, N>
where
  core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
{
  #[inline]
  #[must_use]
  fn truncate(self) -> Self {
    use crate::{i32xN_to_u32xN, u32xN_to_i32xN};
    use core::simd::Simd;
    type Signed<const N: usize> = Simd<i32, N>;
    type Unsigned<const N: usize> = Simd<u32, N>;
    //
    let exponent_value = get_exponent_value(self);
    simdif!(exponent_value.lanes_lt(Signed::splat(0)) => {
      Self::from_bits(self.to_bits() & Self::SIGN_BIT_MASK)
    } else {
      let bits_to_keep = exponent_value + u32xN_to_i32xN(Self::NON_MANTISSA_BITS);
      simdif!(bits_to_keep.lanes_ge(Signed::splat(32)) => {
        self
      } else {
        // even in the else case we must mask the shift value to not overflow
        // any of the lanes.
        let mask = !(Unsigned::splat(u32::MAX) >> (i32xN_to_u32xN(bits_to_keep) & Unsigned::splat(32-1)));
        Self::from_bits(self.to_bits() & mask)
      })
    })
  }
}

#[cfg(feature = "portable_simd")]
impl<const N: usize> Truncate for core::simd::Simd<f64, N>
where
  core::simd::LaneCount<N>: core::simd::SupportedLaneCount,
{
  #[inline]
  #[must_use]
  fn truncate(self) -> Self {
    use crate::{i64xN_to_u64xN, u64xN_to_i64xN};
    use core::simd::Simd;
    type Signed<const N: usize> = Simd<i64, N>;
    type Unsigned<const N: usize> = Simd<u64, N>;
    //
    let exponent_value = get_exponent_value(self);
    simdif!(exponent_value.lanes_lt(Signed::splat(0)) => {
      Self::from_bits(self.to_bits() & Self::SIGN_BIT_MASK)
    } else {
      let bits_to_keep = exponent_value + u64xN_to_i64xN(Self::NON_MANTISSA_BITS);
      simdif!(bits_to_keep.lanes_ge(Signed::splat(64)) => {
        self
      } else {
        // even in the else case we must mask the shift value to not overflow
        // any of the lanes.
        let mask = !(Unsigned::splat(u64::MAX) >> (i64xN_to_u64xN(bits_to_keep) & Unsigned::splat(64-1)));
        Self::from_bits(self.to_bits() & mask)
      })
    })
  }
}

#[test]
fn test_truncate() {
  // f32
  for f in [
    0.0,
    0.5,
    1.0,
    1.2,
    -0.0,
    -0.1,
    -1.1,
    f32::INFINITY,
    f32::NEG_INFINITY,
    f32::MIN_POSITIVE,
    f32::NAN,
  ]
  .iter()
  .copied()
  {
    let expected = f.trunc();
    let actual = truncate(f);
    if expected.is_nan() {
      assert!(actual.is_nan(), "failed for {}", f);
    } else {
      assert_eq!(expected, actual, "failed for {}", f);
    }
  }

  // f64
  for f in [
    0.0,
    0.5,
    1.0,
    1.2,
    -0.0,
    -0.1,
    -1.1,
    f64::INFINITY,
    f64::NEG_INFINITY,
    f64::MIN_POSITIVE,
    f64::NAN,
  ]
  .iter()
  .copied()
  {
    let expected = f.trunc();
    let actual = truncate(f);
    if expected.is_nan() {
      assert!(actual.is_nan(), "failed for {}", f);
    } else {
      assert_eq!(expected, actual, "failed for {}", f);
    }
  }

  #[cfg(feature = "portable_simd")]
  {
    use core::simd::{f32x4, f64x2};
    // f32
    for f in [
      0.0,
      0.5,
      1.0,
      1.2,
      -0.0,
      -0.1,
      -1.1,
      f32::INFINITY,
      f32::NEG_INFINITY,
      f32::MIN_POSITIVE,
      f32::NAN,
    ]
    .iter()
    .copied()
    {
      let expected: f32 = f.trunc();
      let actual: f32 = truncate(f32x4::splat(f)).to_array()[0];
      if expected.is_nan() {
        assert!(actual.is_nan(), "failed for {}", f);
      } else {
        assert_eq!(expected, actual, "failed for {}", f);
      }
    }

    // f64
    for f in [
      0.0,
      0.5,
      1.0,
      1.2,
      -0.0,
      -0.1,
      -1.1,
      f64::INFINITY,
      f64::NEG_INFINITY,
      f64::MIN_POSITIVE,
      f64::NAN,
    ]
    .iter()
    .copied()
    {
      let expected: f64 = f.trunc();
      let actual: f64 = truncate(f64x2::splat(f)).to_array()[0];
      if expected.is_nan() {
        assert!(actual.is_nan(), "failed for {}", f);
      } else {
        assert_eq!(expected, actual, "failed for {}", f);
      }
    }
  }
}
