(datatype Dim (Times Dim Dim) (NamedDim String) (Lit i64)) (rewrite (Times a (Times b c)) (Times (Times a b) c)) (rewrite (Times (Times a b) c) (Times a (Times b c)) ) (rewrite (Times (Lit i) (Lit j)) (Lit (* i j))) (rewrite (Times a b) (Times b a)) (datatype MExpr (MMul MExpr MExpr) (Kron MExpr MExpr) (NamedMat String) (Id Dim) ; DSum ; HStack ; VStack ; Transpose ; Inverse ; Zero Math Math ; ScalarMul ) ; alternative encoding (type A) = (Matrix n m) may be more useful for "large story example" (function nrows (MExpr) Dim) (function ncols (MExpr) Dim) (rewrite (nrows (Kron A B)) (Times (nrows A) (nrows B))) (rewrite (ncols (Kron A B)) (Times (ncols A) (ncols B))) (rewrite (nrows (MMul A B)) (nrows A)) (rewrite (ncols (MMul A B)) (ncols B)) (rewrite (nrows (Id n)) n) (rewrite (ncols (Id n)) n) (rewrite (MMul (Id n) A) A) (rewrite (MMul A (Id n)) A) (rewrite (MMul A (MMul B C)) (MMul (MMul A B) C)) (rewrite (MMul (MMul A B) C) (MMul A (MMul B C))) (rewrite (Kron A (Kron B C)) (Kron (Kron A B) C)) (rewrite (Kron (Kron A B) C) (Kron A (Kron B C))) (rewrite (Kron (MMul A C) (MMul B D)) (MMul (Kron A B) (Kron C D))) (rewrite (MMul (Kron A B) (Kron C D)) (Kron (MMul A C) (MMul B D)) :when ((= (ncols A) (nrows C)) (= (ncols B) (nrows D))) ) ; demand (rule ((= e (MMul A B))) ((ncols A) (nrows A) (ncols B) (nrows B)) ) (rule ((= e (Kron A B))) ((ncols A) (nrows A) (ncols B) (nrows B)) ) (let n (NamedDim "n")) (let m (NamedDim "m")) (let p (NamedDim "p")) (let A (NamedMat "A")) (let B (NamedMat "B")) (let C (NamedMat "C")) (union (nrows A) n) (union (ncols A) n) (union (nrows B) m) (union (ncols B) m) (union (nrows C) p) (union (ncols C) p) (let ex1 (MMul (Kron (Id n) B) (Kron A (Id m)))) (let rows1 (nrows ex1)) (let cols1 (ncols ex1)) (run 20) (check (= (nrows B) m)) (check (= (nrows (Kron (Id n) B)) (Times n m))) (let simple_ex1 (Kron A B)) (check (= ex1 simple_ex1)) (let ex2 (MMul (Kron (Id p) C) (Kron A (Id m)))) (run 10) (fail (check (= ex2 (Kron A C))))