// SPDX-FileCopyrightText: 2015 Evan Teran
// SPDX-License-Identifier: MIT

// From: https://github.com/eteran/cpp-utilities/blob/master/fixed/include/cpp-utilities/fixed.h
// See also: http://stackoverflow.com/questions/79677/whats-the-best-way-to-do-fixed-point-math

#pragma once

#include <cstddef> // for size_t
#include <cstdint>
#include <exception>
#include <ostream>
#include <type_traits>

#include "concepts.h"

namespace Common {

template <size_t I, size_t F>
class FixedPoint;

namespace detail {

// helper templates to make magic with types :)
// these allow us to determine resonable types from
// a desired size, they also let us infer the next largest type
// from a type which is nice for the division op
template <size_t T>
struct type_from_size {
    using value_type = void;
    using unsigned_type = void;
    using signed_type = void;
    static constexpr bool is_specialized = false;
};

#if defined(__GNUC__) && defined(__x86_64__) && !defined(__STRICT_ANSI__)
template <>
struct type_from_size<128> {
    static constexpr bool is_specialized = true;
    static constexpr size_t size = 128;

    using value_type = __int128;
    using unsigned_type = unsigned __int128;
    using signed_type = __int128;
    using next_size = type_from_size<256>;
};
#endif

template <>
struct type_from_size<64> {
    static constexpr bool is_specialized = true;
    static constexpr size_t size = 64;

    using value_type = int64_t;
    using unsigned_type = std::make_unsigned_t<value_type>;
    using signed_type = std::make_signed_t<value_type>;
    using next_size = type_from_size<128>;
};

template <>
struct type_from_size<32> {
    static constexpr bool is_specialized = true;
    static constexpr size_t size = 32;

    using value_type = int32_t;
    using unsigned_type = std::make_unsigned_t<value_type>;
    using signed_type = std::make_signed_t<value_type>;
    using next_size = type_from_size<64>;
};

template <>
struct type_from_size<16> {
    static constexpr bool is_specialized = true;
    static constexpr size_t size = 16;

    using value_type = int16_t;
    using unsigned_type = std::make_unsigned_t<value_type>;
    using signed_type = std::make_signed_t<value_type>;
    using next_size = type_from_size<32>;
};

template <>
struct type_from_size<8> {
    static constexpr bool is_specialized = true;
    static constexpr size_t size = 8;

