import functools import sys import numpy as np import decimal from decimal import Decimal decimal.getcontext().prec = 1000 @functools.total_ordering class NumType: def __init__(self, bits, signed, isfloat): self.bits = bits self.signed = signed self.isfloat = isfloat if self.isfloat: assert signed self.name = f"f{bits}" if bits == 32: self.int_min = -2**24 self.int_max = 2**24 self.min = -2**127 self.max = 2**127 else: assert bits == 64 self.int_min = -2**53 self.int_max = 2**53 self.min = -2**1023 self.max = 2**1023 else: if signed: self.name = f"i{bits}" self.int_min = -2**(bits-1) self.int_max = 2**(bits-1) - 1 self.min = self.int_min self.max = self.int_max else: self.name = f"u{bits}" self.int_min = 0 self.int_max = 2**bits - 1 self.min = self.int_min self.max = self.int_max def can_exactly_represent(self, x): if x < self.min or x > self.max: return False if self.isfloat: if self.bits == 32: return Decimal(float(np.float32(x))) == x else: return Decimal(float(np.float64(x))) == x else: return x == int(x) def interesting_values(self): vals = {self.int_min, self.int_max, self.min, self.max} for val in vals.copy(): vals.add(val + Decimal("0.5")) vals.add(val - Decimal("0.5")) vals.add(val + 1) vals.add(val - 1) return vals def is_subset_eq(self, other): return other.int_min <= self.int_min <= self.int_max <= other.int_max def __lt__(self, other): return (self.isfloat, self.bits, self.signed) < (other.isfloat, other.bits, other.signed) def __repr__(self): return self.name def implies(a, b): return not a or b def common_type(a, b, types): return min((t for t in types if (a.is_subset_eq(t) and b.is_subset_eq(t)) and implies(a.isfloat or b.isfloat, t.isfloat)), default=None) if __name__ == "__main__": types = [ NumType(bits, signed, False) for bits in [8, 16, 32, 64, 128] for signed in [False, True]] types += [NumType(32, True, True), NumType(64, True, True)] if sys.argv[1] == "common-types": unhandled = [] for a in types: for b in types: if a < b: c = common_type(a, b, types) if c is not None: print(f"{str(a):>4}, {str(b):>4} => {str(c):>4};") else: unhandled.append((a, b)) for a, b in unhandled: print("NO COMMON TYPE", a, b) elif sys.argv[1] == "all-types": for a in types: for b in types: print(f"{a}, {b};") elif sys.argv[1] == "tests": print("""// Automatically generated by tools/gen.py. use core::cmp::Ordering; use num_ord::NumOrd; """) for t1 in types: for t2 in types: if t1 == t2: continue print(f"#[test] fn test_{t1}_{t2}() {{") interesting_values = sorted(t1.interesting_values() | t2.interesting_values()) for v1 in interesting_values: for v2 in interesting_values: if t1.can_exactly_represent(v1) and t2.can_exactly_represent(v2): print(f" assert_eq!(NumOrd({v1}{t1}) < NumOrd({v2}{t2}), {str(v1 < v2).lower()});") print(f" assert_eq!(NumOrd({v1}{t1}) <= NumOrd({v2}{t2}), {str(v1 <= v2).lower()});") print(f" assert_eq!(NumOrd({v1}{t1}) > NumOrd({v2}{t2}), {str(v1 > v2).lower()});") print(f" assert_eq!(NumOrd({v1}{t1}) >= NumOrd({v2}{t2}), {str(v1 >= v2).lower()});") if v1 > v2: ordering = "Greater" elif v1 < v2: ordering = "Less" else: ordering = "Equal" print(f" assert_eq!(NumOrd({v1}{t1}).partial_cmp(&NumOrd({v2}{t2})), Some(Ordering::{ordering}));") print("}\n")