//! A somewhat contrived arithmetic that parses string literals and only allows to add them //! and compare strings. use std::{fmt, str::FromStr}; use arithmetic_parser::{ grammars::{Parse, ParseLiteral}, BinaryOp, InputSpan, NomResult, }; use arithmetic_typing::{ arith::*, defs::Assertions, error::{ErrorPathFragment, OpErrors}, Annotated, PrimitiveType, Type, TypeEnvironment, }; /// Primitive type: string or boolean. #[derive(Debug, Clone, Copy, PartialEq)] enum StrType { Str, Bool, } impl fmt::Display for StrType { fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str(match self { Self::Str => "Str", Self::Bool => "Bool", }) } } impl FromStr for StrType { type Err = anyhow::Error; fn from_str(s: &str) -> Result { match s { "Str" => Ok(Self::Str), "Bool" => Ok(Self::Bool), _ => Err(anyhow::anyhow!("Expected `Str` or `Bool`")), } } } impl PrimitiveType for StrType {} impl WithBoolean for StrType { const BOOL: Self = Self::Bool; } impl LinearType for StrType { fn is_linear(&self) -> bool { matches!(self, Self::Str) } } /// Grammar parsing strings as literals. #[derive(Debug, Clone, Copy)] struct StrGrammar; impl ParseLiteral for StrGrammar { type Lit = String; /// Parses an ASCII string like `"Hello, world!"`. fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> { use nom::{ branch::alt, bytes::complete::{escaped_transform, is_not}, character::complete::char as tag_char, combinator::{cut, map, opt}, sequence::{preceded, terminated}, }; let parser = escaped_transform( is_not("\\\"\n"), '\\', alt(( map(tag_char('\\'), |_| "\\"), map(tag_char('"'), |_| "\""), map(tag_char('n'), |_| "\n"), )), ); map( preceded(tag_char('"'), cut(terminated(opt(parser), tag_char('"')))), Option::unwrap_or_default, )(input) } } #[derive(Debug, Clone, Copy)] struct StrArithmetic; impl MapPrimitiveType for StrArithmetic { type Prim = StrType; fn type_of_literal(&self, _lit: &String) -> Self::Prim { StrType::Str } } impl TypeArithmetic for StrArithmetic { fn process_unary_op( &self, substitutions: &mut Substitutions, context: &UnaryOpContext, errors: OpErrors<'_, StrType>, ) -> Type { BoolArithmetic.process_unary_op(substitutions, context, errors) } fn process_binary_op( &self, substitutions: &mut Substitutions, context: &BinaryOpContext, mut errors: OpErrors<'_, StrType>, ) -> Type { const OP_SETTINGS: OpConstraintSettings<'static, StrType> = OpConstraintSettings { lin: &Linearity, ops: &Ops, }; match context.op { BinaryOp::Add => { NumArithmetic::unify_binary_op(substitutions, context, errors, OP_SETTINGS) } BinaryOp::Gt | BinaryOp::Lt | BinaryOp::Ge | BinaryOp::Le => { let lhs_ty = &context.lhs; let rhs_ty = &context.rhs; substitutions.unify( &Type::Prim(StrType::Str), lhs_ty, errors.join_path(ErrorPathFragment::Lhs), ); substitutions.unify( &Type::Prim(StrType::Str), rhs_ty, errors.join_path(ErrorPathFragment::Rhs), ); Type::BOOL } _ => BoolArithmetic.process_binary_op(substitutions, context, errors), } } } type Parser = Annotated; fn main() -> anyhow::Result<()> { let code = r#" x = "foo" + "bar"; // Spreading logic is reused from `NumArithmetic` and just works. y = "foo" + ("bar", "quux"); // Boolean logic works as well. assert("bar" != "baz"); assert("foo" > "bar" && "foo" <= "quux"); "#; let ast = Parser::parse_statements(code)?; let mut env = TypeEnvironment::::new(); env.insert("assert", Assertions::Assert); env.process_with_arithmetic(&StrArithmetic, &ast)?; assert_eq!(env["x"], Type::Prim(StrType::Str)); assert_eq!(env["y"].to_string(), "(Str, Str)"); let bogus_code = r#""foo" - "bar""#; let bogus_ast = Parser::parse_statements(bogus_code)?; let err = env .process_with_arithmetic(&StrArithmetic, &bogus_ast) .unwrap_err(); assert_eq!(err.to_string(), "1:1: Unsupported binary op: subtraction"); Ok(()) }