    using value_type = int8_t;
    using unsigned_type = std::make_unsigned_t<value_type>;
    using signed_type = std::make_signed_t<value_type>;
    using next_size = type_from_size<16>;
};

// this is to assist in adding support for non-native base
// types (for adding big-int support), this should be fine
// unless your bit-int class doesn't nicely support casting
template <class B, class N>
constexpr B next_to_base(N rhs) {
    return static_cast<B>(rhs);
}

struct divide_by_zero : std::exception {};

template <size_t I, size_t F>
constexpr FixedPoint<I, F> divide(
    FixedPoint<I, F> numerator, FixedPoint<I, F> denominator, FixedPoint<I, F>& remainder,
    std::enable_if_t<type_from_size<I + F>::next_size::is_specialized>* = nullptr) {

    using next_type = typename FixedPoint<I, F>::next_type;
    using base_type = typename FixedPoint<I, F>::base_type;
    constexpr size_t fractional_bits = FixedPoint<I, F>::fractional_bits;

    next_type t(numerator.to_raw());
    t <<= fractional_bits;

    FixedPoint<I, F> quotient;

    quotient = FixedPoint<I, F>::from_base(next_to_base<base_type>(t / denominator.to_raw()));
    remainder = FixedPoint<I, F>::from_base(next_to_base<base_type>(t % denominator.to_raw()));

    return quotient;
}

template <size_t I, size_t F>
constexpr FixedPoint<I, F> divide(
    FixedPoint<I, F> numerator, FixedPoint<I, F> denominator, FixedPoint<I, F>& remainder,
    std::enable_if_t<!type_from_size<I + F>::next_size::is_specialized>* = nullptr) {

    using unsigned_type = typename FixedPoint<I, F>::unsigned_type;

    constexpr int bits = FixedPoint<I, F>::total_bits;

    if (denominator == 0) {
        throw divide_by_zero();
    } else {

        int sign = 0;

        FixedPoint<I, F> quotient;

        if (numerator < 0) {
            sign ^= 1;
            numerator = -numerator;
        }

        if (denominator < 0) {
            sign ^= 1;
            denominator = -denominator;
        }

        unsigned_type n = numerator.to_raw();
        unsigned_type d = denominator.to_raw();
        unsigned_type x = 1;
        unsigned_type answer = 0;

        // egyptian division algorithm
        while ((n >= d) && (((d >> (bits - 1)) & 1) == 0)) {
            x <<= 1;
            d <<= 1;
        }

        while (x != 0) {
            if (n >= d) {
                n -= d;
                answer += x;
            }

            x >>= 1;
            d >>= 1;
        }

        unsigned_type l1 = n;
        unsigned_type l2 = denominator.to_raw();

        // calculate the lower bits (needs to be unsigned)
        while (l1 >> (bits - F) > 0) {
            l1 >>= 1;
            l2 >>= 1;
        }
        const unsigned_type lo = (l1 << F) / l2;

        quotient = FixedPoint<I, F>::from_base((answer << F) | lo);
        remainder = n;

        if (sign) {
            quotient = -quotient;
        }

        return quotient;
    }
}

// this is the usual implementation of multiplication
template <size_t I, size_t F>
constexpr FixedPoint<I, F> multiply(
    FixedPoint<I, F> lhs, FixedPoint<I, F> rhs,
    std::enable_if_t<type_from_size<I + F>::next_size::is_specialized>* = nullptr) {

    using next_type = typename FixedPoint<I, F>::next_type;
    using base_type = typename FixedPoint<I, F>::base_type;

    constexpr size_t fractional_bits = FixedPoint<I, F>::fractional_bits;

    next_type t(static_cast<next_type>(lhs.to_raw()) * static_cast<next_type>(rhs.to_raw()));
    t >>= fractional_bits;

    return FixedPoint<I, F>::from_base(next_to_base<base_type>(t));
}

// this is the fall back version we use when we don't have a next size
// it is slightly slower, but is more robust since it doesn't
// require and upgraded type
template <size_t I, size_t F>
constexpr FixedPoint<I, F> multiply(
    FixedPoint<I, F> lhs, FixedPoint<I, F> rhs,
    std::enable_if_t<!type_from_size<I + F>::next_size::is_specialized>* = nullptr) {

    using base_type = typename FixedPoint<I, F>::base_type;

    constexpr size_t fractional_bits = FixedPoint<I, F>::fractional_bits;
    constexpr base_type integer_mask = FixedPoint<I, F>::integer_mask;
    constexpr base_type fractional_mask = FixedPoint<I, F>::fractional_mask;

    // more costly but doesn't need a larger type
    const base_type a_hi = (lhs.to_raw() & integer_mask) >> fractional_bits;
    const base_type b_hi = (rhs.to_raw() & integer_mask) >> fractional_bits;
    const base_type a_lo = (lhs.to_raw() & fractional_mask);
    const base_type b_lo = (rhs.to_raw() & fractional_mask);

    const base_type x1 = a_hi * b_hi;
    const base_type x2 = a_hi * b_lo;
    const base_type x3 = a_lo * b_hi;
    const base_type x4 = a_lo * b_lo;

    return FixedPoint<I, F>::from_base((x1 << fractional_bits) + (x3 + x2) +
                                       (x4 >> fractional_bits));
}
} // namespace detail

template <size_t I, size_t F>
class FixedPoint {
    static_assert(detail::type_from_size<I + F>::is_specialized, "invalid combination of sizes");

public:
    static constexpr size_t fractional_bits = F;
    static constexpr size_t integer_bits = I;
    static constexpr size_t total_bits = I + F;

    using base_type_info = detail::type_from_size<total_bits>;

