;; Implements part of the simplification layer of herbie in egglog🫡 ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;; Datatypes (datatype Math ; Ground terms (Num Rational) (Var String) ; Custom ops (Const String) (Unary String Math) ; unneeded for now ; (Binary String Math Math) ; Constant-folding ops (Add Math Math) (Sub Math Math) (Mul Math Math) (Div Math Math) (Pow Math Math) (Neg Math) (Sqrt Math) (Cbrt Math) ; cube root (Fabs Math) (Ceil Math) (Floor Math) (Round Math) (Log Math)) (let r-zero (rational 0 1)) (let r-one (rational 1 1)) (let r-two (rational 2 1)) (let zero (Num r-zero)) (let one (Num r-one)) (let two (Num r-two)) (let three (Num (rational 3 1))) (let neg-one (Neg one)) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;; Analyses ;; -------- ;; This example has three analyses: ;; an interval analysis consisting of a hi and lo component ;; and a non-zero analysis. ;; The non-zero analysis is built off the interval analysis (in order to prove ;; that rewrites are sound, even if some parts of an expr can't be const-evaled) ; TODO: unbounded intervals? (function hi (Math) Rational :merge (min old new)) (function lo (Math) Rational :merge (max old new)) (relation non-zero (Math)) ;; First, constant folding! ;; We don't need an explicit constant folding analysis, we can just union ;; with nums when we can ; Cases (rewrite (Add (Num a) (Num b)) (Num (+ a b))) (rewrite (Sub (Num a) (Num b)) (Num (- a b))) (rewrite (Mul (Num a) (Num b)) (Num (* a b))) (rewrite (Div (Num a) denom) (Num (/ a b)) :when ((= denom (Num b)) (non-zero denom))) (rewrite (Pow (Num a) (Num b)) (Num res) :when ((= res (pow a b)))) (rewrite (Neg (Num a)) (Num (neg a))) ;; TODO unimplemented ;; (rewrite (Sqrt (Num a)) (Num res) :when ((= res (sqrt a)))) ;; (rewrite (Cbrt (Num a)) (Num res) :when ((= res (cbrt a)))) (rewrite (Fabs (Num a)) (Num (abs a))) (rewrite (Ceil (Num a)) (Num (ceil a))) (rewrite (Floor (Num a)) (Num (floor a))) (rewrite (Round (Num a)) (Num (round a))) (rewrite (Log (Num a)) (Num res) :when ((= res (log a)))) ;; To check if something is zero, we check that zero is not contained in the ;; interval. There are two possible (overlapping!) cases: ;; - There exists a lo interval, in which case it must be larger than 0 ;; - There exists a hi interval, in which case it must be smaller than 0 ;; This assumes that intervals are well-formed: lo <= hi at all times. (rule ((= l (lo e)) (> l r-zero)) ((non-zero e))) (rule ((= h (hi e)) (< h r-zero)) ((non-zero e))) (rule ((= e (Num ve))) ((set (lo e) ve) (set (hi e) ve))) ;; The interval analyses are similar to the constant-folding analysis, ;; except we have to take the lower/upper bound of the results we get (rule ((= e (Add a b)) (= la (lo a)) (= lb (lo b))) ((set (lo e) (+ la lb)))) (rule ((= e (Add a b)) (= ha (hi a)) (= hb (hi b))) ((set (hi e) (+ ha hb)))) (rule ((= e (Sub a b)) (= la (lo a)) (= ha (hi a)) (= lb (lo b)) (= hb (hi b))) ((set (lo e) (min (min (- la lb) (- la hb)) (min (- ha lb) (- ha hb)))) (set (hi e) (max (max (- la lb) (- la hb)) (max (- ha lb) (- ha hb)))))) (rule ((= e (Mul a b)) (= la (lo a)) (= ha (hi a)) (= lb (lo b)) (= hb (hi b))) ((set (lo e) (min (min (* la lb) (* la hb)) (min (* ha lb) (* ha hb)))) (set (hi e) (max (max (* la lb) (* la hb)) (max (* ha lb) (* ha hb)))))) (rule ((= e (Div a b)) (= la (lo a)) (= ha (hi a)) (= lb (lo b)) (= hb (hi b))) ((set (lo e) (min (min (/ la lb) (/ la hb)) (min (/ ha lb) (/ ha hb)))) (set (hi e) (max (max (/ la lb) (/ la hb)) (max (/ ha lb) (/ ha hb)))))) ; TODO: Pow (rule ((= e (Neg a)) (= la (lo a)) (= ha (hi a))) ((set (lo e) (neg ha)) (set (hi e) (neg la)))) ; TODO: Sqrt ; TODO: Cbrt (rule ((= e (Fabs a)) (= la (lo a)) (= ha (hi a))) ((set (lo e) (min (abs la) (abs ha))) (set (hi e) (max (abs la) (abs ha))))) (rule ((= e (Ceil a)) (= la (lo a))) ((set (lo e) (ceil la)))) (rule ((= e (Ceil a)) (= ha (hi a))) ((set (hi e) (ceil ha)))) (rule ((= e (Floor a)) (= la (lo a))) ((set (lo e) (floor la)))) (rule ((= e (Floor a)) (= ha (hi a))) ((set (hi e) (floor ha)))) (rule ((= e (Round a)) (= la (lo a))) ((set (lo e) (round la)))) (rule ((= e (Round a)) (= ha (hi a))) ((set (hi e) (round ha)))) ; TODO: Log ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;; Rewrites ;; -------- ;; These rewrites were compiled from src/syntax/rules.rkt in the herbie repo, ;; using all rewrites in the `simplify` rewrite group. ;; Commutativity (rewrite (Add a b) (Add b a)) (rewrite (Mul a b) (Mul b a)) ;; Associativity (rewrite (Add a (Add b c)) (Add (Add a b) c)) (rewrite (Add (Add a b) c) (Add a (Add b c))) (rewrite (Add a (Sub b c)) (Sub (Add a b) c)) (rewrite (Add (Sub a b) c) (Sub a (Sub b c))) (rewrite (Sub a (Add b c)) (Sub (Sub a b) c)) (rewrite (Sub (Add a b) c) (Add a (Sub b c))) (rewrite (Sub (Sub a b) c) (Sub a (Add b c))) (rewrite (Sub a (Sub b c)) (Add (Sub a b) c)) (rewrite (Mul a (Mul b c)) (Mul (Mul a b) c)) (rewrite (Mul (Mul a b) c) (Mul a (Mul b c))) (rewrite (Mul a (Div b c)) (Div (Mul a b) c)) (rewrite (Mul (Div a b) c) (Div (Mul a c) b)) (rewrite (Div a (Mul b c)) (Div (Div a b) c)) (rewrite (Div (Mul b c) a) (Div b (Div a c)) :when ((non-zero c))) (rewrite (Div a (Div b c)) (Mul (Div a b) c) :when ((non-zero c))) (rewrite (Div (Div b c) a) (Div b (Mul a c)) :when ((non-zero a))) ;; Counting (rewrite (Add x x) (Mul two x)) ;; Distributivity (rewrite (Mul a (Add b c)) (Add (Mul a b) (Mul a c))) (rewrite (Mul a (Add b c)) (Add (Mul b a) (Mul c a))) (rewrite (Add (Mul a b) (Mul a c)) (Mul a (Add b c))) (rewrite (Sub (Mul a b) (Mul a c)) (Mul a (Sub b c))) (rewrite (Add (Mul b a) (Mul c a)) (Mul a (Add b c))) (rewrite (Sub (Mul b a) (Mul c a)) (Mul a (Sub b c))) (rewrite (Add (Mul b a) a) (Mul (Add b one) a)) (rewrite (Add a (Mul c a)) (Mul (Add c one) a)) (rewrite (Neg (Mul a b)) (Mul (Neg a) b)) (rewrite (Neg (Mul a b)) (Mul a (Neg b))) (rewrite (Mul (Neg a) b) (Neg (Mul a b))) (rewrite (Mul a (Neg b)) (Neg (Mul a b))) (rewrite (Neg (Add a b)) (Add (Neg a) (Neg b))) (rewrite (Add (Neg a) (Neg b)) (Neg (Add a b))) (rewrite (Div (Neg a) b) (Neg (Div a b))) (rewrite (Neg (Div a b)) (Div (Neg a) b)) (rewrite (Sub a (Mul (Neg b) c)) (Add a (Mul b c))) (rewrite (Sub a (Mul b c)) (Add a (Mul (Neg b) c))) ;; Difference of squares (rewrite (Mul (Mul a b) (Mul a b)) (Mul (Mul a a) (Mul b b))) (rewrite (Mul (Mul a a) (Mul b b)) (Mul (Mul a b) (Mul a b))) (rewrite (Sub (Mul a a) (Mul b b)) (Mul (Add a b) (Sub a b))) (rewrite (Sub (Mul a a) one) (Mul (Add a one) (Sub a one))) (rewrite (Add (Mul a a) (Neg one)) (Mul (Add a one) (Sub a one))) (rewrite (Pow a b) (Mul (Pow a (Div b two)) (Pow a (Div b two)))) (rewrite (Mul (Pow a b) (Pow a b)) (Pow a (Mul two b))) ;; Identity ;; This isn't subsumed by const folding since this can return results ;; even if we can't evaluate a precise value for x (rewrite (Div one (Div one x)) x :when ((non-zero x))) (rewrite (Mul x (Div one x)) one :when ((non-zero x))) (rewrite (Mul (Div one x) x) one :when ((non-zero x))) (rewrite (Sub x x) zero) (rewrite (Div x x) one :when ((non-zero x))) (rewrite (Div zero x) zero :when ((non-zero x))) (rewrite (Mul zero x) zero) (rewrite (Mul x zero) zero) (rewrite (Add zero x) x) (rewrite (Add x zero) x) (rewrite (Sub zero x) (Neg x)) (rewrite (Sub x zero) x) (rewrite (Neg (Neg x)) x) (rewrite (Mul one x) x) (rewrite (Mul x one) x) (rewrite (Div x one) x) (rewrite (Mul neg-one x) (Neg x)) (rewrite (Sub a b) (Add a (Neg b))) (rewrite (Add a (Neg b)) (Sub a b)) (rewrite (Neg x) (Sub zero x)) (rewrite (Neg x) (Mul neg-one x)) (rewrite (Div x y) (Mul x (Div one y))) (rewrite (Mul x (Div one y)) (Div x y)) (rewrite (Div x y) (Div one (Div y x)) :when ((non-zero x) (non-zero y))) ; FIXME: this rule can't be expressed in its full generality; ; we can't express the general rule x -> 1/x since ; we can't quantify over Math yet ; for now we just apply it to vars ; it's also p slow lmao (rewrite (Var x) (Mul one (Var x))) ;; Fractions (rewrite (Div (Sub a b) c) (Sub (Div a c) (Div b c))) (rewrite (Div (Mul a b) (Mul c d)) (Mul (Div a c) (Div b d))) ;; Square root (rewrite (Mul (Sqrt x) (Sqrt x)) x) (rewrite (Sqrt (Mul x x)) (Fabs x)) (rewrite (Mul (Neg x) (Neg x)) (Mul x x)) (rewrite (Mul (Fabs x) (Fabs x)) (Mul x x)) ;; Absolute values (rewrite (Fabs (Fabs x)) (Fabs x)) (rewrite (Fabs (Sub a b)) (Fabs (Sub b a))) (rewrite (Fabs (Neg x)) (Fabs x)) (rewrite (Fabs (Mul x x)) (Mul x x)) (rewrite (Fabs (Mul a b)) (Mul (Fabs a) (Fabs b))) (rewrite (Fabs (Div a b)) (Div (Fabs a) (Fabs b))) ;; Cube root (rewrite (Pow (Cbrt x) three) x) (rewrite (Cbrt (Pow x three)) x) (rewrite (Mul (Mul (Cbrt x) (Cbrt x)) (Cbrt x)) x) (rewrite (Mul (Cbrt x) (Mul (Cbrt x) (Cbrt x))) x) (rewrite (Pow (Neg x) three) (Neg (Pow x three))) (rewrite (Pow (Mul x y) three) (Mul (Pow x three) (Pow y three))) (rewrite (Pow (Div x y) three) (Div (Pow x three) (Pow y three))) (rewrite (Pow x three) (Mul x (Mul x x))) ; FIXME: this rewrite is slow and has the potential to blow up the egraph ; this is bc this rule and the second-to-last difference of squares rule ; have some cyclic behavior goin on ; the last identity rule compounds this behavior (rewrite (Mul x (Mul x x)) (Pow x three)) ;; Exponentials (rewrite (Unary "exp" (Log x)) x) (rewrite (Log (Unary "exp" x)) x) (rewrite (Unary "exp" zero) one) (rewrite (Unary "exp" one) (Const "E")) ;; (rewrite one (Unary "exp" zero)) (rewrite (Const "E") (Unary "exp" one)) (rewrite (Unary "exp" (Add a b)) (Mul (Unary "exp" a) (Unary "exp" b))) (rewrite (Unary "exp" (Sub a b)) (Div (Unary "exp" a) (Unary "exp" b))) (rewrite (Unary "exp" (Neg a)) (Div one (Unary "exp" a))) (rewrite (Mul (Unary "exp" a) (Unary "exp" b)) (Unary "exp" (Add a b))) (rewrite (Div one (Unary "exp" a)) (Unary "exp" (Neg a))) (rewrite (Div (Unary "exp" a) (Unary "exp" b)) (Unary "exp" (Sub a b))) (rewrite (Unary "exp" (Mul a b)) (Pow (Unary "exp" a) b)) (rewrite (Unary "exp" (Div a two)) (Sqrt (Unary "exp" a))) (rewrite (Unary "exp" (Div a three)) (Cbrt (Unary "exp" a))) (rewrite (Unary "exp" (Mul a two)) (Mul (Unary "exp" a) (Unary "exp" a))) (rewrite (Unary "exp" (Mul a three)) (Pow (Unary "exp" a) three)) ;; Powers (rewrite (Pow a neg-one) (Div one a)) (rewrite (Pow a one) a) ; 0^0 is undefined (rewrite (Pow a zero) one :when ((non-zero a))) (rewrite (Pow one a) one) (rewrite (Unary "Exp" (Mul (Log a) b)) (Pow a b)) (rewrite (Mul (Pow a b) a) (Pow a (Add b one))) (rewrite (Pow a (Num (rational 1 2))) (Sqrt a)) (rewrite (Pow a two) (Mul a a)) (rewrite (Pow a (Num (rational 1 3))) (Cbrt a)) (rewrite (Pow a three) (Mul (Mul a a) a)) ; 0^0 is undefined (rewrite (Pow zero a) zero :when ((non-zero a))) ;; Logarithms (rewrite (Log (Mul a b)) (Add (Log a) (Log b))) (rewrite (Log (Div a b)) (Sub (Log a) (Log b))) (rewrite (Log (Div one a)) (Neg (Log a))) (rewrite (Log (Pow a b)) (Mul b (Log a))) (rewrite (Log (Const "E")) one) ;; Trigonometry (rewrite (Add (Mul (Unary "cos" a) (Unary "cos" a)) (Mul (Unary "sin" a) (Unary "sin" a))) one) (rewrite (Sub one (Mul (Unary "cos" a) (Unary "cos" a))) (Mul (Unary "sin" a) (Unary "sin" a))) (rewrite (Sub one (Mul (Unary "sin" a) (Unary "sin" a))) (Mul (Unary "cos" a) (Unary "cos" a))) (rewrite (Add (Mul (Unary "cos" a) (Unary "cos" a)) (Num (rational -1 1))) (Neg (Mul (Unary "sin" a) (Unary "sin" a)))) (rewrite (Add (Mul (Unary "sin" a) (Unary "sin" a)) (Num (rational -1 1))) (Neg (Mul (Unary "cos" a) (Unary "cos" a)))) (rewrite (Sub (Mul (Unary "cos" a) (Unary "cos" a)) one) (Neg (Mul (Unary "sin" a) (Unary "sin" a)))) (rewrite (Sub (Mul (Unary "sin" a) (Unary "sin" a)) one) (Neg (Mul (Unary "cos" a) (Unary "cos" a)))) (rewrite (Unary "sin" (Div (Const "PI") (Num (rational 6 1)))) (Num (rational 1 2))) (rewrite (Unary "sin" (Div (Const "PI") (Num (rational 4 1)))) (Div (Sqrt two) two)) (rewrite (Unary "sin" (Div (Const "PI") three)) (Div (Sqrt three) two)) (rewrite (Unary "sin" (Div (Const "PI") two)) one) (rewrite (Unary "sin" (Const "PI")) zero) (rewrite (Unary "sin" (Add x (Const "PI"))) (Neg (Unary "sin" x))) (rewrite (Unary "sin" (Add x (Div (Const "PI") two))) (Unary "cos" x)) (rewrite (Unary "cos" (Div (Const "PI") (Num (rational 6 1)))) (Div (Sqrt three) two)) (rewrite (Unary "cos" (Div (Const "PI") (Num (rational 4 1)))) (Div (Sqrt two) two)) (rewrite (Unary "cos" (Div (Const "PI") three)) (Num (rational 1 2))) (rewrite (Unary "cos" (Div (Const "PI") two)) zero) (rewrite (Unary "cos" (Const "PI")) (Num (rational -1 1))) (rewrite (Unary "cos" (Add x (Const "PI"))) (Neg (Unary "cos" x))) (rewrite (Unary "cos" (Add x (Div (Const "PI") two))) (Neg (Unary "sin" x))) (rewrite (Unary "tan" (Div (Const "PI") (Num (rational 6 1)))) (Div one (Sqrt three))) (rewrite (Unary "tan" (Div (Const "PI") (Num (rational 4 1)))) one) (rewrite (Unary "tan" (Div (Const "PI") three)) (Sqrt three)) (rewrite (Unary "tan" (Const "PI")) zero) (rewrite (Unary "tan" (Add x (Const "PI"))) (Unary "tan" x)) (rewrite (Unary "tan" (Add x (Div (Const "PI") two))) (Div neg-one (Unary "tan" x))) (rewrite (Div (Unary "sin" a) (Add one (Unary "cos" a))) (Unary "tan" (Div a two))) (rewrite (Div (Neg (Unary "sin" a)) (Add one (Unary "cos" a))) (Unary "tan" (Div (Neg a) two))) (rewrite (Div (Sub one (Unary "cos" a)) (Unary "sin" a)) (Unary "tan" (Div a two))) (rewrite (Div (Sub one (Unary "cos" a)) (Neg (Unary "sin" a))) (Unary "tan" (Div (Neg a) two))) (rewrite (Div (Add (Unary "sin" a) (Unary "sin" b)) (Add (Unary "cos" a) (Unary "cos" b))) (Unary "tan" (Div (Add a b) two))) (rewrite (Div (Sub (Unary "sin" a) (Unary "sin" b)) (Add (Unary "cos" a) (Unary "cos" b))) (Unary "tan" (Div (Sub a b) two))) (rewrite (Unary "sin" zero) zero) (rewrite (Unary "cos" zero) one) (rewrite (Unary "tan" zero) zero) (rewrite (Unary "sin" (Neg x)) (Neg (Unary "sin" x))) (rewrite (Unary "cos" (Neg x)) (Unary "cos" x)) (rewrite (Unary "tan" (Neg x)) (Neg (Unary "cos" x))) ; Hyperbolics (rewrite (Unary "sinh" x) (Div (Sub (Unary "exp" x) (Unary "exp" (Neg x))) two)) (rewrite (Unary "cosh" x) (Div (Add (Unary "exp" x) (Unary "exp" (Neg x))) two)) (rewrite (Unary "tanh" x) (Div (Sub (Unary "exp" x) (Unary "exp" (Neg x))) (Add (Unary "exp" x) (Unary "exp" (Neg x))))) (rewrite (Unary "tanh" x) (Div (Sub (Unary "exp" (Mul two x)) one) (Add (Unary "exp" (Mul two x)) one))) (rewrite (Unary "tanh" x) (Div (Sub one (Unary "exp" (Mul (Num (rational -2 1)) x))) (Add one (Unary "exp" (Mul (Num (rational -2 1)) x))))) (rewrite (Sub (Mul (Unary "cosh" x) (Unary "cosh" x)) (Mul (Unary "sinh" x) (Unary "sinh" x))) one) (rewrite (Add (Unary "cosh" x) (Unary "sinh" x)) (Unary "exp" x)) (rewrite (Sub (Unary "cosh" x) (Unary "sinh" x)) (Unary "exp" (Neg x))) ;; Unimplemented: misc. rewrites (conditionals, specialized numerical fns) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; ;; Testing ;; ------- ;; In actuality, herbie would be responsible for plugging exprs in here. ;; For our purposes, we take some test cases from herbie ;; (src/core/simplify.rkt) (push) (let e (Add one zero)) (run 1) (check (= e one)) (pop) (push) (let five (Num (rational 5 1))) (let six (Num (rational 6 1))) (let e2 (Add one five)) (run 1) (check (= e2 six)) (pop) (let x (Var "x")) (push) (let e3 (Add x zero)) (run 1) (check (= e3 x)) (pop) (push) (let e4 (Sub x zero)) (run 1) (check (= e4 x)) (pop) (push) (let e5 (Mul x one)) (run 1) (check (= e5 x)) (pop) (push) (let e6 (Div x one)) (run 1) (check (= e6 x)) (pop) (push) (let e7 (Sub (Mul one x) (Mul (Add x one) one))) (run 3) (check (= e7 (Num (rational -1 1)))) (pop) (push) (let e8 (Sub (Add x one) x)) (run 4) (check (= e8 one)) (pop) (push) (let e9 (Sub (Add x one) one)) (run 4) (check (= e9 x)) (pop) (push) (set (lo x) r-one) (let e10 (Div (Mul x three) x)) (run 3) (check (= e10 three)) (pop) (push) (let e11 (Sub (Mul (Sqrt (Add x one)) (Sqrt (Add x one))) (Mul (Sqrt x) (Sqrt x)))) (run 5) (check (= one e11)) (pop) (push) (let e12 (Add (Num (rational 1 5)) (Num (rational 3 10)))) (run 1) (check (= e12 (Num (rational 1 2)))) (pop) (push) (let e13 (Unary "cos" (Const "PI"))) (run 1) (check (= e13 (Num (rational -1 1)))) (pop) (push) (let sqrt5 (Sqrt (Num (rational 5 1)))) (let e14 (Div one (Sub (Div (Add one sqrt5) two) (Div (Sub one sqrt5) two)))) (let tgt (Div one sqrt5)) (run 6) (check (= e14 tgt)) (pop)