    using base_type = typename base_type_info::value_type;
    using next_type = typename base_type_info::next_size::value_type;
    using unsigned_type = typename base_type_info::unsigned_type;

public:
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Woverflow"
#endif
    static constexpr base_type fractional_mask =
        ~(static_cast<unsigned_type>(~base_type(0)) << fractional_bits);
    static constexpr base_type integer_mask = ~fractional_mask;
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

public:
    static constexpr base_type one = base_type(1) << fractional_bits;

public: // constructors
    constexpr FixedPoint() = default;

    constexpr FixedPoint(const FixedPoint&) = default;
    constexpr FixedPoint& operator=(const FixedPoint&) = default;

    constexpr FixedPoint(FixedPoint&&) noexcept = default;
    constexpr FixedPoint& operator=(FixedPoint&&) noexcept = default;

    template <IsArithmetic Number>
    constexpr FixedPoint(Number n) : data_(static_cast<base_type>(n * one)) {}

public: // conversion
    template <size_t I2, size_t F2>
    constexpr explicit FixedPoint(FixedPoint<I2, F2> other) {
        static_assert(I2 <= I && F2 <= F, "Scaling conversion can only upgrade types");
        using T = FixedPoint<I2, F2>;

        const base_type fractional = (other.data_ & T::fractional_mask);
        const base_type integer = (other.data_ & T::integer_mask) >> T::fractional_bits;
        data_ =
            (integer << fractional_bits) | (fractional << (fractional_bits - T::fractional_bits));
    }

private:
    // this makes it simpler to create a FixedPoint point object from
    // a native type without scaling
    // use "FixedPoint::from_base" in order to perform this.
    struct NoScale {};

    constexpr FixedPoint(base_type n, const NoScale&) : data_(n) {}

public:
    static constexpr FixedPoint from_base(base_type n) {
        return FixedPoint(n, NoScale());
    }

public: // comparison operators
    friend constexpr auto operator<=>(FixedPoint lhs, FixedPoint rhs) = default;

public: // unary operators
    [[nodiscard]] constexpr bool operator!() const {
        return !data_;
    }

    [[nodiscard]] constexpr FixedPoint operator~() const {
        // NOTE(eteran): this will often appear to "just negate" the value
        // that is not an error, it is because -x == (~x+1)
        // and that "+1" is adding an infinitesimally small fraction to the
        // complimented value
        return FixedPoint::from_base(~data_);
    }

    [[nodiscard]] constexpr FixedPoint operator-() const {
        return FixedPoint::from_base(-data_);
    }

    [[nodiscard]] constexpr FixedPoint operator+() const {
        return FixedPoint::from_base(+data_);
    }

    constexpr FixedPoint& operator++() {
        data_ += one;
        return *this;
    }

    constexpr FixedPoint& operator--() {
        data_ -= one;
        return *this;
    }

    constexpr FixedPoint operator++(int) {
        FixedPoint tmp(*this);
        data_ += one;
        return tmp;
    }

    constexpr FixedPoint operator--(int) {
        FixedPoint tmp(*this);
        data_ -= one;
        return tmp;
    }

public: // basic math operators
    constexpr FixedPoint& operator+=(FixedPoint n) {
        data_ += n.data_;
        return *this;
    }

    constexpr FixedPoint& operator-=(FixedPoint n) {
        data_ -= n.data_;
        return *this;
    }

    constexpr FixedPoint& operator*=(FixedPoint n) {
        return assign(detail::multiply(*this, n));
    }

    constexpr FixedPoint& operator/=(FixedPoint n) {
        FixedPoint temp;
        return assign(detail::divide(*this, n, temp));
    }

private:
    constexpr FixedPoint& assign(FixedPoint rhs) {
        data_ = rhs.data_;
        return *this;
    }

public: // binary math operators, effects underlying bit pattern since these
        // don't really typically make sense for non-integer values
    constexpr FixedPoint& operator&=(FixedPoint n) {
        data_ &= n.data_;
        return *this;
    }

    constexpr FixedPoint& operator|=(FixedPoint n) {
        data_ |= n.data_;
        return *this;
    }

    constexpr FixedPoint& operator^=(FixedPoint n) {
        data_ ^= n.data_;
        return *this;
    }

    template <IsIntegral Integer>
    constexpr FixedPoint& operator>>=(Integer n) {
        data_ >>= n;
        return *this;
    }

    template <IsIntegral Integer>
    constexpr FixedPoint& operator<<=(Integer n) {
        data_ <<= n;
        return *this;
    }

public: // conversion to basic types
    constexpr void round_up() {
        data_ += (data_ & fractional_mask) >> 1;
    }

    [[nodiscard]] constexpr int to_int() {
        round_up();
        return static_cast<int>((data_ & integer_mask) >> fractional_bits);
    }

    [[nodiscard]] constexpr unsigned int to_uint() {
        round_up();
        return static_cast<unsigned int>((data_ & integer_mask) >> fractional_bits);
    }

    [[nodiscard]] constexpr int64_t to_long() {
        round_up();
        return static_cast<int64_t>((data_ & integer_mask) >> fractional_bits);
    }

    [[nodiscard]] constexpr int to_int_floor() const {
        return static_cast<int>((data_ & integer_mask) >> fractional_bits);
    }

    [[nodiscard]] constexpr int64_t to_long_floor() const {
        return static_cast<int64_t>((data_ & integer_mask) >> fractional_bits);
    }

    [[nodiscard]] constexpr unsigned int to_uint_floor() const {
        return static_cast<unsigned int>((data_ & integer_mask) >> fractional_bits);
    }

    [[nodiscard]] constexpr float to_float() const {
        return static_cast<float>(data_) / FixedPoint::one;
    }

    [[nodiscard]] constexpr double to_double() const {
        return static_cast<double>(data_) / FixedPoint::one;
    }

    [[nodiscard]] constexpr base_type to_raw() const {
        return data_;
    }

    constexpr void clear_int() {
        data_ &= fractional_mask;
    }

    [[nodiscard]] constexpr base_type get_frac() const {
        return data_ & fractional_mask;
    }

public:
    constexpr void swap(FixedPoint& rhs) noexcept {
        using std::swap;
        swap(data_, rhs.data_);
    }

public:
    base_type data_{};
};

// if we have the same fractional portion, but differing integer portions, we trivially upgrade the
// smaller type
template <size_t I1, size_t I2, size_t F>
constexpr std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>> operator+(
    FixedPoint<I1, F> lhs, FixedPoint<I2, F> rhs) {

    using T = std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>>;

    const T l = T::from_base(lhs.to_raw());
    const T r = T::from_base(rhs.to_raw());
    return l + r;
}

template <size_t I1, size_t I2, size_t F>
constexpr std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>> operator-(
    FixedPoint<I1, F> lhs, FixedPoint<I2, F> rhs) {

    using T = std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>>;

    const T l = T::from_base(lhs.to_raw());
    const T r = T::from_base(rhs.to_raw());
    return l - r;
}

template <size_t I1, size_t I2, size_t F>
constexpr std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>> operator*(
    FixedPoint<I1, F> lhs, FixedPoint<I2, F> rhs) {

    using T = std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>>;

    const T l = T::from_base(lhs.to_raw());
    const T r = T::from_base(rhs.to_raw());
    return l * r;
}

template <size_t I1, size_t I2, size_t F>
constexpr std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>> operator/(
    FixedPoint<I1, F> lhs, FixedPoint<I2, F> rhs) {

    using T = std::conditional_t<I1 >= I2, FixedPoint<I1, F>, FixedPoint<I2, F>>;

    const T l = T::from_base(lhs.to_raw());
    const T r = T::from_base(rhs.to_raw());
    return l / r;
}

template <size_t I, size_t F>
std::ostream& operator<<(std::ostream& os, FixedPoint<I, F> f) {
    os << f.to_double();
    return os;
}

// basic math operators
template <size_t I, size_t F>
constexpr FixedPoint<I, F> operator+(FixedPoint<I, F> lhs, FixedPoint<I, F> rhs) {
    lhs += rhs;
    return lhs;
}
template <size_t I, size_t F>
constexpr FixedPoint<I, F> operator-(FixedPoint<I, F> lhs, FixedPoint<I, F> rhs) {
    lhs -= rhs;
    return lhs;
}
template <size_t I, size_t F>
constexpr FixedPoint<I, F> operator*(FixedPoint<I, F> lhs, FixedPoint<I, F> rhs) {
    lhs *= rhs;
    return lhs;
}
template <size_t I, size_t F>
constexpr FixedPoint<I, F> operator/(FixedPoint<I, F> lhs, FixedPoint<I, F> rhs) {
    lhs /= rhs;
    return lhs;
}

template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator+(FixedPoint<I, F> lhs, Number rhs) {
    lhs += FixedPoint<I, F>(rhs);
    return lhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator-(FixedPoint<I, F> lhs, Number rhs) {
    lhs -= FixedPoint<I, F>(rhs);
    return lhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator*(FixedPoint<I, F> lhs, Number rhs) {
    lhs *= FixedPoint<I, F>(rhs);
    return lhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator/(FixedPoint<I, F> lhs, Number rhs) {
    lhs /= FixedPoint<I, F>(rhs);
    return lhs;
}

template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator+(Number lhs, FixedPoint<I, F> rhs) {
    FixedPoint<I, F> tmp(lhs);
    tmp += rhs;
    return tmp;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator-(Number lhs, FixedPoint<I, F> rhs) {
    FixedPoint<I, F> tmp(lhs);
    tmp -= rhs;
    return tmp;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator*(Number lhs, FixedPoint<I, F> rhs) {
    FixedPoint<I, F> tmp(lhs);
    tmp *= rhs;
    return tmp;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr FixedPoint<I, F> operator/(Number lhs, FixedPoint<I, F> rhs) {
    FixedPoint<I, F> tmp(lhs);
    tmp /= rhs;
    return tmp;
}

// shift operators
template <size_t I, size_t F, IsIntegral Integer>
constexpr FixedPoint<I, F> operator<<(FixedPoint<I, F> lhs, Integer rhs) {
    lhs <<= rhs;
    return lhs;
}
template <size_t I, size_t F, IsIntegral Integer>
constexpr FixedPoint<I, F> operator>>(FixedPoint<I, F> lhs, Integer rhs) {
    lhs >>= rhs;
    return lhs;
}

// comparison operators
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator>(FixedPoint<I, F> lhs, Number rhs) {
    return lhs > FixedPoint<I, F>(rhs);
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator<(FixedPoint<I, F> lhs, Number rhs) {
    return lhs < FixedPoint<I, F>(rhs);
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator>=(FixedPoint<I, F> lhs, Number rhs) {
    return lhs >= FixedPoint<I, F>(rhs);
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator<=(FixedPoint<I, F> lhs, Number rhs) {
    return lhs <= FixedPoint<I, F>(rhs);
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator==(FixedPoint<I, F> lhs, Number rhs) {
    return lhs == FixedPoint<I, F>(rhs);
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator!=(FixedPoint<I, F> lhs, Number rhs) {
    return lhs != FixedPoint<I, F>(rhs);
}

template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator>(Number lhs, FixedPoint<I, F> rhs) {
    return FixedPoint<I, F>(lhs) > rhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator<(Number lhs, FixedPoint<I, F> rhs) {
    return FixedPoint<I, F>(lhs) < rhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator>=(Number lhs, FixedPoint<I, F> rhs) {
    return FixedPoint<I, F>(lhs) >= rhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator<=(Number lhs, FixedPoint<I, F> rhs) {
    return FixedPoint<I, F>(lhs) <= rhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator==(Number lhs, FixedPoint<I, F> rhs) {
    return FixedPoint<I, F>(lhs) == rhs;
}
template <size_t I, size_t F, IsArithmetic Number>
constexpr bool operator!=(Number lhs, FixedPoint<I, F> rhs) {
    return FixedPoint<I, F>(lhs) != rhs;
}

} // namespace Common
