local _tl_compat53 = ((tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3) and require('compat53.module'); local assert = _tl_compat53 and _tl_compat53.assert or assert; local io = _tl_compat53 and _tl_compat53.io or io; local ipairs = _tl_compat53 and _tl_compat53.ipairs or ipairs; local load = _tl_compat53 and _tl_compat53.load or load; local math = _tl_compat53 and _tl_compat53.math or math; local os = _tl_compat53 and _tl_compat53.os or os; local package = _tl_compat53 and _tl_compat53.package or package; local pairs = _tl_compat53 and _tl_compat53.pairs or pairs; local string = _tl_compat53 and _tl_compat53.string or string; local table = _tl_compat53 and _tl_compat53.table or table; local _tl_table_unpack = unpack or table.unpack; local Env = {} local TypeCheckOptions = {} local LoadMode = {} local LoadFunction = {} local tl = { load = nil, process = nil, process_string = nil, gen = nil, type_check = nil, init_env = nil, } local inspect = function(x) return tostring(x) end local keywords = { ["and"] = true, ["break"] = true, ["do"] = true, ["else"] = true, ["elseif"] = true, ["end"] = true, ["false"] = true, ["for"] = true, ["function"] = true, ["goto"] = true, ["if"] = true, ["in"] = true, ["local"] = true, ["nil"] = true, ["not"] = true, ["or"] = true, ["repeat"] = true, ["return"] = true, ["then"] = true, ["true"] = true, ["until"] = true, ["while"] = true, } local TokenKind = {} local Token = {} local lex_word_start = {} for c = string.byte("a"), string.byte("z") do lex_word_start[string.char(c)] = true end for c = string.byte("A"), string.byte("Z") do lex_word_start[string.char(c)] = true end lex_word_start["_"] = true local lex_word = {} for c = string.byte("a"), string.byte("z") do lex_word[string.char(c)] = true end for c = string.byte("A"), string.byte("Z") do lex_word[string.char(c)] = true end for c = string.byte("0"), string.byte("9") do lex_word[string.char(c)] = true end lex_word["_"] = true local lex_decimal_start = {} for c = string.byte("1"), string.byte("9") do lex_decimal_start[string.char(c)] = true end local lex_decimals = {} for c = string.byte("0"), string.byte("9") do lex_decimals[string.char(c)] = true end local lex_hexadecimals = {} for c = string.byte("0"), string.byte("9") do lex_hexadecimals[string.char(c)] = true end for c = string.byte("a"), string.byte("f") do lex_hexadecimals[string.char(c)] = true end for c = string.byte("A"), string.byte("F") do lex_hexadecimals[string.char(c)] = true end local lex_char_symbols = {} for _, c in ipairs({ "[", "]", "(", ")", "{", "}", ",", "#", "`", ";" }) do lex_char_symbols[c] = true end local lex_op_start = {} for _, c in ipairs({ "+", "*", "/", "|", "&", "%", "^" }) do lex_op_start[c] = true end local lex_space = {} for _, c in ipairs({ " ", "\t", "\v", "\n", "\r" }) do lex_space[c] = true end local LexState = {} function tl.lex(input) local tokens = {} local state = "start" local fwd = true local y = 1 local x = 0 local i = 0 local lc_open_lvl = 0 local lc_close_lvl = 0 local ls_open_lvl = 0 local ls_close_lvl = 0 local errs = {} local tx local ty local ti local in_token = false local function begin_token() tx = x ty = y ti = i in_token = true end local function end_token(kind, last, t) local tk = t or input:sub(ti, last or i) or "" if keywords[tk] then kind = "keyword" end table.insert(tokens, { x = tx, y = ty, i = ti, tk = tk, kind = kind, }) in_token = false end local function drop_token() in_token = false end while i <= #input do if fwd then i = i + 1 if i > #input then break end end local c = input:sub(i, i) if fwd then if c == "\n" then y = y + 1 x = 0 else x = x + 1 end else fwd = true end if state == "start" then if input:sub(1, 2) == "#!" then i = input:find("\n") if not i then break end c = "\n" y = 2 x = 0 end state = "any" end if state == "any" then if c == "-" then state = "maybecomment" begin_token() elseif c == "." then state = "maybedotdot" begin_token() elseif c == "\"" then state = "dblquote_string" begin_token() elseif c == "'" then state = "singlequote_string" begin_token() elseif lex_word_start[c] then state = "identifier" begin_token() elseif c == "0" then state = "decimal_or_hex" begin_token() elseif lex_decimal_start[c] then state = "decimal_number" begin_token() elseif c == "<" then state = "lt" begin_token() elseif c == ":" then state = "colon" begin_token() elseif c == ">" then state = "gt" begin_token() elseif c == "=" or c == "~" then state = "maybeequals" begin_token() elseif c == "[" then state = "maybelongstring" begin_token() elseif lex_char_symbols[c] then begin_token() end_token(c) elseif lex_op_start[c] then begin_token() end_token("op") elseif lex_space[c] then else begin_token() end_token("$invalid$") table.insert(errs, tokens[#tokens]) end elseif state == "maybecomment" then if c == "-" then state = "maybecomment2" else end_token("op", nil, "-") fwd = false state = "any" end elseif state == "maybecomment2" then if c == "[" then state = "maybelongcomment" else fwd = false state = "comment" drop_token() end elseif state == "maybelongcomment" then if c == "[" then state = "longcomment" elseif c == "=" then lc_open_lvl = lc_open_lvl + 1 else fwd = false state = "comment" drop_token() lc_open_lvl = 0 end elseif state == "longcomment" then if c == "]" then state = "maybelongcommentend" end elseif state == "maybelongcommentend" then if c == "]" and lc_close_lvl == lc_open_lvl then drop_token() state = "any" lc_open_lvl = 0 lc_close_lvl = 0 elseif c == "=" then lc_close_lvl = lc_close_lvl + 1 else state = "longcomment" lc_close_lvl = 0 end elseif state == "dblquote_string" then if c == "\\" then state = "escape_dblquote_string" elseif c == "\"" then end_token("string") state = "any" end elseif state == "escape_dblquote_string" then state = "dblquote_string" elseif state == "singlequote_string" then if c == "\\" then state = "escape_singlequote_string" elseif c == "'" then end_token("string") state = "any" end elseif state == "escape_singlequote_string" then state = "singlequote_string" elseif state == "maybeequals" then if c == "=" then end_token("op") state = "any" else end_token("op", i - 1) fwd = false state = "any" end elseif state == "lt" then if c == "=" or c == "<" then end_token("op") state = "any" else end_token("op", i - 1) fwd = false state = "any" end elseif state == "colon" then if c == ":" then end_token("::") state = "any" else end_token(":", i - 1) fwd = false state = "any" end elseif state == "gt" then if c == "=" or c == ">" then end_token("op") state = "any" else end_token("op", i - 1) fwd = false state = "any" end elseif state == "maybelongstring" then if c == "[" then state = "longstring" elseif c == "=" then ls_open_lvl = ls_open_lvl + 1 else end_token("[", i - 1) fwd = false state = "any" ls_open_lvl = 0 end elseif state == "longstring" then if c == "]" then state = "maybelongstringend" end elseif state == "maybelongstringend" then if c == "]" then if ls_close_lvl == ls_open_lvl then end_token("string") state = "any" ls_open_lvl = 0 ls_close_lvl = 0 end elseif c == "=" then ls_close_lvl = ls_close_lvl + 1 else state = "longstring" ls_close_lvl = 0 end elseif state == "maybedotdot" then if c == "." then state = "maybedotdotdot" elseif lex_decimals[c] then state = "decimal_float" else end_token(".", i - 1) fwd = false state = "any" end elseif state == "maybedotdotdot" then if c == "." then end_token("...") state = "any" else end_token("op", i - 1) fwd = false state = "any" end elseif state == "comment" then if c == "\n" then state = "any" end elseif state == "identifier" then if not lex_word[c] then end_token("identifier", i - 1) fwd = false state = "any" end elseif state == "decimal_or_hex" then if c == "x" or c == "X" then state = "hex_number" elseif c == "e" or c == "E" then state = "power_sign" elseif lex_decimals[c] then state = "decimal_number" elseif c == "." then state = "decimal_float" else end_token("number", i - 1) fwd = false state = "any" end elseif state == "hex_number" then if c == "." then state = "hex_float" elseif c == "p" or c == "P" then state = "power_sign" elseif not lex_hexadecimals[c] then end_token("number", i - 1) fwd = false state = "any" end elseif state == "hex_float" then if c == "p" or c == "P" then state = "power_sign" elseif not lex_hexadecimals[c] then end_token("number", i - 1) fwd = false state = "any" end elseif state == "decimal_number" then if c == "." then state = "decimal_float" elseif c == "e" or c == "E" then state = "power_sign" elseif not lex_decimals[c] then end_token("number", i - 1) fwd = false state = "any" end elseif state == "decimal_float" then if c == "e" or c == "E" then state = "power_sign" elseif not lex_decimals[c] then end_token("number", i - 1) fwd = false state = "any" end elseif state == "power_sign" then if c == "-" or c == "+" then state = "power" elseif lex_decimals[c] then state = "power" else end_token("$invalid$") table.insert(errs, tokens[#tokens]) state = "any" end elseif state == "power" then if not lex_decimals[c] then end_token("number", i - 1) fwd = false state = "any" end end end local terminals = { ["identifier"] = "identifier", ["decimal_or_hex"] = "number", ["decimal_number"] = "number", ["decimal_float"] = "number", ["hex_number"] = "number", ["hex_float"] = "number", ["power"] = "number", } if in_token then if terminals[state] then end_token(terminals[state], i - 1) else drop_token() end end return tokens, (#errs > 0) and errs end local add_space = { ["word:keyword"] = true, ["word:word"] = true, ["word:string"] = true, ["word:="] = true, ["word:op"] = true, ["keyword:word"] = true, ["keyword:keyword"] = true, ["keyword:string"] = true, ["keyword:number"] = true, ["keyword:="] = true, ["keyword:op"] = true, ["keyword:{"] = true, ["keyword:("] = true, ["keyword:#"] = true, ["=:word"] = true, ["=:keyword"] = true, ["=:string"] = true, ["=:number"] = true, ["=:{"] = true, ["=:("] = true, ["op:("] = true, ["op:{"] = true, ["op:#"] = true, [",:word"] = true, [",:keyword"] = true, [",:string"] = true, [",:{"] = true, ["):op"] = true, ["):word"] = true, ["):keyword"] = true, ["op:string"] = true, ["op:number"] = true, ["op:word"] = true, ["op:keyword"] = true, ["]:word"] = true, ["]:keyword"] = true, ["]:="] = true, ["]:op"] = true, ["string:op"] = true, ["string:word"] = true, ["string:keyword"] = true, ["number:word"] = true, ["number:keyword"] = true, } local should_unindent = { ["end"] = true, ["elseif"] = true, ["else"] = true, ["}"] = true, } local should_indent = { ["{"] = true, ["for"] = true, ["if"] = true, ["while"] = true, ["elseif"] = true, ["else"] = true, ["function"] = true, } function tl.pretty_print_tokens(tokens) local y = 1 local out = {} local indent = 0 local newline = false local kind = "" for _, t in ipairs(tokens) do while t.y > y do table.insert(out, "\n") y = y + 1 newline = true kind = "" end if should_unindent[t.tk] then indent = indent - 1 if indent < 0 then indent = 0 end end if newline then for _ = 1, indent do table.insert(out, " ") end newline = false end if should_indent[t.tk] then indent = indent + 1 end if add_space[(kind or "") .. ":" .. t.kind] then table.insert(out, " ") end table.insert(out, t.tk) kind = t.kind or "" end return table.concat(out) end local last_typeid = 0 local function new_typeid() last_typeid = last_typeid + 1 return last_typeid end local ParseError = {} local TypeName = {} local table_types = { ["array"] = true, ["map"] = true, ["arrayrecord"] = true, ["record"] = true, ["emptytable"] = true, } local Type = {} local Operator = {} local NodeKind = {} local FactType = {} local Fact = {} local KeyParsed = {} local Node = {} local function is_array_type(t) return t.typename == "array" or t.typename == "arrayrecord" end local function is_record_type(t) return t.typename == "record" or t.typename == "arrayrecord" end local function is_type(t) return t.typename == "typetype" or t.typename == "nestedtype" end local ParseState = {} local ParseTypeListMode = {} local parse_type_list local parse_expression local parse_statements local parse_argument_list local parse_argument_type_list local parse_type local parse_newtype local function fail(ps, i, msg) if not ps.tokens[i] then local eof = ps.tokens[#ps.tokens] table.insert(ps.errs, { y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) return #ps.tokens end table.insert(ps.errs, { y = ps.tokens[i].y, x = ps.tokens[i].x, msg = msg or "syntax error" }) return math.min(#ps.tokens, i + 1) end local function verify_tk(ps, i, tk) if ps.tokens[i].tk == tk then return i + 1 end return fail(ps, i, "syntax error, expected '" .. tk .. "'") end local function new_node(tokens, i, kind) local t = tokens[i] return { y = t.y, x = t.x, tk = t.tk, kind = kind or t.kind } end local function a_type(t) t.typeid = new_typeid() return t end local function new_type(ps, i, typename) local token = ps.tokens[i] return a_type({ typename = assert(typename), filename = ps.filename, y = token.y, x = token.x, tk = token.tk, }) end local function verify_kind(ps, i, kind, node_kind) if ps.tokens[i].kind == kind then return i + 1, new_node(ps.tokens, i, node_kind) end return fail(ps, i, "syntax error, expected " .. kind) end local is_newtype = { ["enum"] = true, ["record"] = true, } local function parse_table_value(ps, i) if is_newtype[ps.tokens[i].tk] then return parse_newtype(ps, i) else local i, node, _ = parse_expression(ps, i) return i, node end end local function parse_table_item(ps, i, n) local node = new_node(ps.tokens, i, "table_item") if ps.tokens[i].kind == "$EOF$" then return fail(ps, i) end if ps.tokens[i].tk == "[" then node.key_parsed = "long" i = i + 1 i, node.key = parse_expression(ps, i) i = verify_tk(ps, i, "]") i = verify_tk(ps, i, "=") i, node.value = parse_table_value(ps, i) return i, node, n elseif ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == "=" then node.key_parsed = "short" i, node.key = verify_kind(ps, i, "identifier", "string") node.key.conststr = node.key.tk node.key.tk = '"' .. node.key.tk .. '"' i = verify_tk(ps, i, "=") i, node.value = parse_table_value(ps, i) return i, node, n elseif ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then node.key_parsed = "short" local orig_i = i local try_ps = { filename = ps.filename, tokens = ps.tokens, errs = {}, } i, node.key = verify_kind(try_ps, i, "identifier", "string") node.key.conststr = node.key.tk node.key.tk = '"' .. node.key.tk .. '"' i = verify_tk(try_ps, i, ":") i, node.decltype = parse_type(try_ps, i) if node.decltype and ps.tokens[i].tk == "=" then i = verify_tk(try_ps, i, "=") i, node.value = parse_table_value(try_ps, i) if node.value then for _, e in ipairs(try_ps.errs) do table.insert(ps.errs, e) end return i, node, n end end node.decltype = nil i = orig_i end node.key = new_node(ps.tokens, i, "number") node.key_parsed = "implicit" node.key.constnum = n node.key.tk = tostring(n) i, node.value = parse_expression(ps, i) return i, node, n + 1 end local ParseItem = {} local SeparatorMode = {} local function parse_list(ps, i, list, close, sep, parse_item) local n = 1 while ps.tokens[i].kind ~= "$EOF$" do if close[ps.tokens[i].tk] then (list).yend = ps.tokens[i].y break end local item i, item, n = parse_item(ps, i, n) table.insert(list, item) if ps.tokens[i].tk == "," then i = i + 1 if sep == "sep" and close[ps.tokens[i].tk] then return fail(ps, i) end elseif sep == "term" and ps.tokens[i].tk == ";" then i = i + 1 elseif not close[ps.tokens[i].tk] then return fail(ps, i) end end return i, list end local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) i = verify_tk(ps, i, open) i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) i = verify_tk(ps, i, close) return i, list end local function parse_table_literal(ps, i) local node = new_node(ps.tokens, i, "table_literal") return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) end local function parse_trying_list(ps, i, list, parse_item) local try_ps = { filename = ps.filename, tokens = ps.tokens, errs = {}, } local tryi, item = parse_item(try_ps, i) if not item then return i, list end for _, e in ipairs(try_ps.errs) do table.insert(ps.errs, e) end i = tryi table.insert(list, item) if ps.tokens[i].tk == "," then while ps.tokens[i].tk == "," do i = i + 1 i, item = parse_item(ps, i) table.insert(list, item) end end return i, list end local function parse_typearg_type(ps, i) local backtick = false if ps.tokens[i].tk == "`" then i = verify_tk(ps, i, "`") backtick = true end i = verify_kind(ps, i, "identifier") return i, a_type({ y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, typename = "typearg", typearg = (backtick and "`" or "") .. ps.tokens[i - 1].tk, }) end local function parse_typevar_type(ps, i) i = verify_tk(ps, i, "`") i = verify_kind(ps, i, "identifier") return i, a_type({ y = ps.tokens[i - 2].y, x = ps.tokens[i - 2].x, typename = "typevar", typevar = "`" .. ps.tokens[i - 1].tk, }) end local function parse_typearg_list(ps, i) local typ = new_type(ps, i, "tuple") return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_typearg_type) end local function parse_typeval_list(ps, i) local typ = new_type(ps, i, "tuple") return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_type) end local function parse_return_types(ps, i) return parse_type_list(ps, i, "rets") end local function parse_function_type(ps, i) local node = new_type(ps, i, "function") node.args = {} node.rets = {} i = i + 1 if ps.tokens[i].tk == "<" then i, node.typeargs = parse_typearg_list(ps, i) end if ps.tokens[i].tk == "(" then i, node.args = parse_argument_type_list(ps, i) i, node.rets = parse_return_types(ps, i) else node.args = { a_type({ typename = "any", is_va = true }) } node.rets = { a_type({ typename = "any", is_va = true }) } end return i, node end local function parse_base_type(ps, i) if ps.tokens[i].tk == "string" or ps.tokens[i].tk == "boolean" or ps.tokens[i].tk == "nil" or ps.tokens[i].tk == "number" or ps.tokens[i].tk == "thread" then local typ = new_type(ps, i, ps.tokens[i].tk) typ.tk = nil return i + 1, typ elseif ps.tokens[i].tk == "table" then local typ = new_type(ps, i, "map") typ.keys = a_type({ typename = "any" }) typ.values = a_type({ typename = "any" }) return i + 1, typ elseif ps.tokens[i].tk == "function" then return parse_function_type(ps, i) elseif ps.tokens[i].tk == "{" then i = i + 1 local decl = new_type(ps, i, "array") local t i, t = parse_type(ps, i) if ps.tokens[i].tk == "}" then decl.elements = t decl.yend = ps.tokens[i].y i = verify_tk(ps, i, "}") elseif ps.tokens[i].tk == ":" then decl.typename = "map" i = i + 1 decl.keys = t i, decl.values = parse_type(ps, i) decl.yend = ps.tokens[i].y i = verify_tk(ps, i, "}") end return i, decl elseif ps.tokens[i].tk == "`" then return parse_typevar_type(ps, i) elseif ps.tokens[i].kind == "identifier" then local typ = new_type(ps, i, "nominal") typ.names = { ps.tokens[i].tk } i = i + 1 while ps.tokens[i].tk == "." do i = i + 1 if ps.tokens[i].kind == "identifier" then table.insert(typ.names, ps.tokens[i].tk) i = i + 1 else return fail(ps, i, "syntax error, expected identifier") end end if ps.tokens[i].tk == "<" then i, typ.typevals = parse_typeval_list(ps, i) end return i, typ end return fail(ps, i) end parse_type = function(ps, i) if ps.tokens[i].tk == "(" then i = i + 1 local t i, t = parse_type(ps, i) i = verify_tk(ps, i, ")") return i, t end local bt local istart = i i, bt = parse_base_type(ps, i) if not bt then return i end if ps.tokens[i].tk == "|" then local u = new_type(ps, istart, "union") u.types = { bt } while ps.tokens[i].tk == "|" do i = i + 1 i, bt = parse_base_type(ps, i) if not bt then return i end table.insert(u.types, bt) end bt = u end return i, bt end parse_type_list = function(ps, i, mode) local list = new_type(ps, i, "tuple") local first_token = ps.tokens[i].tk if mode == "rets" or mode == "decltype" then if first_token == ":" then i = i + 1 else return i, list end end local optional_paren = false if ps.tokens[i].tk == "(" then optional_paren = true i = i + 1 end local prev_i = i i = parse_trying_list(ps, i, list, parse_type) if i == prev_i and ps.tokens[i].tk ~= ")" then fail(ps, i - 1, "expected a type list") end if mode == "rets" and ps.tokens[i].tk == "..." then i = i + 1 local nrets = #list if nrets > 0 then list[nrets].is_va = true else return fail(ps, i, "unexpected '...'") end end if optional_paren then i = verify_tk(ps, i, ")") end return i, list end local function parse_function_args_rets_body(ps, i, node) if ps.tokens[i].tk == "<" then i, node.typeargs = parse_typearg_list(ps, i) end i, node.args = parse_argument_list(ps, i) i, node.rets = parse_return_types(ps, i) i, node.body = parse_statements(ps, i) node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end local function parse_function_value(ps, i) local node = new_node(ps.tokens, i, "function") i = verify_tk(ps, i, "function") return parse_function_args_rets_body(ps, i, node) end local function unquote(str) local f = str:sub(1, 1) if f == '"' or f == "'" then return str:sub(2, -2) end f = str:match("^%[=*%[") local l = #f + 1 return str:sub(l, -l) end local function parse_literal(ps, i) if ps.tokens[i].tk == "{" then return parse_table_literal(ps, i) elseif ps.tokens[i].kind == "..." then return verify_kind(ps, i, "...") elseif ps.tokens[i].kind == "string" then local tk = unquote(ps.tokens[i].tk) local node i, node = verify_kind(ps, i, "string") node.conststr = tk return i, node elseif ps.tokens[i].kind == "identifier" then return verify_kind(ps, i, "identifier", "variable") elseif ps.tokens[i].kind == "number" then local n = tonumber(ps.tokens[i].tk) local node i, node = verify_kind(ps, i, "number") node.constnum = n return i, node elseif ps.tokens[i].tk == "true" then return verify_kind(ps, i, "keyword", "boolean") elseif ps.tokens[i].tk == "false" then return verify_kind(ps, i, "keyword", "boolean") elseif ps.tokens[i].tk == "nil" then return verify_kind(ps, i, "keyword", "nil") elseif ps.tokens[i].tk == "function" then return parse_function_value(ps, i) end return fail(ps, i) end do local precedences = { [1] = { ["not"] = 11, ["#"] = 11, ["-"] = 11, ["~"] = 11, }, [2] = { ["or"] = 1, ["and"] = 2, ["is"] = 3, ["<"] = 3, [">"] = 3, ["<="] = 3, [">="] = 3, ["~="] = 3, ["=="] = 3, ["|"] = 4, ["~"] = 5, ["&"] = 6, ["<<"] = 7, [">>"] = 7, [".."] = 8, ["+"] = 8, ["-"] = 9, ["*"] = 10, ["/"] = 10, ["//"] = 10, ["%"] = 10, ["^"] = 12, ["as"] = 50, ["@funcall"] = 100, ["@index"] = 100, ["."] = 100, [":"] = 100, }, } local is_right_assoc = { ["^"] = true, [".."] = true, } local function new_operator(tk, arity, op) op = op or tk.tk return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } end local E local function P(ps, i) if ps.tokens[i].kind == "$EOF$" then return i end local e1 local t1 = ps.tokens[i] if precedences[1][ps.tokens[i].tk] ~= nil then local op = new_operator(ps.tokens[i], 1) i = i + 1 i, e1 = P(ps, i) e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } elseif ps.tokens[i].tk == "(" then i = i + 1 i, e1 = parse_expression(ps, i) e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } i = verify_tk(ps, i, ")") else i, e1 = parse_literal(ps, i) end while true do if ps.tokens[i].kind == "string" or ps.tokens[i].kind == "{" then local op = new_operator(ps.tokens[i], 2, "@funcall") local args = new_node(ps.tokens, i, "expression_list") local arg if ps.tokens[i].kind == "string" then arg = new_node(ps.tokens, i) arg.conststr = unquote(ps.tokens[i].tk) i = i + 1 else i, arg = parse_table_literal(ps, i) end table.insert(args, arg) e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = args } elseif ps.tokens[i].tk == "(" then local op = new_operator(ps.tokens[i], 2, "@funcall") local args = new_node(ps.tokens, i, "expression_list") i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = args } elseif ps.tokens[i].tk == "[" then local op = new_operator(ps.tokens[i], 2, "@index") local idx i = i + 1 i, idx = parse_expression(ps, i) i = verify_tk(ps, i, "]") e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = idx } elseif ps.tokens[i].tk == "." or ps.tokens[i].tk == ":" then local op = new_operator(ps.tokens[i], 2) local key i = i + 1 i, key = verify_kind(ps, i, "identifier") e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = key } elseif ps.tokens[i].tk == "as" or ps.tokens[i].tk == "is" then local op = new_operator(ps.tokens[i], 2, ps.tokens[i].tk) i = i + 1 local cast = new_node(ps.tokens, i, "cast") if ps.tokens[i].tk == "(" then i, cast.casttype = parse_type_list(ps, i, "casttype") else i, cast.casttype = parse_type(ps, i) end e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } else break end end return i, e1 end local function E(ps, i, lhs, min_precedence) local lookahead = ps.tokens[i].tk while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do local t1 = ps.tokens[i] local op = new_operator(t1, 2) i = i + 1 local rhs i, rhs = P(ps, i) lookahead = ps.tokens[i].tk while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or (is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do i, rhs = E(ps, i, rhs, precedences[2][lookahead]) lookahead = ps.tokens[i].tk end lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } end return i, lhs end parse_expression = function(ps, i) local lhs i, lhs = P(ps, i) i, lhs = E(ps, i, lhs, 0) if lhs then return i, lhs, 0 else return fail(ps, i, "expected an expression") end end end local function parse_variable_name(ps, i) local is_const = false local node i, node = verify_kind(ps, i, "identifier") if not node then return i end if ps.tokens[i].tk == "<" then i = i + 1 local annotation i, annotation = verify_kind(ps, i, "identifier") if annotation and annotation.tk == "const" then is_const = true end i = verify_tk(ps, i, ">") end node.is_const = is_const return i, node end local function parse_argument(ps, i) local node if ps.tokens[i].tk == "..." then i, node = verify_kind(ps, i, "...") else i, node = verify_kind(ps, i, "identifier", "argument") end if ps.tokens[i].tk == ":" then i = i + 1 local decltype i, decltype = parse_type(ps, i) if node then i, node.decltype = i, decltype end end return i, node, 0 end parse_argument_list = function(ps, i) local node = new_node(ps.tokens, i, "argument_list") return parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) end local function parse_argument_type(ps, i) local is_va = false if ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then i = i + 2 elseif ps.tokens[i].tk == "..." then if ps.tokens[i + 1].tk == ":" then i = i + 2 is_va = true else return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") end end local i, typ = parse_type(ps, i) if typ then typ.is_va = is_va end return i, typ, 0 end parse_argument_type_list = function(ps, i) local list = new_type(ps, i, "tuple") return parse_bracket_list(ps, i, list, "(", ")", "sep", parse_argument_type) end local function parse_local_function(ps, i) local node = new_node(ps.tokens, i, "local_function") i = verify_tk(ps, i, "local") i = verify_tk(ps, i, "function") i, node.name = verify_kind(ps, i, "identifier") return parse_function_args_rets_body(ps, i, node) end local function parse_function(ps, i) local orig_i = i local fn = new_node(ps.tokens, i, "global_function") local node = fn i = verify_tk(ps, i, "function") local names = {} i, names[1] = verify_kind(ps, i, "identifier", "variable") while ps.tokens[i].tk == "." do i = i + 1 i, names[#names + 1] = verify_kind(ps, i, "identifier") end if ps.tokens[i].tk == ":" then i = i + 1 i, names[#names + 1] = verify_kind(ps, i, "identifier") fn.is_method = true end if #names > 1 then fn.kind = "record_function" local owner = names[1] for i = 2, #names - 1 do local dot = { y = names[i].y, x = names[i].x - 1, arity = 2, op = "." } names[i].kind = "identifier" local op = { y = names[i].y, x = names[i].x, kind = "op", op = dot, e1 = owner, e2 = names[i] } owner = op end fn.fn_owner = owner end fn.name = names[#names] local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y i = parse_function_args_rets_body(ps, i, fn) if fn.is_method then table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "variable" }) end if not fn.name then return orig_i end return i, node end local function parse_if(ps, i) local node = new_node(ps.tokens, i, "if") i = verify_tk(ps, i, "if") i, node.exp = parse_expression(ps, i) i = verify_tk(ps, i, "then") i, node.thenpart = parse_statements(ps, i) node.elseifs = {} local n = 0 while ps.tokens[i].tk == "elseif" do n = n + 1 local subnode = new_node(ps.tokens, i, "elseif") subnode.parent_if = node subnode.elseif_n = n i = i + 1 i, subnode.exp = parse_expression(ps, i) i = verify_tk(ps, i, "then") i, subnode.thenpart = parse_statements(ps, i) table.insert(node.elseifs, subnode) end if ps.tokens[i].tk == "else" then local subnode = new_node(ps.tokens, i, "else") subnode.parent_if = node i = i + 1 i, subnode.elsepart = parse_statements(ps, i) node.elsepart = subnode end node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end local function parse_while(ps, i) local node = new_node(ps.tokens, i, "while") i = verify_tk(ps, i, "while") i, node.exp = parse_expression(ps, i) i = verify_tk(ps, i, "do") i, node.body = parse_statements(ps, i) node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end local function parse_fornum(ps, i) local node = new_node(ps.tokens, i, "fornum") i = i + 1 i, node.var = verify_kind(ps, i, "identifier") i = verify_tk(ps, i, "=") i, node.from = parse_expression(ps, i) i = verify_tk(ps, i, ",") i, node.to = parse_expression(ps, i) if ps.tokens[i].tk == "," then i = i + 1 i, node.step = parse_expression(ps, i) end i = verify_tk(ps, i, "do") i, node.body = parse_statements(ps, i) node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end local function parse_forin(ps, i) local node = new_node(ps.tokens, i, "forin") i = i + 1 node.vars = new_node(ps.tokens, i, "variables") i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_variable_name) i = verify_tk(ps, i, "in") node.exps = new_node(ps.tokens, i, "expression_list") i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) if #node.exps < 1 then return fail(ps, i, "missing iterator expression in generic for") elseif #node.exps > 3 then return fail(ps, i, "too many expressions in generic for") end i = verify_tk(ps, i, "do") i, node.body = parse_statements(ps, i) node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end local function parse_for(ps, i) if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then return parse_fornum(ps, i) else return parse_forin(ps, i) end end local function parse_repeat(ps, i) local node = new_node(ps.tokens, i, "repeat") i = verify_tk(ps, i, "repeat") i, node.body = parse_statements(ps, i) node.body.is_repeat = true node.yend = ps.tokens[i].y i = verify_tk(ps, i, "until") i, node.exp = parse_expression(ps, i) return i, node end local function parse_do(ps, i) local node = new_node(ps.tokens, i, "do") i = verify_tk(ps, i, "do") i, node.body = parse_statements(ps, i) node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end local function parse_break(ps, i) local node = new_node(ps.tokens, i, "break") i = verify_tk(ps, i, "break") return i, node end local function parse_goto(ps, i) local node = new_node(ps.tokens, i, "goto") i = verify_tk(ps, i, "goto") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") return i, node end local function parse_label(ps, i) local node = new_node(ps.tokens, i, "label") i = verify_tk(ps, i, "::") node.label = ps.tokens[i].tk i = verify_kind(ps, i, "identifier") i = verify_tk(ps, i, "::") return i, node end local stop_statement_list = { ["end"] = true, ["else"] = true, ["elseif"] = true, ["until"] = true, } local stop_return_list = { [";"] = true, ["$EOF$"] = true, } for k, v in pairs(stop_statement_list) do stop_return_list[k] = v end local function parse_return(ps, i) local node = new_node(ps.tokens, i, "return") i = verify_tk(ps, i, "return") node.exps = new_node(ps.tokens, i, "expression_list") i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) if ps.tokens[i].kind == ";" then i = i + 1 end return i, node end local function store_field_in_record(name, def, nt) if def.fields[name] then return false end def.fields[name] = nt.newtype table.insert(def.field_order, name) return true end local ParseBody = {} local function parse_nested_type(ps, i, def, typename, parse_body) i = i + 1 local v i, v = verify_kind(ps, i, "identifier", "variable") if not v then return fail(ps, i, "expected a variable name") end local nt = new_node(ps.tokens, i, "newtype") nt.newtype = new_type(ps, i, "typetype") local rdef = new_type(ps, i, typename) local iok = parse_body(ps, i, rdef, nt) if iok then i = iok nt.newtype.def = rdef end local ok = store_field_in_record(v.tk, def, nt) if not ok then fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)") end return i end local function parse_enum_body(ps, i, def, node) def.enumset = {} while not ((not ps.tokens[i]) or ps.tokens[i].tk == "end") do local item i, item = verify_kind(ps, i, "string", "enum_item") if item then table.insert(node, item) def.enumset[unquote(item.tk)] = true end end node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end local function parse_record_body(ps, i, def, node) def.fields = {} def.field_order = {} if ps.tokens[i].tk == "<" then i, def.typeargs = parse_typearg_list(ps, i) end while not ((not ps.tokens[i]) or ps.tokens[i].tk == "end") do if ps.tokens[i].tk == "{" then if def.typename == "arrayrecord" then return fail(ps, i, "duplicated declaration of array element type in record") end i = i + 1 local t i, t = parse_type(ps, i) if ps.tokens[i].tk == "}" then node.yend = ps.tokens[i].y i = verify_tk(ps, i, "}") else return fail(ps, i, "expected an array declaration") end def.typename = "arrayrecord" def.elements = t elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then i = i + 1 local v i, v = verify_kind(ps, i, "identifier", "variable") if not v then return fail(ps, i, "expected a variable name") end i = verify_tk(ps, i, "=") local nt i, nt = parse_newtype(ps, i) if not nt or not nt.newtype then return fail(ps, i, "expected a type definition") end local ok = store_field_in_record(v.tk, def, nt) if not ok then return fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)") end elseif ps.tokens[i].tk == "record" and ps.tokens[i + 1].tk ~= ":" then i = parse_nested_type(ps, i, def, "record", parse_record_body) elseif ps.tokens[i].tk == "enum" and ps.tokens[i + 1].tk ~= ":" then i = parse_nested_type(ps, i, def, "enum", parse_enum_body) else local v i, v = verify_kind(ps, i, "identifier", "variable") local iv = i if not v then return fail(ps, i, "expected a variable name") end if ps.tokens[i].tk == ":" then i = verify_tk(ps, i, ":") local t i, t = parse_type(ps, i) if not t then return fail(ps, i, "expected a type") end if not def.fields[v.tk] then def.fields[v.tk] = t table.insert(def.field_order, v.tk) else local prev_t = def.fields[v.tk] if t.typename == "function" and prev_t.typename == "function" then def.fields[v.tk] = new_type(ps, iv, "poly") def.fields[v.tk].types = { prev_t, t } elseif t.typename == "function" and prev_t.typename == "poly" then table.insert(prev_t.types, t) else return fail(ps, i, "attempt to redeclare field '" .. v.tk .. "' (only functions can be overloaded)") end end elseif ps.tokens[i].tk == "=" then local next_word = ps.tokens[i + 1].tk if next_word == "record" or next_word == "enum" then return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") elseif next_word == "functiontype" then return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") else return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") end end end end node.yend = ps.tokens[i].y i = verify_tk(ps, i, "end") return i, node end parse_newtype = function(ps, i) local node = new_node(ps.tokens, i, "newtype") node.newtype = new_type(ps, i, "typetype") if ps.tokens[i].tk == "record" then local def = new_type(ps, i, "record") i = i + 1 i = parse_record_body(ps, i, def, node) node.newtype.def = def return i, node elseif ps.tokens[i].tk == "enum" then local def = new_type(ps, i, "enum") i = i + 1 i = parse_enum_body(ps, i, def, node) node.newtype.def = def return i, node else i, node.newtype.def = parse_type(ps, i) return i, node end return fail(ps, i) end local function parse_call_or_assignment(ps, i) local asgn = new_node(ps.tokens, i, "assignment") local tryi = i asgn.vars = new_node(ps.tokens, i, "variables") i = parse_trying_list(ps, i, asgn.vars, parse_expression) if #asgn.vars < 1 then return fail(ps, i) end local lhs = asgn.vars[1] if ps.tokens[i].tk == "=" then asgn.exps = new_node(ps.tokens, i, "values") repeat i = i + 1 local val i, val = parse_expression(ps, i) table.insert(asgn.exps, val) until ps.tokens[i].tk ~= "," return i, asgn end if #asgn.vars > 1 then local err_ps = { tokens = ps.tokens, errs = {}, } local expi = parse_expression(err_ps, tryi) return fail(ps, expi or i) end if lhs.op and lhs.op.op == "@funcall" and #asgn.vars == 1 then return i, lhs end return fail(ps, i) end local function parse_variable_declarations(ps, i, node_name) local asgn = new_node(ps.tokens, i, node_name) asgn.vars = new_node(ps.tokens, i, "variables") i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) if #asgn.vars == 0 then return fail(ps, i, "expected a local variable definition") end local lhs = asgn.vars[1] i, asgn.decltype = parse_type_list(ps, i, "decltype") if ps.tokens[i].tk == "=" then if ps.tokens[i + 1].tk == "record" or ps.tokens[i + 1].tk == "enum" then local scope = node_name == "local_declaration" and "local" or "global" fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. scope .. " " .. ps.tokens[i + 1].tk .. " " .. asgn.vars[1].tk .. "'") elseif ps.tokens[i + 1].tk == "functiontype" then local scope = node_name == "local_declaration" and "local" or "global" fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...") end asgn.exps = new_node(ps.tokens, i, "values") local v = 1 repeat i = i + 1 local val i, val = parse_expression(ps, i) table.insert(asgn.exps, val) v = v + 1 until ps.tokens[i].tk ~= "," end return i, asgn end local function parse_type_declaration(ps, i, node_name) i = i + 2 local asgn = new_node(ps.tokens, i, node_name) i, asgn.var = parse_variable_name(ps, i) if not asgn.var then return fail(ps, i, "expected a type name") end i = verify_tk(ps, i, "=") i, asgn.value = parse_newtype(ps, i) if asgn.value then asgn.value.newtype.def.names = { asgn.var.tk } else return i end return i, asgn end local ParseBody = {} local function parse_type_constructor(ps, i, node_name, type_name, parse_body) local asgn = new_node(ps.tokens, i, node_name) local nt = new_node(ps.tokens, i, "newtype") asgn.value = nt nt.newtype = new_type(ps, i, "typetype") local def = new_type(ps, i, type_name) nt.newtype.def = def i = i + 2 i, asgn.var = verify_kind(ps, i, "identifier") if not asgn.var then return fail(ps, i, "expected a type name") end nt.newtype.def.names = { asgn.var.tk } i = parse_body(ps, i, def, nt) return i, asgn end local function parse_statement(ps, i) if ps.tokens[i].tk == "local" then if ps.tokens[i + 1].tk == "type" and ps.tokens[i + 2].kind == "identifier" then return parse_type_declaration(ps, i, "local_type") elseif ps.tokens[i + 1].tk == "function" then return parse_local_function(ps, i) elseif ps.tokens[i + 1].tk == "record" and ps.tokens[i + 2].kind == "identifier" then return parse_type_constructor(ps, i, "local_type", "record", parse_record_body) elseif ps.tokens[i + 1].tk == "enum" and ps.tokens[i + 2].kind == "identifier" then return parse_type_constructor(ps, i, "local_type", "enum", parse_enum_body) else i = i + 1 return parse_variable_declarations(ps, i, "local_declaration") end elseif ps.tokens[i].tk == "global" then if ps.tokens[i + 1].tk == "type" and ps.tokens[i + 2].kind == "identifier" then return parse_type_declaration(ps, i, "global_type") elseif ps.tokens[i + 1].tk == "record" and ps.tokens[i + 2].kind == "identifier" then return parse_type_constructor(ps, i, "global_type", "record", parse_record_body) elseif ps.tokens[i + 1].tk == "enum" and ps.tokens[i + 2].kind == "identifier" then return parse_type_constructor(ps, i, "global_type", "enum", parse_enum_body) elseif ps.tokens[i + 1].tk == "function" then i = i + 1 return parse_function(ps, i) else i = i + 1 return parse_variable_declarations(ps, i, "global_declaration") end elseif ps.tokens[i].tk == "function" then return parse_function(ps, i) elseif ps.tokens[i].tk == "if" then return parse_if(ps, i) elseif ps.tokens[i].tk == "while" then return parse_while(ps, i) elseif ps.tokens[i].tk == "repeat" then return parse_repeat(ps, i) elseif ps.tokens[i].tk == "for" then return parse_for(ps, i) elseif ps.tokens[i].tk == "do" then return parse_do(ps, i) elseif ps.tokens[i].tk == "break" then return parse_break(ps, i) elseif ps.tokens[i].tk == "return" then return parse_return(ps, i) elseif ps.tokens[i].tk == "goto" then return parse_goto(ps, i) elseif ps.tokens[i].tk == "::" then return parse_label(ps, i) else return parse_call_or_assignment(ps, i) end end parse_statements = function(ps, i, filename, toplevel) local node = new_node(ps.tokens, i, "statements") while true do while ps.tokens[i].kind == ";" do i = i + 1 end if ps.tokens[i].kind == "$EOF$" then break end if (not toplevel) and stop_statement_list[ps.tokens[i].tk] then break end local item i, item = parse_statement(ps, i) if filename then for j = 1, #ps.errs do if not ps.errs[j].filename then ps.errs[j].filename = filename end end end if not item then break end table.insert(node, item) end return i, node end function tl.parse_program(tokens, errs, filename) errs = errs or {} local ps = { tokens = tokens, errs = errs, filename = filename, } local last = ps.tokens[#ps.tokens] or { y = 1, x = 1, tk = "" } table.insert(ps.tokens, { y = last.y, x = last.x + #last.tk, tk = "$EOF$", kind = "$EOF$" }) return parse_statements(ps, 1, filename, true) end local VisitorCallbacks = {} local Visitor = {} local function visit_before(ast, kind, visit) assert(visit.cbs[kind], "no visitor for " .. (kind)) if visit.cbs[kind].before then visit.cbs[kind].before(ast) end end local function visit_after(ast, kind, visit, xs) if visit.after and visit.after.before then visit.after.before(ast, xs) end local ret if visit.cbs[kind].after then ret = visit.cbs[kind].after(ast, xs) end if visit.after and visit.after.after then ret = visit.after.after(ast, xs, ret) end return ret end local function recurse_type(ast, visit) visit_before(ast, ast.typename, visit) local xs = {} if ast.typeargs then for _, child in ipairs(ast.typeargs) do table.insert(xs, recurse_type(child, visit)) end end for i, child in ipairs(ast) do xs[i] = recurse_type(child, visit) end if ast.types then for i, child in ipairs(ast.types) do table.insert(xs, recurse_type(child, visit)) end end if ast.def then table.insert(xs, recurse_type(ast.def, visit)) end if ast.keys then table.insert(xs, recurse_type(ast.keys, visit)) end if ast.values then table.insert(xs, recurse_type(ast.values, visit)) end if ast.elements then table.insert(xs, recurse_type(ast.elements, visit)) end if ast.fields then for _, child in pairs(ast.fields) do table.insert(xs, recurse_type(child, visit)) end end if ast.args then for i, child in ipairs(ast.args) do if i > 1 or not ast.is_method then table.insert(xs, recurse_type(child, visit)) end end end if ast.rets then for _, child in ipairs(ast.rets) do table.insert(xs, recurse_type(child, visit)) end end if ast.typevals then for _, child in ipairs(ast.typevals) do table.insert(xs, recurse_type(child, visit)) end end if ast.ktype then table.insert(xs, recurse_type(ast.ktype, visit)) end if ast.vtype then table.insert(xs, recurse_type(ast.vtype, visit)) end return visit_after(ast, ast.typename, visit, xs) end local function recurse_node(ast, visit_node, visit_type) if not ast then return end visit_before(ast, ast.kind, visit_node) local xs = {} local cbs = visit_node.cbs[ast.kind] if ast.kind == "statements" or ast.kind == "variables" or ast.kind == "values" or ast.kind == "argument_list" or ast.kind == "expression_list" or ast.kind == "table_literal" then for i, child in ipairs(ast) do xs[i] = recurse_node(child, visit_node, visit_type) end elseif ast.kind == "local_declaration" or ast.kind == "global_declaration" or ast.kind == "assignment" then xs[1] = recurse_node(ast.vars, visit_node, visit_type) if ast.exps then xs[2] = recurse_node(ast.exps, visit_node, visit_type) end if ast.decltype then xs[3] = recurse_type(ast.decltype, visit_type) end elseif ast.kind == "local_type" or ast.kind == "global_type" then xs[1] = recurse_node(ast.var, visit_node, visit_type) xs[2] = recurse_node(ast.value, visit_node, visit_type) elseif ast.kind == "table_item" then xs[1] = recurse_node(ast.key, visit_node, visit_type) xs[2] = recurse_node(ast.value, visit_node, visit_type) elseif ast.kind == "if" then xs[1] = recurse_node(ast.exp, visit_node, visit_type) if cbs.before_statements then cbs.before_statements(ast, xs) end xs[2] = recurse_node(ast.thenpart, visit_node, visit_type) for i, e in ipairs(ast.elseifs) do table.insert(xs, recurse_node(e, visit_node, visit_type)) end if ast.elsepart then table.insert(xs, recurse_node(ast.elsepart, visit_node, visit_type)) end elseif ast.kind == "while" then xs[1] = recurse_node(ast.exp, visit_node, visit_type) if cbs.before_statements then cbs.before_statements(ast, xs) end xs[2] = recurse_node(ast.body, visit_node, visit_type) elseif ast.kind == "repeat" then xs[1] = recurse_node(ast.body, visit_node, visit_type) xs[2] = recurse_node(ast.exp, visit_node, visit_type) elseif ast.kind == "function" then xs[1] = recurse_node(ast.args, visit_node, visit_type) xs[2] = recurse_type(ast.rets, visit_type) xs[3] = recurse_node(ast.body, visit_node, visit_type) elseif ast.kind == "forin" then xs[1] = recurse_node(ast.vars, visit_node, visit_type) xs[2] = recurse_node(ast.exps, visit_node, visit_type) if cbs.before_statements then cbs.before_statements(ast) end xs[3] = recurse_node(ast.body, visit_node, visit_type) elseif ast.kind == "fornum" then xs[1] = recurse_node(ast.var, visit_node, visit_type) xs[2] = recurse_node(ast.from, visit_node, visit_type) xs[3] = recurse_node(ast.to, visit_node, visit_type) xs[4] = ast.step and recurse_node(ast.step, visit_node, visit_type) xs[5] = recurse_node(ast.body, visit_node, visit_type) elseif ast.kind == "elseif" then xs[1] = recurse_node(ast.exp, visit_node, visit_type) if cbs.before_statements then cbs.before_statements(ast, xs) end xs[2] = recurse_node(ast.thenpart, visit_node, visit_type) elseif ast.kind == "else" then xs[1] = recurse_node(ast.elsepart, visit_node, visit_type) elseif ast.kind == "return" then xs[1] = recurse_node(ast.exps, visit_node, visit_type) elseif ast.kind == "do" then xs[1] = recurse_node(ast.body, visit_node, visit_type) elseif ast.kind == "cast" then elseif ast.kind == "local_function" or ast.kind == "global_function" then xs[1] = recurse_node(ast.name, visit_node, visit_type) xs[2] = recurse_node(ast.args, visit_node, visit_type) xs[3] = recurse_type(ast.rets, visit_type) xs[4] = recurse_node(ast.body, visit_node, visit_type) elseif ast.kind == "record_function" then xs[1] = recurse_node(ast.fn_owner, visit_node, visit_type) xs[2] = recurse_node(ast.name, visit_node, visit_type) xs[3] = recurse_node(ast.args, visit_node, visit_type) xs[4] = recurse_type(ast.rets, visit_type) if cbs.before_statements then cbs.before_statements(ast, xs) end xs[5] = recurse_node(ast.body, visit_node, visit_type) elseif ast.kind == "paren" then xs[1] = recurse_node(ast.e1, visit_node, visit_type) elseif ast.kind == "op" then xs[1] = recurse_node(ast.e1, visit_node, visit_type) local p1 = ast.e1.op and ast.e1.op.prec or nil if ast.op.op == ":" and ast.e1.kind == "string" then p1 = -999 end xs[2] = p1 if ast.op.arity == 2 then if cbs.before_e2 then cbs.before_e2(ast, xs) end if ast.op.op == "is" or ast.op.op == "as" then xs[3] = recurse_type(ast.e2.casttype, visit_type) else xs[3] = recurse_node(ast.e2, visit_node, visit_type) end xs[4] = (ast.e2.op and ast.e2.op.prec) end elseif ast.kind == "newtype" then xs[1] = recurse_type(ast.newtype, visit_type) elseif ast.kind == "variable" or ast.kind == "argument" or ast.kind == "identifier" or ast.kind == "string" or ast.kind == "number" or ast.kind == "break" or ast.kind == "goto" or ast.kind == "label" or ast.kind == "nil" or ast.kind == "..." or ast.kind == "boolean" then if ast.decltype then xs[1] = recurse_type(ast.decltype, visit_type) end else if not ast.kind then error("wat: " .. inspect(ast)) end error("unknown node kind " .. ast.kind) end return visit_after(ast, ast.kind, visit_node, xs) end local tight_op = { [1] = { ["-"] = true, ["~"] = true, ["#"] = true, }, [2] = { ["."] = true, [":"] = true, }, } local spaced_op = { [1] = { ["not"] = true, }, [2] = { ["or"] = true, ["and"] = true, ["<"] = true, [">"] = true, ["<="] = true, [">="] = true, ["~="] = true, ["=="] = true, ["|"] = true, ["~"] = true, ["&"] = true, ["<<"] = true, [">>"] = true, [".."] = true, ["+"] = true, ["-"] = true, ["*"] = true, ["/"] = true, ["//"] = true, ["%"] = true, ["^"] = true, }, } local PrettyPrintOpts = {} local default_pretty_print_ast_opts = { preserve_indent = true, preserve_newlines = true, } local fast_pretty_print_ast_opts = { preserve_indent = false, preserve_newlines = true, } function tl.pretty_print_ast(ast, mode) local indent = 0 local opts if type(mode) == "table" then opts = mode elseif mode == true then opts = fast_pretty_print_ast_opts else opts = default_pretty_print_ast_opts end local Output = {} local function increment_indent() indent = indent + 1 end if not opts.preserve_indent then increment_indent = nil end local function add(out, s) table.insert(out, s) end local function add_string(out, s) table.insert(out, s) if string.find(s, "\n", 1, true) then for nl in s:gmatch("\n") do out.h = out.h + 1 end end end local function add_child(out, child, space, indent) if #child == 0 then return end if child.y < out.y then out.y = child.y end if child.y > out.y + out.h and opts.preserve_newlines then local delta = child.y - (out.y + out.h) out.h = out.h + delta table.insert(out, ("\n"):rep(delta)) else if space then table.insert(out, space) indent = nil end end if indent and opts.preserve_indent then table.insert(out, (" "):rep(indent)) end table.insert(out, child) out.h = out.h + child.h end local function concat_output(out) for i, s in ipairs(out) do if type(s) == "table" then out[i] = concat_output(s) end end return table.concat(out) end local function print_record_def(typ) local out = { "{" } for name, field in pairs(typ.fields) do if field.typename == "typetype" and is_record_type(field.def) then table.insert(out, name) table.insert(out, " = ") table.insert(out, print_record_def(field.def)) table.insert(out, ", ") end end table.insert(out, "}") return table.concat(out) end local visit_node = {} visit_node.cbs = { ["statements"] = { after = function(node, children) local out = { y = node.y, h = 0 } local space for i, child in ipairs(children) do add_child(out, children[i], space, indent) space = "; " end return out end, }, ["local_declaration"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "local") add_child(out, children[1], " ") if children[2] then table.insert(out, " =") add_child(out, children[2], " ") end return out end, }, ["local_type"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "local") add_child(out, children[1], " ") table.insert(out, " =") add_child(out, children[2], " ") return out end, }, ["global_type"] = { after = function(node, children) local out = { y = node.y, h = 0 } add_child(out, children[1], " ") table.insert(out, " =") add_child(out, children[2], " ") return out end, }, ["global_declaration"] = { after = function(node, children) local out = { y = node.y, h = 0 } if children[2] then add_child(out, children[1]) table.insert(out, " =") add_child(out, children[2], " ") end return out end, }, ["assignment"] = { after = function(node, children) local out = { y = node.y, h = 0 } add_child(out, children[1]) table.insert(out, " =") add_child(out, children[2], " ") return out end, }, ["if"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "if") add_child(out, children[1], " ") table.insert(out, " then") add_child(out, children[2], " ") indent = indent - 1 for i = 3, #children do add_child(out, children[i], " ", indent) end add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["while"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "while") add_child(out, children[1], " ") table.insert(out, " do") add_child(out, children[2], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["repeat"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "repeat") add_child(out, children[1], " ") if opts.preserve_indent then indent = indent - 1 end add_child(out, { y = node.yend, h = 0, [1] = "until " }, " ", indent) add_child(out, children[2]) return out end, }, ["do"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "do") add_child(out, children[1], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["forin"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") table.insert(out, " in") add_child(out, children[2], " ") table.insert(out, " do") add_child(out, children[3], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["fornum"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "for") add_child(out, children[1], " ") table.insert(out, " =") add_child(out, children[2], " ") table.insert(out, ",") add_child(out, children[3], " ") if children[4] then table.insert(out, ",") add_child(out, children[4], " ") end table.insert(out, " do") add_child(out, children[5], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["return"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "return") if #children[1] > 0 then add_child(out, children[1], " ") end return out end, }, ["break"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "break") return out end, }, ["elseif"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "elseif") add_child(out, children[1], " ") table.insert(out, " then") add_child(out, children[2], " ") return out end, }, ["else"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "else") add_child(out, children[1], " ") return out end, }, ["variables"] = { after = function(node, children) local out = { y = node.y, h = 0 } local space for i, child in ipairs(children) do if i > 1 then table.insert(out, ",") space = " " end add_child(out, child, space) end return out end, }, ["table_literal"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } if #children == 0 then indent = indent - 1 table.insert(out, "{}") return out end table.insert(out, "{") local n = #children for i, child in ipairs(children) do add_child(out, child, " ", child.y ~= node.y and indent) if i < n or node.yend ~= node.y then table.insert(out, ",") end end indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "}" }, " ", indent) return out end, }, ["table_item"] = { after = function(node, children) local out = { y = node.y, h = 0 } if node.key_parsed ~= "implicit" then if node.key_parsed == "short" then children[1][1] = children[1][1]:sub(2, -2) add_child(out, children[1]) table.insert(out, " = ") else table.insert(out, "[") add_child(out, children[1]) table.insert(out, "] = ") end end add_child(out, children[2]) return out end, }, ["local_function"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "local function") add_child(out, children[1], " ") table.insert(out, "(") add_child(out, children[2]) table.insert(out, ")") add_child(out, children[4], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["global_function"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") table.insert(out, "(") add_child(out, children[2]) table.insert(out, ")") add_child(out, children[4], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["record_function"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "function") add_child(out, children[1], " ") table.insert(out, node.is_method and ":" or ".") add_child(out, children[2]) table.insert(out, "(") if node.is_method then table.remove(children[3], 1) if children[3][1] == "," then table.remove(children[3], 1) table.remove(children[3], 1) end end add_child(out, children[3]) table.insert(out, ")") add_child(out, children[5], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["function"] = { before = increment_indent, after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "function(") add_child(out, children[1]) table.insert(out, ")") add_child(out, children[3], " ") indent = indent - 1 add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) return out end, }, ["cast"] = {}, ["paren"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "(") add_child(out, children[1], "", indent) table.insert(out, ")") return out end, }, ["op"] = { after = function(node, children) local out = { y = node.y, h = 0 } if node.op.op == "@funcall" then add_child(out, children[1], "", indent) table.insert(out, "(") add_child(out, children[3], "", indent) table.insert(out, ")") elseif node.op.op == "@index" then add_child(out, children[1], "", indent) table.insert(out, "[") add_child(out, children[3], "", indent) table.insert(out, "]") elseif node.op.op == "as" then add_child(out, children[1], "", indent) elseif node.op.op == "is" then table.insert(out, "type(") add_child(out, children[1], "", indent) table.insert(out, ") == \"") add_child(out, children[3], "", indent) table.insert(out, "\"") elseif spaced_op[node.op.arity][node.op.op] or tight_op[node.op.arity][node.op.op] then local space = spaced_op[node.op.arity][node.op.op] and " " or "" if children[2] and node.op.prec > tonumber(children[2]) then table.insert(children[1], 1, "(") table.insert(children[1], ")") end if node.op.arity == 1 then table.insert(out, node.op.op) add_child(out, children[1], space, indent) elseif node.op.arity == 2 then add_child(out, children[1], "", indent) if space == " " then table.insert(out, " ") end table.insert(out, node.op.op) if children[4] and node.op.prec > tonumber(children[4]) then table.insert(children[3], 1, "(") table.insert(children[3], ")") end add_child(out, children[3], space, indent) end else error("unknown node op " .. node.op.op) end return out end, }, ["variable"] = { after = function(node, children) local out = { y = node.y, h = 0 } add_string(out, node.tk) return out end, }, ["newtype"] = { after = function(node, children) local out = { y = node.y, h = 0 } if is_record_type(node.newtype.def) then table.insert(out, print_record_def(node.newtype.def)) else table.insert(out, "{}") end return out end, }, ["goto"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "goto ") table.insert(out, node.label) return out end, }, ["label"] = { after = function(node, children) local out = { y = node.y, h = 0 } table.insert(out, "::") table.insert(out, node.label) table.insert(out, "::") return out end, }, } local primitive = { ["function"] = "function", ["enum"] = "string", ["boolean"] = "boolean", ["string"] = "string", ["nil"] = "nil", ["number"] = "number", ["thread"] = "thread", } local visit_type = {} visit_type.cbs = { ["string"] = { after = function(typ, children) local out = { y = typ.y, h = 0 } table.insert(out, primitive[typ.typename] or "table") return out end, }, } visit_type.cbs["typetype"] = visit_type.cbs["string"] visit_type.cbs["typevar"] = visit_type.cbs["string"] visit_type.cbs["typearg"] = visit_type.cbs["string"] visit_type.cbs["function"] = visit_type.cbs["string"] visit_type.cbs["thread"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] visit_type.cbs["arrayrecord"] = visit_type.cbs["string"] visit_type.cbs["record"] = visit_type.cbs["string"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] visit_type.cbs["nil"] = visit_type.cbs["string"] visit_type.cbs["number"] = visit_type.cbs["string"] visit_type.cbs["union"] = visit_type.cbs["string"] visit_type.cbs["nominal"] = visit_type.cbs["string"] visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] visit_type.cbs["emptytable"] = visit_type.cbs["string"] visit_type.cbs["table_item"] = visit_type.cbs["string"] visit_type.cbs["unknown_emptytable_value"] = visit_type.cbs["string"] visit_type.cbs["tuple"] = visit_type.cbs["string"] visit_type.cbs["poly"] = visit_type.cbs["string"] visit_type.cbs["any"] = visit_type.cbs["string"] visit_type.cbs["unknown"] = visit_type.cbs["string"] visit_type.cbs["invalid"] = visit_type.cbs["string"] visit_type.cbs["unresolved"] = visit_type.cbs["string"] visit_type.cbs["none"] = visit_type.cbs["string"] visit_node.cbs["values"] = visit_node.cbs["variables"] visit_node.cbs["expression_list"] = visit_node.cbs["variables"] visit_node.cbs["argument_list"] = visit_node.cbs["variables"] visit_node.cbs["identifier"] = visit_node.cbs["variable"] visit_node.cbs["string"] = visit_node.cbs["variable"] visit_node.cbs["number"] = visit_node.cbs["variable"] visit_node.cbs["nil"] = visit_node.cbs["variable"] visit_node.cbs["boolean"] = visit_node.cbs["variable"] visit_node.cbs["..."] = visit_node.cbs["variable"] visit_node.cbs["argument"] = visit_node.cbs["variable"] local out = recurse_node(ast, visit_node, visit_type) local code if opts.preserve_newlines then code = { y = 1, h = 0 } add_child(code, out) else code = out end return concat_output(code) end local ANY = a_type({ typename = "any" }) local NONE = a_type({ typename = "none" }) local NIL = a_type({ typename = "nil" }) local NUMBER = a_type({ typename = "number" }) local STRING = a_type({ typename = "string" }) local OPT_NUMBER = a_type({ typename = "number" }) local OPT_STRING = a_type({ typename = "string" }) local VARARG_ANY = a_type({ typename = "any", is_va = true }) local VARARG_STRING = a_type({ typename = "string", is_va = true }) local VARARG_NUMBER = a_type({ typename = "number", is_va = true }) local VARARG_UNKNOWN = a_type({ typename = "unknown", is_va = true }) local VARARG_ALPHA = a_type({ typename = "typevar", typevar = "@a", is_va = true }) local BOOLEAN = a_type({ typename = "boolean" }) local ARG_ALPHA = a_type({ typename = "typearg", typearg = "@a" }) local ARG_BETA = a_type({ typename = "typearg", typearg = "@b" }) local ALPHA = a_type({ typename = "typevar", typevar = "@a" }) local BETA = a_type({ typename = "typevar", typevar = "@b" }) local ARRAY_OF_STRING = a_type({ typename = "array", elements = STRING }) local ARRAY_OF_ALPHA = a_type({ typename = "array", elements = ALPHA }) local MAP_OF_ALPHA_TO_BETA = a_type({ typename = "map", keys = ALPHA, values = BETA }) local TABLE = a_type({ typename = "map", keys = ANY, values = ANY }) local FUNCTION = a_type({ typename = "function", args = { a_type({ typename = "any", is_va = true }) }, rets = { a_type({ typename = "any", is_va = true }) } }) local THREAD = a_type({ typename = "thread" }) local INVALID = a_type({ typename = "invalid" }) local UNKNOWN = a_type({ typename = "unknown" }) local NOMINAL_FILE = a_type({ typename = "nominal", names = { "FILE" } }) local NOMINAL_METATABLE = a_type({ typename = "nominal", names = { "METATABLE" } }) local OS_DATE_TABLE = a_type({ typename = "record", fields = { ["year"] = NUMBER, ["month"] = NUMBER, ["day"] = NUMBER, ["hour"] = NUMBER, ["min"] = NUMBER, ["sec"] = NUMBER, ["wday"] = NUMBER, ["yday"] = NUMBER, ["isdst"] = BOOLEAN, }, }) local DEBUG_GETINFO_TABLE = a_type({ typename = "record", fields = { ["name"] = STRING, ["namewhat"] = STRING, ["source"] = STRING, ["short_src"] = STRING, ["linedefined"] = NUMBER, ["lastlinedefined"] = NUMBER, ["what"] = STRING, ["currentline"] = NUMBER, ["istailcall"] = BOOLEAN, ["nups"] = NUMBER, ["nparams"] = NUMBER, ["isvararg"] = BOOLEAN, ["func"] = ANY, ["activelines"] = a_type({ typename = "map", keys = NUMBER, values = BOOLEAN }), }, }) local numeric_binop = { ["number"] = { ["number"] = NUMBER, }, } local relational_binop = { ["number"] = { ["number"] = BOOLEAN, }, ["string"] = { ["string"] = BOOLEAN, }, ["boolean"] = { ["boolean"] = BOOLEAN, }, } local equality_binop = { ["number"] = { ["number"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["string"] = { ["string"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["boolean"] = { ["boolean"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["record"] = { ["emptytable"] = BOOLEAN, ["arrayrecord"] = BOOLEAN, ["record"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["array"] = { ["emptytable"] = BOOLEAN, ["arrayrecord"] = BOOLEAN, ["array"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["arrayrecord"] = { ["emptytable"] = BOOLEAN, ["arrayrecord"] = BOOLEAN, ["record"] = BOOLEAN, ["array"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["map"] = { ["emptytable"] = BOOLEAN, ["map"] = BOOLEAN, ["nil"] = BOOLEAN, }, ["thread"] = { ["thread"] = BOOLEAN, ["nil"] = BOOLEAN, }, } local unop_types = { ["#"] = { ["arrayrecord"] = NUMBER, ["string"] = NUMBER, ["array"] = NUMBER, ["map"] = NUMBER, ["emptytable"] = NUMBER, }, ["-"] = { ["number"] = NUMBER, }, ["not"] = { ["string"] = BOOLEAN, ["number"] = BOOLEAN, ["boolean"] = BOOLEAN, ["record"] = BOOLEAN, ["arrayrecord"] = BOOLEAN, ["array"] = BOOLEAN, ["map"] = BOOLEAN, ["emptytable"] = BOOLEAN, ["thread"] = BOOLEAN, }, } local binop_types = { ["+"] = numeric_binop, ["-"] = { ["number"] = { ["number"] = NUMBER, }, }, ["*"] = numeric_binop, ["%"] = numeric_binop, ["/"] = numeric_binop, ["^"] = numeric_binop, ["&"] = numeric_binop, ["|"] = numeric_binop, ["<<"] = numeric_binop, [">>"] = numeric_binop, ["=="] = equality_binop, ["~="] = equality_binop, ["<="] = relational_binop, [">="] = relational_binop, ["<"] = relational_binop, [">"] = relational_binop, ["or"] = { ["boolean"] = { ["boolean"] = BOOLEAN, ["function"] = FUNCTION, }, ["number"] = { ["number"] = NUMBER, ["boolean"] = BOOLEAN, }, ["string"] = { ["string"] = STRING, ["boolean"] = BOOLEAN, ["enum"] = STRING, }, ["function"] = { ["function"] = FUNCTION, ["boolean"] = BOOLEAN, }, ["array"] = { ["boolean"] = BOOLEAN, }, ["record"] = { ["boolean"] = BOOLEAN, }, ["arrayrecord"] = { ["boolean"] = BOOLEAN, }, ["map"] = { ["boolean"] = BOOLEAN, }, ["enum"] = { ["string"] = STRING, }, ["thread"] = { ["boolean"] = BOOLEAN, }, }, [".."] = { ["string"] = { ["string"] = STRING, ["enum"] = STRING, ["number"] = STRING, }, ["number"] = { ["number"] = STRING, ["string"] = STRING, ["enum"] = STRING, }, ["enum"] = { ["number"] = STRING, ["string"] = STRING, ["enum"] = STRING, }, }, } local show_type local function is_unknown(t) return t.typename == "unknown" or t.typename == "unknown_emptytable_value" end local show_type local function show_type_base(t, seen) if seen[t] then return "..." end seen[t] = true local function show(t) return show_type(t, seen) end if t.typename == "nominal" then if t.typevals then local out = { table.concat(t.names, "."), "<" } local vals = {} for _, v in ipairs(t.typevals) do table.insert(vals, show(v)) end table.insert(out, table.concat(vals, ", ")) table.insert(out, ">") return table.concat(out) else return table.concat(t.names, ".") end elseif t.typename == "tuple" then local out = {} for _, v in ipairs(t) do table.insert(out, show(v)) end return "(" .. table.concat(out, ", ") .. ")" elseif t.typename == "poly" then local out = {} for _, v in ipairs(t.types) do table.insert(out, show(v)) end return table.concat(out, " or ") elseif t.typename == "union" then local out = {} for _, v in ipairs(t.types) do table.insert(out, show(v)) end return table.concat(out, " | ") elseif t.typename == "emptytable" then return "{}" elseif t.typename == "map" then return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}" elseif t.typename == "array" then return "{" .. show(t.elements) .. "}" elseif t.typename == "enum" then return t.names and table.concat(t.names, ".") or "enum" elseif is_record_type(t) then local out = {} for _, k in ipairs(t.field_order) do local v = t.fields[k] table.insert(out, k .. ": " .. show(v)) end return "{" .. table.concat(out, ", ") .. "}" elseif t.typename == "function" then local out = {} table.insert(out, "function(") local args = {} if t.is_method then table.insert(args, "self") end for i, v in ipairs(t.args) do if not t.is_method or i > 1 then table.insert(args, show(v)) end end table.insert(out, table.concat(args, ",")) table.insert(out, ")") if #t.rets > 0 then table.insert(out, ":") local rets = {} for _, v in ipairs(t.rets) do table.insert(rets, show(v)) end table.insert(out, table.concat(rets, ",")) end return table.concat(out) elseif t.typename == "number" or t.typename == "boolean" or t.typename == "thread" then return t.typename elseif t.typename == "string" then return t.typename .. (t.tk and " " .. t.tk or "") elseif t.typename == "typevar" then return t.typevar elseif t.typename == "typearg" then return t.typearg elseif is_unknown(t) then return "" elseif t.typename == "invalid" then return "" elseif t.typename == "any" then return "" elseif t.typename == "nil" then return "nil" elseif t.typename == "typetype" then return "type " .. show(t.def) elseif t.typename == "bad_nominal" then return table.concat(t.names, ".") .. " (an unknown type)" else return inspect(t) end end show_type = function(t, seen) local ret = show_type_base(t, seen or {}) if t.inferred_at then ret = ret .. " (inferred at " .. t.inferred_at_file .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ": )" end return ret end local Error = {} local Result = {} local function search_for(module_name, suffix, path, tried) for entry in path:gmatch("[^;]+") do local slash_name = module_name:gsub("%.", "/") local filename = entry:gsub("?", slash_name) local tl_filename = filename:gsub("%.lua$", suffix) local fd = io.open(tl_filename, "r") if fd then return tl_filename, fd, tried end table.insert(tried, "no file '" .. tl_filename .. "'") end return nil, nil, tried end function tl.search_module(module_name, search_dtl) local found local tried = {} local path = os.getenv("TL_PATH") or package.path if search_dtl then local found, fd, tried = search_for(module_name, ".d.tl", path, tried) if found then return found, fd end end local found, fd, tried = search_for(module_name, ".tl", path, tried) if found then return found, fd end local found, fd, tried = search_for(module_name, ".lua", path, tried) if found then return found, fd end return nil, nil, tried end local Variable = {} local function fill_field_order(t) if t.typename == "record" then t.field_order = {} for k, v in pairs(t.fields) do table.insert(t.field_order, k) end table.sort(t.field_order) end end local function require_module(module_name, lax, env, result) local modules = env.modules if modules[module_name] then return modules[module_name], true end modules[module_name] = UNKNOWN local found, fd, tried = tl.search_module(module_name, true) if found and (lax or found:match("tl$")) then fd:close() local _result, err = tl.process(found, env, result) assert(_result, err) if not _result.type then _result.type = BOOLEAN end modules[module_name] = _result.type return _result.type, true end return UNKNOWN, found ~= nil end local standard_library = { ["..."] = a_type({ typename = "tuple", STRING, STRING, STRING, STRING, STRING }), ["@return"] = a_type({ typename = "tuple", ANY }), ["any"] = a_type({ typename = "typetype", def = ANY }), ["arg"] = ARRAY_OF_STRING, ["assert"] = a_type({ typename = "poly", types = { a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ALPHA }, rets = { ALPHA } }), a_type({ typename = "function", typeargs = { ARG_ALPHA, ARG_BETA }, args = { ALPHA, BETA }, rets = { ALPHA } }), }, }), ["collectgarbage"] = a_type({ typename = "function", args = { STRING }, rets = { a_type({ typename = "union", types = { BOOLEAN, NUMBER } }), NUMBER, NUMBER } }), ["dofile"] = a_type({ typename = "function", args = { OPT_STRING }, rets = { VARARG_ANY } }), ["error"] = a_type({ typename = "function", args = { STRING, NUMBER }, rets = {} }), ["getmetatable"] = a_type({ typename = "function", args = { ANY }, rets = { NOMINAL_METATABLE } }), ["ipairs"] = a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA }, rets = { a_type({ typename = "function", args = {}, rets = { NUMBER, ALPHA } }), }, }), ["load"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { STRING }, rets = { FUNCTION, STRING } }), a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION, STRING } }), a_type({ typename = "function", args = { STRING, STRING, STRING }, rets = { FUNCTION, STRING } }), a_type({ typename = "function", args = { STRING, STRING, STRING, TABLE }, rets = { FUNCTION, STRING } }), }, }), ["loadfile"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = {}, rets = { FUNCTION, ANY } }), a_type({ typename = "function", args = { STRING }, rets = { FUNCTION, ANY } }), a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION, ANY } }), a_type({ typename = "function", args = { STRING, STRING, TABLE }, rets = { FUNCTION, ANY } }), }, }), ["next"] = a_type({ typename = "poly", types = { a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA }, rets = { ALPHA, BETA } }), a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA, ALPHA }, rets = { ALPHA, BETA } }), a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA }, rets = { NUMBER, ALPHA } }), a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA, ALPHA }, rets = { NUMBER, ALPHA } }), }, }), ["pairs"] = a_type({ typename = "function", typeargs = { ARG_ALPHA, ARG_BETA }, args = { a_type({ typename = "map", keys = ALPHA, values = BETA }) }, rets = { a_type({ typename = "function", args = {}, rets = { ALPHA, BETA } }), }, }), ["pcall"] = a_type({ typename = "function", args = { FUNCTION, VARARG_ANY }, rets = { BOOLEAN, ANY } }), ["xpcall"] = a_type({ typename = "function", args = { FUNCTION, FUNCTION, VARARG_ANY }, rets = { BOOLEAN, ANY } }), ["print"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = {} }), ["rawequal"] = a_type({ typename = "function", args = { ANY, ANY }, rets = { BOOLEAN } }), ["rawget"] = a_type({ typename = "function", args = { TABLE, ANY }, rets = { ANY } }), ["rawlen"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { TABLE }, rets = { NUMBER } }), a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), }, }), ["rawset"] = a_type({ typename = "poly", types = { a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { MAP_OF_ALPHA_TO_BETA, ALPHA, BETA }, rets = {} }), a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ARRAY_OF_ALPHA, NUMBER, ALPHA }, rets = {} }), a_type({ typename = "function", args = { TABLE, ANY, ANY }, rets = {} }), }, }), ["require"] = a_type({ typename = "function", args = { STRING }, rets = {} }), ["select"] = a_type({ typename = "poly", types = { a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { NUMBER, VARARG_ALPHA }, rets = { ALPHA } }), a_type({ typename = "function", args = { NUMBER, VARARG_ANY }, rets = { ANY } }), a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { NUMBER } }), }, }), ["setmetatable"] = a_type({ typeargs = { ARG_ALPHA }, typename = "function", args = { ALPHA, NOMINAL_METATABLE }, rets = { ALPHA } }), ["tonumber"] = a_type({ typename = "function", args = { ANY, NUMBER }, rets = { NUMBER } }), ["tostring"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), ["FILE"] = a_type({ typename = "typetype", def = a_type({ typename = "record", fields = { ["close"] = a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING } }), ["flush"] = a_type({ typename = "function", args = { NOMINAL_FILE }, rets = {} }), ["lines"] = a_type({ typename = "function", args = { NOMINAL_FILE, a_type({ typename = "union", types = { STRING, NUMBER }, is_va = true }) }, rets = { a_type({ typename = "function", args = {}, rets = { VARARG_STRING } }), }, }), ["read"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { STRING, STRING } }), a_type({ typename = "function", args = { NOMINAL_FILE, NUMBER }, rets = { STRING, STRING } }), }, }), ["seek"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NUMBER, STRING } }), a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { NUMBER, STRING } }), a_type({ typename = "function", args = { NOMINAL_FILE, STRING, NUMBER }, rets = { NUMBER, STRING } }), }, }), ["setvbuf"] = a_type({ typename = "function", args = { NOMINAL_FILE, STRING, OPT_NUMBER }, rets = {} }), ["write"] = a_type({ typename = "function", args = { NOMINAL_FILE, VARARG_STRING }, rets = { NOMINAL_FILE, STRING } }), }, }), }), ["METATABLE"] = a_type({ typename = "typetype", def = a_type({ typename = "record", fields = { ["__call"] = FUNCTION, ["__gc"] = a_type({ typename = "function", args = { ANY }, rets = {} }), ["__index"] = ANY, ["__len"] = a_type({ typename = "function", args = { ANY }, rets = { NUMBER } }), ["__mode"] = a_type({ typename = "enum", enumset = { ["k"] = true, ["v"] = true, ["kv"] = true } }), ["__newindex"] = ANY, ["__pairs"] = a_type({ typeargs = { ARG_ALPHA, ARG_BETA }, typename = "function", args = { a_type({ typename = "map", keys = ALPHA, values = BETA }) }, rets = { a_type({ typename = "function", args = {}, rets = { ALPHA, BETA } }), }, }), ["__tostring"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), ["__name"] = STRING, ["__add"] = FUNCTION, ["__sub"] = FUNCTION, ["__mul"] = FUNCTION, ["__div"] = FUNCTION, ["__idiv"] = FUNCTION, ["__mod"] = FUNCTION, ["__pow"] = FUNCTION, ["__unm"] = FUNCTION, ["__band"] = FUNCTION, ["__bor"] = FUNCTION, ["__bxor"] = FUNCTION, ["__bnot"] = FUNCTION, ["__shl"] = FUNCTION, ["__shr"] = FUNCTION, ["__concat"] = FUNCTION, ["__eq"] = FUNCTION, ["__lt"] = FUNCTION, ["__le"] = FUNCTION, }, }), }), ["coroutine"] = a_type({ typename = "record", fields = { ["create"] = a_type({ typename = "function", args = { FUNCTION }, rets = { THREAD } }), ["close"] = a_type({ typename = "function", args = { THREAD }, rets = { BOOLEAN, STRING } }), ["isyieldable"] = a_type({ typename = "function", args = {}, rets = { BOOLEAN } }), ["resume"] = a_type({ typename = "function", args = { THREAD, VARARG_ANY }, rets = { BOOLEAN, VARARG_ANY } }), ["running"] = a_type({ typename = "function", args = {}, rets = { THREAD, BOOLEAN } }), ["status"] = a_type({ typename = "function", args = { THREAD }, rets = { STRING } }), ["wrap"] = a_type({ typename = "function", args = { FUNCTION }, rets = { FUNCTION } }), ["yield"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = { VARARG_ANY } }), }, }), ["debug"] = a_type({ typename = "record", fields = { ["traceback"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { THREAD, STRING, NUMBER }, rets = { STRING } }), a_type({ typename = "function", args = { STRING, NUMBER }, rets = { STRING } }), }, }), ["getinfo"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { ANY }, rets = { DEBUG_GETINFO_TABLE } }), a_type({ typename = "function", args = { ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), a_type({ typename = "function", args = { ANY, ANY, STRING }, rets = { DEBUG_GETINFO_TABLE } }), }, }), }, }), ["io"] = a_type({ typename = "record", fields = { ["close"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = {}, rets = { BOOLEAN, STRING } }), a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { BOOLEAN, STRING } }), }, }), ["flush"] = a_type({ typename = "function", args = {}, rets = {} }), ["input"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }), a_type({ typename = "function", args = { STRING }, rets = { NOMINAL_FILE } }), a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NOMINAL_FILE } }), }, }), ["lines"] = a_type({ typename = "function", args = { OPT_STRING, a_type({ typename = "union", types = { STRING, NUMBER }, is_va = true }) }, rets = { a_type({ typename = "function", args = {}, rets = { VARARG_STRING } }), }, }), ["open"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { NOMINAL_FILE, STRING } }), ["output"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }), a_type({ typename = "function", args = { STRING }, rets = { NOMINAL_FILE } }), a_type({ typename = "function", args = { NOMINAL_FILE }, rets = { NOMINAL_FILE } }), }, }), ["popen"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { NOMINAL_FILE, STRING } }), ["read"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { NOMINAL_FILE, STRING }, rets = { STRING, STRING } }), a_type({ typename = "function", args = { NOMINAL_FILE, NUMBER }, rets = { STRING, STRING } }), }, }), ["stderr"] = NOMINAL_FILE, ["stdin"] = NOMINAL_FILE, ["stdout"] = NOMINAL_FILE, ["tmpfile"] = a_type({ typename = "function", args = {}, rets = { NOMINAL_FILE } }), ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), ["write"] = a_type({ typename = "function", args = { VARARG_STRING }, rets = { NOMINAL_FILE, STRING } }), }, }), ["math"] = a_type({ typename = "record", fields = { ["abs"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["acos"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["asin"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["atan"] = a_type({ typename = "poly", a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), }), ["atan2"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), ["ceil"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["cos"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["cosh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["deg"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["exp"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["floor"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["fmod"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), ["frexp"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER, NUMBER } }), ["huge"] = NUMBER, ["ldexp"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), ["log"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["log10"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["max"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { NUMBER } }), ["maxinteger"] = NUMBER, ["min"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { NUMBER } }), ["mininteger"] = NUMBER, ["modf"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER, NUMBER } }), ["pi"] = NUMBER, ["pow"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), ["rad"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["random"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), ["randomseed"] = a_type({ typename = "function", args = { NUMBER }, rets = {} }), ["sin"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["sinh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["sqrt"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["tan"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["tanh"] = a_type({ typename = "function", args = { NUMBER }, rets = { NUMBER } }), ["tointeger"] = a_type({ typename = "function", args = { ANY }, rets = { NUMBER } }), ["type"] = a_type({ typename = "function", args = { ANY }, rets = { STRING } }), ["ult"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { BOOLEAN } }), }, }), ["os"] = a_type({ typename = "record", fields = { ["clock"] = a_type({ typename = "function", args = {}, rets = { NUMBER } }), ["date"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = {}, rets = { STRING } }), a_type({ typename = "function", args = { STRING, OPT_STRING }, rets = { a_type({ typename = "union", types = { STRING, OS_DATE_TABLE } }) } }), }, }), ["difftime"] = a_type({ typename = "function", args = { NUMBER, NUMBER }, rets = { NUMBER } }), ["execute"] = a_type({ typename = "function", args = { STRING }, rets = { BOOLEAN, STRING, NUMBER } }), ["exit"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { NUMBER, BOOLEAN }, rets = {} }), a_type({ typename = "function", args = { BOOLEAN, BOOLEAN }, rets = {} }), }, }), ["getenv"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), ["remove"] = a_type({ typename = "function", args = { STRING }, rets = { BOOLEAN, STRING } }), ["rename"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { BOOLEAN, STRING } }), ["setlocale"] = a_type({ typename = "function", args = { STRING, OPT_STRING }, rets = { STRING } }), ["time"] = a_type({ typename = "function", args = {}, rets = { NUMBER } }), ["tmpname"] = a_type({ typename = "function", args = {}, rets = { STRING } }), }, }), ["package"] = a_type({ typename = "record", fields = { ["config"] = STRING, ["cpath"] = STRING, ["loaded"] = a_type({ typename = "map", keys = STRING, values = ANY, }), ["loaders"] = a_type({ typename = "array", elements = a_type({ typename = "function", args = { STRING }, rets = { ANY } }), }), ["loadlib"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { FUNCTION } }), ["path"] = STRING, ["preload"] = TABLE, ["searchers"] = a_type({ typename = "array", elements = a_type({ typename = "function", args = { STRING }, rets = { ANY } }), }), ["searchpath"] = a_type({ typename = "function", args = { STRING, STRING, OPT_STRING, OPT_STRING }, rets = { STRING, STRING } }), }, }), ["string"] = a_type({ typename = "record", fields = { ["byte"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), a_type({ typename = "function", args = { STRING, NUMBER }, rets = { NUMBER } }), a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { VARARG_NUMBER } }), }, }), ["char"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { STRING } }), ["dump"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { FUNCTION }, rets = { STRING } }), a_type({ typename = "function", args = { FUNCTION, BOOLEAN }, rets = { STRING } }), }, }), ["find"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { STRING, STRING }, rets = { NUMBER, NUMBER, VARARG_STRING } }), a_type({ typename = "function", args = { STRING, STRING, NUMBER }, rets = { NUMBER, NUMBER, VARARG_STRING } }), a_type({ typename = "function", args = { STRING, STRING, NUMBER, BOOLEAN }, rets = { NUMBER, NUMBER, VARARG_STRING } }), }, }), ["format"] = a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { STRING } }), ["gmatch"] = a_type({ typename = "function", args = { STRING, STRING }, rets = { a_type({ typename = "function", args = {}, rets = { STRING } }), }, }), ["gsub"] = a_type({ typename = "poly", types = { a_type({ typename = "function", args = { STRING, STRING, STRING, NUMBER }, rets = { STRING, NUMBER } }), a_type({ typename = "function", args = { STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), NUMBER }, rets = { STRING, NUMBER } }), a_type({ typename = "function", args = { STRING, STRING, a_type({ typename = "function", args = { VARARG_STRING }, rets = { STRING } }) }, rets = { STRING, NUMBER } }), }, }), ["len"] = a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), ["lower"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), ["match"] = a_type({ typename = "function", args = { STRING, STRING, NUMBER }, rets = { VARARG_STRING } }), ["pack"] = a_type({ typename = "function", args = { STRING, VARARG_ANY }, rets = { STRING } }), ["packsize"] = a_type({ typename = "function", args = { STRING }, rets = { NUMBER } }), ["rep"] = a_type({ typename = "function", args = { STRING, NUMBER }, rets = { STRING } }), ["reverse"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), ["sub"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { STRING } }), ["unpack"] = a_type({ typename = "function", args = { STRING, STRING, OPT_NUMBER }, rets = { VARARG_ANY } }), ["upper"] = a_type({ typename = "function", args = { STRING }, rets = { STRING } }), }, }), ["table"] = a_type({ typename = "record", fields = { ["concat"] = a_type({ typename = "function", args = { ARRAY_OF_STRING, OPT_STRING, OPT_NUMBER, OPT_NUMBER }, rets = { STRING } }), ["insert"] = a_type({ typename = "poly", types = { a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, ALPHA }, rets = {} }), a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, ALPHA }, rets = {} }), }, }), ["move"] = a_type({ typename = "poly", types = { a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER }, rets = { ARRAY_OF_ALPHA } }), a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER, ARRAY_OF_ALPHA }, rets = { ARRAY_OF_ALPHA } }), }, }), ["pack"] = a_type({ typename = "function", args = { VARARG_ANY }, rets = { TABLE } }), ["remove"] = a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, OPT_NUMBER }, rets = { ALPHA } }), ["sort"] = a_type({ typename = "poly", types = { a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA }, rets = {} }), a_type({ typename = "function", typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, a_type({ typename = "function", args = { ALPHA, ALPHA }, rets = { BOOLEAN } }) }, rets = {} }), }, }), ["unpack"] = a_type({ typename = "function", needs_compat53 = true, typeargs = { ARG_ALPHA }, args = { ARRAY_OF_ALPHA, NUMBER, NUMBER }, rets = { VARARG_ALPHA }, }), }, }), ["utf8"] = a_type({ typename = "record", fields = { ["char"] = a_type({ typename = "function", args = { VARARG_NUMBER }, rets = { STRING } }), ["charpattern"] = STRING, ["codepoint"] = a_type({ typename = "function", args = { STRING, OPT_NUMBER, OPT_NUMBER }, rets = { VARARG_NUMBER } }), ["codes"] = a_type({ typename = "function", args = { STRING }, rets = { a_type({ typename = "function", args = {}, rets = { NUMBER, STRING } }), }, }), ["len"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { NUMBER } }), ["offset"] = a_type({ typename = "function", args = { STRING, NUMBER, NUMBER }, rets = { NUMBER } }), }, }), } for _, t in pairs(standard_library) do fill_field_order(t) if t.typename == "typetype" then fill_field_order(t.def) end end fill_field_order(OS_DATE_TABLE) fill_field_order(DEBUG_GETINFO_TABLE) NOMINAL_FILE.found = standard_library["FILE"] NOMINAL_METATABLE.found = standard_library["METATABLE"] local compat53_code_cache = {} local function add_compat53_entries(program, used_set) if not next(used_set) then return end local used_list = {} for name, _ in pairs(used_set) do table.insert(used_list, name) end table.sort(used_list) local compat53_loaded = false local n = 1 local function load_code(name, text) local code = compat53_code_cache[name] if not code then local tokens = tl.lex(text) local _ _, code = tl.parse_program(tokens, {}, "@internal") tl.type_check(code, { lax = false, skip_compat53 = true }) code = code[1] compat53_code_cache[name] = code end table.insert(program, n, code) n = n + 1 end for i, name in ipairs(used_list) do local mod, fn = name:match("([^.]*)%.(.*)") local errs = {} local text local code = compat53_code_cache[name] if not code then if name == "table.unpack" then load_code(name, "local _tl_table_unpack = unpack or table.unpack") else if not compat53_loaded then load_code("compat53", "local _tl_compat53 = ((tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3) and require('compat53.module')") compat53_loaded = true end load_code(name, (("local $NAME = _tl_compat53 and _tl_compat53.$NAME or $NAME"):gsub("$NAME", name))) end end end program.y = 1 end local function get_stdlib_compat53(lax) if lax then return { ["utf8"] = true, } else return { ["io"] = true, ["math"] = true, ["string"] = true, ["table"] = true, ["utf8"] = true, ["coroutine"] = true, ["os"] = true, ["package"] = true, ["debug"] = true, ["load"] = true, ["loadfile"] = true, ["assert"] = true, ["pairs"] = true, ["ipairs"] = true, ["pcall"] = true, ["xpcall"] = true, ["rawlen"] = true, } end end local function init_globals(lax) local globals = {} local stdlib_compat53 = get_stdlib_compat53(lax) for name, typ in pairs(standard_library) do globals[name] = { t = typ, needs_compat53 = stdlib_compat53[name], is_const = true } end globals["@is_va"] = { t = VARARG_ANY } return globals end function tl.init_env(lax, skip_compat53) local env = { modules = {}, globals = init_globals(lax), skip_compat53 = skip_compat53, } for name, var in pairs(standard_library) do if var.typename == "record" then env.modules[name] = var end end return env end function tl.type_check(ast, opts) opts = opts or {} opts.env = opts.env or tl.init_env(opts.lax, opts.skip_compat53) local lax = opts.lax local filename = opts.filename local result = opts.result or { syntax_errors = {}, type_errors = {}, unknowns = {}, } local stdlib_compat53 = get_stdlib_compat53(lax) local st = { opts.env.globals } local all_needs_compat53 = {} local errors = result.type_errors or {} local unknowns = result.unknowns or {} local module_type local function find_var(name) if name == "_G" then local globals = {} for k, v in pairs(st[1]) do if k:sub(1, 1) ~= "@" then globals[k] = v.t end end local field_order = {} for k, _ in pairs(globals) do table.insert(field_order, k) end return a_type({ typename = "record", field_order = field_order, fields = globals, }), false end for i = #st, 1, -1 do local scope = st[i] if scope[name] then if i == 1 and scope[name].needs_compat53 then all_needs_compat53[name] = true end local typ = scope[name].t return typ, scope[name].is_const end end end local function resolve_typevars(t, seen) seen = seen or {} if seen[t] then return seen[t] end local orig_t = t local clear_tk = false if t.typename == "typevar" then local tv = find_var(t.typevar) if tv then t = tv clear_tk = true else t = UNKNOWN end end local copy = {} seen[orig_t] = copy for k, v in pairs(t) do local cp = copy if type(v) == "table" then cp[k] = resolve_typevars(v, seen) else cp[k] = v end end if clear_tk then copy.tk = nil end return copy end local function find_type(names, accept_typearg) local typ = find_var(names[1]) if not typ then return nil end for i = 2, #names do local nested = typ.fields or (typ.def and typ.def.fields) if nested then typ = nested[names[i]] if typ == nil then return nil end else break end end if typ then if accept_typearg and typ.typename == "typearg" then return typ end if is_type(typ) then return typ end end return nil end local function infer_var(emptytable, t, node) local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") local nst = is_global and 1 or #st for i = nst, 1, -1 do local scope = st[i] if scope[emptytable.assigned_to] then scope[emptytable.assigned_to] = { t = t, is_const = false, } t.inferred_at = node t.inferred_at_file = filename end end end local function find_global(name) local scope = st[1] if scope[name] then return scope[name].t, scope[name].is_const end end local function resolve_tuple(t) if t.typename == "tuple" then t = t[1] end if t == nil then return NIL end return t end local function error_in_type(where, msg, ...) local n = select("#", ...) if n > 0 then local showt = {} for i = 1, n do local t = select(i, ...) if t.typename == "invalid" then return nil end showt[i] = show_type(t) end msg = msg:format(_tl_table_unpack(showt)) end return { y = where.y, x = where.x, msg = msg, filename = where.filename or filename, } end local function type_error(t, msg, ...) local e = error_in_type(t, msg, ...) if e then table.insert(errors, e) return true else return false end end local function node_error(node, msg, ...) local ok = type_error(node, msg, ...) node.type = INVALID return node.type end local function terr(t, s, ...) return { error_in_type(t, s, ...) } end local function add_unknown(node, name) table.insert(unknowns, { y = node.y, x = node.x, msg = name, filename = filename }) end local function add_var(node, var, valtype, is_const, is_narrowing) if lax and node and is_unknown(valtype) and (var ~= "self" and var ~= "...") then add_unknown(node, var) end if st[#st][var] and is_narrowing then if not st[#st][var].is_narrowed then st[#st][var].narrowed_from = st[#st][var].t end st[#st][var].is_narrowed = true st[#st][var].t = valtype else st[#st][var] = { t = valtype, is_const = is_const, is_narrowed = is_narrowing } end end local CompareTypes = {} local function compare_typevars(t1, t2, comp) local tv1 = find_var(t1.typevar) local tv2 = find_var(t2.typevar) if t1.typevar == t2.typevar then local has_t1 = not not tv1 local has_t2 = not not tv2 if has_t1 == has_t2 then return true end end local function cmp(k, v, a, b) if find_var(k) then return comp(a, b) else add_var(nil, k, resolve_typevars(v)) return true end end if t2.typename == "typevar" then return cmp(t2.typevar, t1, t1, tv2) else return cmp(t1.typevar, t2, tv1, t2) end end local function add_errs_prefixing(src, dst, prefix, node) if not src then return end for i, err in ipairs(src) do err.msg = prefix .. err.msg if node and node.y and ( (err.filename ~= filename) or (not err.y) or (node.y > err.y or (node.y == err.y and node.x > err.x))) then err.y = node.y err.x = node.x err.filename = filename end table.insert(dst, err) end end local is_a local TypeGetter = {} local function match_record_fields(t1, t2, cmp) cmp = cmp or is_a local fielderrs = {} for _, k in ipairs(t1.field_order) do local f = t1.fields[k] local t2k = t2(k) if t2k == nil then if not lax then table.insert(fielderrs, error_in_type(f, "unknown field " .. k)) end else local match, errs = is_a(f, t2k) add_errs_prefixing(errs, fielderrs, "record field doesn't match: " .. k .. ": ") end end if #fielderrs > 0 then return false, fielderrs end return true end local function match_fields_to_record(t1, t2, cmp) return match_record_fields(t1, function(k) return t2.fields[k] end, cmp) end local function match_fields_to_map(t1, t2) if not match_record_fields(t1, function(_) return t2.values end) then return false, { error_in_type(t1, "not all fields have type %s", t2.values) } end return true end local function arg_check(cmp, a, b, at, n, errs) local matches, match_errs = cmp(a, b) if not matches then add_errs_prefixing(match_errs, errs, "argument " .. n .. ": ", at) return false end return true end local same_type local function has_all_types_of(t1s, t2s) for _, t1 in ipairs(t1s) do local found = false for _, t2 in ipairs(t2s) do if is_a(t2, t1) then found = true break end end if not found then return false end end return true end local function any_errors(all_errs) if #all_errs == 0 then return true else return false, all_errs end end local function are_same_nominals(t1, t2) local same_names if t1.found and t2.found then same_names = t1.found.typeid == t2.found.typeid else local ft1 = t1.found or find_type(t1.names) local ft2 = t2.found or find_type(t2.names) if ft1 and ft2 then same_names = ft1.typeid == ft2.typeid else if not ft1 then type_error(t1, "unknown type %s", t1) end if not ft2 then type_error(t2, "unknown type %s", t2) end return false, {} end end if same_names then if t1.typevals == nil and t2.typevals == nil then return true elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then local all_errs = {} for i = 1, #t1.typevals do local ok, errs = same_type(t2.typevals[i], t1.typevals[i]) add_errs_prefixing(errs, all_errs, "type parameter <" .. show_type(t1.typevals[i]) .. ">: ", t1) end if #all_errs == 0 then return true else return false, all_errs end end else return false, terr(t1, "%s is not a %s", t1, t2) end end same_type = function(t1, t2) assert(type(t1) == "table") assert(type(t2) == "table") if t1.typename == "typevar" or t2.typename == "typevar" then return compare_typevars(t1, t2, same_type) end if t1.typename ~= t2.typename then return false, terr(t1, "got %s, expected %s", t1, t2) end if t1.typename == "array" then return same_type(t1.elements, t2.elements) elseif t1.typename == "map" then local all_errs = {} local k_ok, k_errs = same_type(t1.keys, t2.keys) if not k_ok then add_errs_prefixing(k_errs, all_errs, "keys", t1) end local v_ok, v_errs = same_type(t1.values, t2.values) if not v_ok then add_errs_prefixing(v_errs, all_errs, "values", t1) end return any_errors(all_errs) elseif t1.typename == "union" then if has_all_types_of(t1.types, t2.types) and has_all_types_of(t2.types, t1.types) then return true else return false, terr(t1, "got %s, expected %s", t1, t2) end elseif t1.typename == "nominal" then return are_same_nominals(t1, t2) elseif t1.typename == "record" then return match_fields_to_record(t1, t2, same_type) elseif t1.typename == "function" then if #t1.args ~= #t2.args then return false, terr(t1, "different number of input arguments: got " .. #t1.args .. ", expected " .. #t2.args) end if #t1.rets ~= #t2.rets then return false, terr(t1, "different number of return values: got " .. #t1.args .. ", expected " .. #t2.args) end local all_errs = {} for i = 1, #t1.args do arg_check(same_type, t1.args[i], t2.args[i], t1, i, all_errs) end for i = 1, #t1.rets do local ok, errs = same_type(t1.rets[i], t2.rets[i]) add_errs_prefixing(errs, all_errs, "return " .. i, t1) end return any_errors(all_errs) elseif t1.typename == "arrayrecord" then local ok, errs = same_type(t1.elements, t2.elements) if not ok then return ok, errs end return match_fields_to_record(t1, t2, same_type) end return true end local function a_union(types) local ts = {} local stack = {} local i = 1 while types[i] or stack[1] do local t if stack[1] then t = table.remove(stack) else t = types[i] i = i + 1 end if t.typename == "union" then for _, s in ipairs(t.types) do table.insert(stack, s) end else table.insert(ts, t) end end return a_type({ typename = "union", types = ts, }) end local function is_vararg(t) return t.args and #t.args > 0 and t.args[#t.args].is_va end local function combine_errs(...) local errs for i = 1, select("#", ...) do local e = select(i, ...) if e then errs = errs or {} for _, err in ipairs(e) do table.insert(errs, err) end end end if not errs then return true else return false, errs end end local resolve_unary = nil local function is_known_table_type(t) return (t.typename == "array" or t.typename == "map" or t.typename == "record" or t.typename == "arrayrecord") end is_a = function(t1, t2, for_equality) assert(type(t1) == "table") assert(type(t2) == "table") if lax and (is_unknown(t1) or is_unknown(t2)) then return true end if t1.typename == "nil" then return true end if t2.typename ~= "tuple" then t1 = resolve_tuple(t1) end if t2.typename == "tuple" and t1.typename ~= "tuple" then t1 = a_type({ typename = "tuple", [1] = t1, }) end if t1.typename == "typevar" or t2.typename == "typevar" then return compare_typevars(t1, t2, is_a) end if t2.typename == "any" then return true elseif t2.typename == "poly" then for _, t in ipairs(t2.types) do if is_a(t1, t, for_equality) then return true end end return false, terr(t1, "cannot match against any alternatives of the polymorphic type") elseif t1.typename == "union" and t2.typename == "union" then if has_all_types_of(t1.types, t2.types) then return true else return false, terr(t1, "got %s, expected %s", t1, t2) end elseif t2.typename == "union" then for _, t in ipairs(t2.types) do if is_a(t1, t, for_equality) then return true end end elseif t1.typename == "poly" then for _, t in ipairs(t1.types) do if is_a(t, t2, for_equality) then return true end end return false, terr(t1, "cannot match against any alternatives of the polymorphic type") elseif t1.typename == "nominal" and t2.typename == "nominal" and #t2.names == 1 and t2.names[1] == "any" then return true elseif t1.typename == "nominal" and t2.typename == "nominal" then return are_same_nominals(t1, t2) elseif t1.typename == "enum" and t2.typename == "string" then local ok if for_equality then ok = t2.tk and t1.enumset[unquote(t2.tk)] else ok = true end if ok then return true else return false, terr(t1, "enum is incompatible with %s", t2) end elseif t1.typename == "string" and t2.typename == "enum" then local ok = t1.tk and t2.enumset[unquote(t1.tk)] if ok then return true else if t1.tk then return false, terr(t1, "%s is not a member of %s", t1, t2) else return false, terr(t1, "string is not a %s", t2) end end elseif t1.typename == "nominal" or t2.typename == "nominal" then local t1u = resolve_unary(t1) local t2u = resolve_unary(t2) local ok, errs = is_a(t1u, t2u, for_equality) if errs and #errs == 1 then if errs[1].msg:match("^got ") then errs = terr(t1, "got %s, expected %s", t1, t2) end end return ok, errs elseif t1.typename == "emptytable" and is_known_table_type(t2) then return true elseif t2.typename == "array" then if is_array_type(t1) then if is_a(t1.elements, t2.elements) then return true end elseif t1.typename == "map" then local _, errs_keys = is_a(t1.keys, NUMBER) local _, errs_values = is_a(t1.values, t2.elements) return combine_errs(errs_keys, errs_values) end elseif t2.typename == "record" then if is_record_type(t1) then return match_fields_to_record(t1, t2) elseif t1.typename == "typetype" and t1.def.typename == "record" then return is_a(t1.def, t2, for_equality) end elseif t2.typename == "arrayrecord" then if t1.typename == "array" then return is_a(t1.elements, t2.elements) elseif t1.typename == "record" then return match_fields_to_record(t1, t2) elseif t1.typename == "arrayrecord" then if not is_a(t1.elements, t2.elements) then return false, terr(t1, "array parts have incompatible element types") end return match_fields_to_record(t1, t2) end elseif t2.typename == "map" then if t1.typename == "map" then local _, errs_keys = is_a(t1.keys, t2.keys) local _, errs_values = is_a(t2.values, t1.values) if t2.values.typename == "any" then errs_values = {} end return combine_errs(errs_keys, errs_values) elseif t1.typename == "array" then local _, errs_keys = is_a(NUMBER, t2.keys) local _, errs_values = is_a(t1.elements, t2.values) return combine_errs(errs_keys, errs_values) elseif is_record_type(t1) then if not is_a(t2.keys, STRING) then return false, terr(t1, "can't match a record to a map with non-string keys") end if t2.keys.typename == "enum" then for _, k in ipairs(t1.field_order) do if not t2.keys.enumset[k] then return false, terr(t1, "key is not an enum value: " .. k) end end end return match_fields_to_map(t1, t2) end elseif t1.typename == "function" and t2.typename == "function" then local all_errs = {} if (not is_vararg(t2)) and #t1.args > #t2.args then t1.args.typename = "tuple" t2.args.typename = "tuple" table.insert(all_errs, error_in_type(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) else for i = (t1.is_method and 2 or 1), #t1.args do arg_check(is_a, t1.args[i], t2.args[i] or ANY, nil, i, all_errs) end end local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets[#t2.rets].is_va if #t1.rets < #t2.rets and not diff_by_va then t1.rets.typename = "tuple" t2.rets.typename = "tuple" table.insert(all_errs, error_in_type(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) else local nrets = #t2.rets if diff_by_va then nrets = nrets - 1 end for i = 1, nrets do local ok, errs = is_a(t1.rets[i], t2.rets[i]) add_errs_prefixing(errs, all_errs, "return " .. i .. ": ") end end if #all_errs == 0 then return true else return false, all_errs end elseif lax and ((not for_equality) and t2.typename == "boolean") then return true elseif t1.typename == t2.typename then return true end return false, terr(t1, "got %s, expected %s", t1, t2) end local function assert_is_a(node, t1, t2, context, name) t1 = resolve_tuple(t1) t2 = resolve_tuple(t2) if lax and (is_unknown(t1) or is_unknown(t2)) then return end if t2.typename == "unknown_emptytable_value" then if same_type(t2.emptytable_type.keys, NUMBER) then infer_var(t2.emptytable_type, a_type({ typename = "array", elements = t1 }), node) else infer_var(t2.emptytable_type, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 }), node) end return elseif t2.typename == "emptytable" then if is_known_table_type(t1) then infer_var(t2, t1, node) elseif t1.typename ~= "emptytable" then node_error(node, "in " .. context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) end return end local match, match_errs = is_a(t1, t2) add_errs_prefixing(match_errs, errors, "in " .. context .. ": " .. (name and (name .. ": ") or ""), node) end local function close_types(vars) for name, var in pairs(vars) do if var.t.typename == "typetype" then var.t.closed = true end end end local function begin_scope() table.insert(st, {}) end local function end_scope() local unresolved = st[#st]["@unresolved"] if unresolved then local upper = st[#st - 1]["@unresolved"] if upper then for name, nodes in pairs(unresolved.t.labels) do for _, node in ipairs(nodes) do upper.t.labels[name] = upper.t.labels[name] or {} table.insert(upper.t.labels[name], node) end end for name, types in pairs(unresolved.t.nominals) do for _, typ in ipairs(types) do upper.t.nominals[name] = upper.t.nominals[name] or {} table.insert(upper.t.nominals[name], typ) end end else st[#st - 1]["@unresolved"] = unresolved end end close_types(st[#st]) table.remove(st) end local type_check_function_call do local function try_match_func_args(node, f, args, is_method, argdelta) local ok = true local errs = {} if is_method then argdelta = -1 elseif not argdelta then argdelta = 0 end if f.is_method and not is_method and not (args[1] and is_a(args[1], f.args[1])) then table.insert(errs, { y = node.y, x = node.x, msg = "invoked method as a regular function: use ':' instead of '.'", filename = filename }) return nil, errs end local va = is_vararg(f) local nargs = va and math.max(#args, #f.args) or math.min(#args, #f.args) for a = 1, nargs do local arg = args[a] local farg = f.args[a] or (va and f.args[#f.args]) if arg == nil then if farg.is_va then break end else local at = node.e2 and node.e2[a] or node if not arg_check(is_a, arg, farg, at, (a + argdelta), errs) then ok = false break end end end if ok == true then f.rets.typename = "tuple" for a = 1, #args do local arg = args[a] local farg = f.args[a] or (va and f.args[#f.args]) if arg.typename == "emptytable" then infer_var(arg, resolve_typevars(farg), node.e2[a]) end end return resolve_typevars(f.rets) end return nil, errs end local function revert_typeargs(func) if func.typeargs then for _, arg in ipairs(func.typeargs) do if st[#st][arg.typearg] then st[#st][arg.typearg] = nil end end end end local function remove_sorted_duplicates(t) local prev = nil for i = #t, 1, -1 do if t[i] == prev then table.remove(t, i) else prev = t[i] end end end local function check_call(node, func, args, is_method, argdelta) assert(type(func) == "table") assert(type(args) == "table") if lax and is_unknown(func) then func = a_type({ typename = "function", args = { VARARG_UNKNOWN }, rets = { VARARG_UNKNOWN } }) end func = resolve_unary(func) args = args or {} local poly = func.typename == "poly" and func or { types = { func } } local first_errs local expects = {} local tried = {} for i, f in ipairs(poly.types) do if not tried[i] then if f.typename ~= "function" then if lax and is_unknown(f) then return UNKNOWN end return node_error(node, "not a function: %s", f) end table.insert(expects, tostring(#f.args or 0)) local va = is_vararg(f) if #args == (#f.args or 0) or (va and #args > #f.args) then tried[i] = true local matched, errs = try_match_func_args(node, f, args, is_method, argdelta) if matched then return matched else revert_typeargs(f) end first_errs = first_errs or errs end end end for i, f in ipairs(poly.types) do if not tried[i] then tried[i] = true if #args < (#f.args or 0) then tried[i] = true local matched, errs = try_match_func_args(node, f, args, is_method, argdelta) if matched then return matched else revert_typeargs(f) end first_errs = first_errs or errs end end end for i, f in ipairs(poly.types) do if not tried[i] then if is_vararg(f) and #args > (#f.args or 0) then tried[i] = true local matched, errs = try_match_func_args(node, f, args, is_method, argdelta) if matched then return matched else revert_typeargs(f) end first_errs = first_errs or errs end end end if not first_errs then table.sort(expects) remove_sorted_duplicates(expects) node_error(node, "wrong number of arguments (given " .. #args .. ", expects " .. table.concat(expects, " or ") .. ")") else for _, err in ipairs(first_errs) do table.insert(errors, err) end end poly.types[1].rets.typename = "tuple" return resolve_typevars(poly.types[1].rets) end type_check_function_call = function(node, func, args, is_method, argdelta) begin_scope() local ret = check_call(node, func, args, is_method, argdelta) end_scope() return ret end end local unknown_dots = {} local function add_unknown_dot(node, name) if not unknown_dots[name] then unknown_dots[name] = true add_unknown(node, name) end end local function get_self_type(t) if t.typename == "typetype" then return t.def else return t end end local function match_record_key(node, tbl, key, orig_tbl) assert(type(tbl) == "table") assert(type(key) == "table") tbl = resolve_unary(tbl) local type_description = tbl.typename if tbl.typename == "string" or tbl.typename == "enum" then tbl = find_var("string") end if lax and (is_unknown(tbl) or tbl.typename == "typevar") then if node.e1.kind == "variable" and node.op.op ~= "@funcall" then add_unknown_dot(node, node.e1.tk .. "." .. key.tk) end return UNKNOWN end tbl = get_self_type(tbl) if tbl.typename == "emptytable" then elseif is_record_type(tbl) then assert(tbl.fields, "record has no fields!?") if key.kind == "string" or key.kind == "identifier" then if tbl.fields[key.tk] then return tbl.fields[key.tk] end end else if is_unknown(tbl) then if not lax then node_error(node, "cannot index a value of unknown type") end else node_error(node, "cannot index something that is not a record: %s", tbl) end return INVALID end if lax then if node.e1.kind == "variable" and node.op.op ~= "@funcall" then add_unknown_dot(node, node.e1.tk .. "." .. key.tk) end return UNKNOWN end local description if node.e1.kind == "variable" then description = type_description .. " '" .. node.e1.tk .. "' of type " .. show_type(resolve_tuple(orig_tbl)) else description = "type " .. show_type(resolve_tuple(orig_tbl)) end return node_error(key, "invalid key '" .. key.tk .. "' in " .. description) end local function widen_in_scope(scope, var) if scope[var].is_narrowed then if scope[var].narrowed_from then scope[var].t = scope[var].narrowed_from scope[var].narrowed_from = nil scope[var].is_narrowed = false else scope[var] = nil end return true end return false end local function widen_back_var(var) local widened = false for i = #st, 1, -1 do if st[i][var] then if widen_in_scope(st[i], var) then widened = true else break end end end return widened end local function widen_all_unions() for i = #st, 1, -1 do for var, _ in pairs(st[i]) do widen_in_scope(st[i], var) end end end local function add_global(node, var, valtype, is_const) if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then add_unknown(node, var) end st[1][var] = { t = valtype, is_const = is_const } end local check_typevars local function check_all_typevars(node, ts) if ts ~= nil then for _, arg in ipairs(ts) do check_typevars(node, arg) end end end check_typevars = function(node, t) if t == nil then return end if t.typename == "typevar" then if not find_var(t.typevar) then node_error(node, "unknown type variable " .. t.typevar) end return end check_typevars(node, t.elements) check_typevars(node, t.keys) check_typevars(node, t.values) check_all_typevars(node, t.typeargs) check_all_typevars(node, t.args) check_all_typevars(node, t.rets) end local function get_rets(rets) if lax and (#rets == 0) then return { a_type({ typename = "unknown", is_va = true }) } end return rets end local function begin_function_scope(node, recurse) begin_scope() local args = {} if node.typeargs then for i, arg in ipairs(node.typeargs) do add_var(nil, arg.typearg, arg) end end local is_va = false for i, arg in ipairs(node.args) do local t = arg.decltype if not t then t = a_type({ typename = "unknown" }) end if arg.tk == "..." then is_va = true t.is_va = true if i ~= #node.args then node_error(node, "'...' can only be last argument") end end check_typevars(arg, t) table.insert(args, t) add_var(arg, arg.tk, t) end add_var(nil, "@is_va", is_va and VARARG_ANY or NIL) add_var(nil, "@return", node.rets or a_type({ typename = "tuple" })) if recurse then add_var(nil, node.name.tk, a_type({ typename = "function", args = args, rets = get_rets(node.rets), })) end end local function fail_unresolved() local unresolved = st[#st]["@unresolved"] if unresolved then st[#st]["@unresolved"] = nil for name, nodes in pairs(unresolved.t.labels) do for _, node in ipairs(nodes) do node_error(node, "no visible label '" .. name .. "' for goto") end end for name, types in pairs(unresolved.t.nominals) do for _, typ in ipairs(types) do assert(typ.x) assert(typ.y) type_error(typ, "unknown type %s", typ) end end end end local function end_function_scope() fail_unresolved() end_scope() end local function match_typevals(t, def) if t.typevals and def.typeargs then if #t.typevals ~= #def.typeargs then type_error(t, "mismatch in number of type arguments") return nil end begin_scope() for i, tt in ipairs(t.typevals) do add_var(nil, def.typeargs[i].typearg, tt) end local ret = resolve_typevars(def) end_scope() return ret elseif t.typevals then type_error(t, "spurious type arguments") return nil elseif def.typeargs then type_error(t, "missing type arguments in %s", def) return nil else return def end end local function resolve_nominal(t) if t.resolved then return t.resolved end local resolved local typetype = t.found or find_type(t.names) if not typetype then type_error(t, "unknown type %s", t) elseif is_type(typetype) then resolved = match_typevals(t, typetype.def) else type_error(t, table.concat(t.names, ".") .. " is not a type") end if not resolved then resolved = a_type({ typename = "bad_nominal", names = t.names }) end t.found = typetype t.resolved = resolved return resolved end resolve_unary = function(t) t = resolve_tuple(t) if t.typename == "nominal" then return resolve_nominal(t) end return t end local function flatten_list(list) local exps = {} for i = 1, #list - 1 do table.insert(exps, resolve_unary(list[i])) end if #list > 0 then local last = list[#list] if last.typename == "tuple" then for _, val in ipairs(last) do table.insert(exps, val) end else table.insert(exps, last) end end return exps end local function get_assignment_values(vals, wanted) local ret = {} if vals == nil then return ret end for i = 1, #vals - 1 do ret[i] = vals[i] end local last = vals[#vals] if last.typename == "tuple" then for _, v in ipairs(last) do table.insert(ret, v) end elseif last.is_va and #ret < wanted then while #ret < wanted do table.insert(ret, last) end else table.insert(ret, last) end return ret end local function match_all_record_field_names(node, a, field_names, errmsg) local t for _, k in ipairs(field_names) do local f = a.fields[k] if not t then t = f else if not same_type(f, t) then t = nil break end end end if t then return t else return node_error(node, errmsg) end end local function type_check_index(node, idxnode, a, b) local orig_a = a local orig_b = b a = resolve_unary(a) b = resolve_unary(b) if is_array_type(a) and is_a(b, NUMBER) then return a.elements elseif a.typename == "emptytable" then if a.keys == nil then a.keys = b a.keys_inferred_at = node a.keys_inferred_at_file = filename else if not is_a(b, a.keys) then local inferred = " (type of keys inferred at " .. a.keys_inferred_at_file .. ":" .. a.keys_inferred_at.y .. ":" .. a.keys_inferred_at.x .. ": )" return node_error(idxnode, "inconsistent index type: %s, expected %s" .. inferred, b, a.keys) end end return a_type({ y = node.y, x = node.x, typename = "unknown_emptytable_value", emptytable_type = a }) elseif a.typename == "map" then if is_a(b, a.keys) then return a.values else return node_error(idxnode, "wrong index type: %s, expected %s", orig_b, a.keys) end elseif node.e2.kind == "string" or node.e2.kind == "enum_item" then return match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = assert(node.e2.conststr) }, orig_a) elseif is_record_type(a) and b.typename == "enum" then local field_names = {} for k, _ in pairs(b.enumset) do table.insert(field_names, k) end table.sort(field_names) for _, k in ipairs(field_names) do if not a.fields[k] then return node_error(idxnode, "enum value '" .. k .. "' is not a field in %s", a) end end return match_all_record_field_names(idxnode, a, field_names, "cannot index, not all enum values map to record fields of the same type") elseif lax and is_unknown(a) then return UNKNOWN else if is_a(b, STRING) then return node_error(idxnode, "cannot index object of type %s with a string, consider using an enum", orig_a) end return node_error(idxnode, "cannot index object of type %s with %s", orig_a, orig_b) end end local function expand_type(where, old, new) if not old then return new else if not is_a(new, old) then if old.typename == "map" and is_record_type(new) then if old.keys.typename == "string" then for _, ftype in pairs(new.fields) do old.values = expand_type(where, old.values, ftype) end else node_error(where, "cannot determine table literal type") end elseif is_record_type(old) and is_record_type(new) then old.typename = "map" old.keys = STRING for _, ftype in pairs(old.fields) do if not old.values then old.values = ftype else old.values = expand_type(where, old.values, ftype) end end for _, ftype in pairs(new.fields) do if not old.values then new.values = ftype else new.values = expand_type(where, old.values, ftype) end end old.fields = nil old.field_order = nil elseif old.typename == "union" then new.tk = nil table.insert(old.types, new) else old.tk = nil new.tk = nil return a_union({ old, new }) end end end return old end local function find_in_scope(exp) if exp.kind == "variable" then local t = find_var(exp.tk) if t.def then if not t.def.closed and not t.closed then return t.def end end if not t.closed then return t end elseif exp.kind == "op" and exp.op.op == "." then local t = find_in_scope(exp.e1) if not t then return nil end while exp.e2.kind == "op" and exp.e2.op.op == "." do t = t.fields[exp.e2.e1.tk] if not t then return nil end exp = exp.e2 end t = t.fields[exp.e2.tk] return t end end local facts_and local facts_or local facts_not do local function join_facts(fss) local vars = {} for _, fs in ipairs(fss) do for _, f in ipairs(fs) do if not vars[f.var] then vars[f.var] = {} end table.insert(vars[f.var], f) end end return vars end local function intersect(xs, ys, same) local rs = {} for i = #xs, 1, -1 do local x = xs[i] for _, y in ipairs(ys) do if same(x, y) then table.insert(rs, x) break end end end return rs end local function same_type_for_intersect(t, u) return (same_type(t, u)) end local function intersect_facts(fs, errnode) local all_is = true local types = {} for i, f in ipairs(fs) do if f.fact ~= "is" then all_is = false break end if f.typ.typename == "union" then if i == 1 then types = f.typ.types else types = intersect(types, f.typ.types, same_type_for_intersect) end else if i == 1 then types = { f.typ } else types = intersect(types, { f.typ }, same_type_for_intersect) end end end if #types == 0 then node_error(errnode, "branch is always false") return false end if all_is then if #types == 1 then return true, types[1] else return true, a_union(types) end else return false end end local function sum_facts(fs) local all_is = true local types = {} for _, f in ipairs(fs) do if f.fact ~= "is" then all_is = false break end table.insert(types, f.typ) end if all_is then if #types == 1 then return true, types[1] else return true, a_union(types) end else return false end end local function subtract_types(u1, u2, errt) local types = {} for _, rt in ipairs(u1.types or { u1 }) do local not_present = true for _, ft in ipairs(u2.types or { u2 }) do if same_type(rt, ft) then not_present = false break end end if not_present then table.insert(types, rt) end end if #types == 0 then type_error(errt, "branch is always false") return INVALID end if #types == 1 then return types[1] else return a_union(types) end end facts_and = function(f1, f2, errnode) if not f1 then return f2 end if not f2 then return f1 end local out = {} for v, fs in pairs(join_facts({ f1, f2 })) do local ok, u = intersect_facts(fs, errnode) if ok then table.insert(out, { fact = "is", var = v, typ = u }) else for _, f in ipairs(fs) do table.insert(out, f) end end end return out end facts_or = function(f1, f2) if not f1 or not f2 then return nil end local out = {} for v, fs in pairs(join_facts({ f1, f2 })) do local ok, u = sum_facts(fs) if ok then table.insert(out, { fact = "is", var = v, typ = u }) else for _, f in ipairs(fs) do table.insert(out, f) end end end return out end facts_not = function(f1) if not f1 then return nil end local out = {} for v, fs in pairs(join_facts({ f1 })) do local realtype = find_var(v) if realtype then local ok, u = sum_facts(fs) if ok then local not_typ = subtract_types(realtype, u, fs[1].typ) table.insert(out, { fact = "is", var = v, typ = not_typ }) end end end return out end end local function apply_facts(where, facts) if not facts then return end for _, f in ipairs(facts) do if f.fact == "is" then local t = resolve_typevars(f.typ) t.inferred_at = where t.inferred_at_file = filename add_var(nil, f.var, t, nil, true) end end end local function dismiss_unresolved(name) local unresolved = st[#st]["@unresolved"] if unresolved then if unresolved.t.nominals[name] then for _, t in ipairs(unresolved.t.nominals[name]) do resolve_nominal(t) end end unresolved.t.nominals[name] = nil end end local function type_check_funcall(node, a, b, argdelta) argdelta = argdelta or 0 if node.e1.tk == "rawget" then if #b == 2 then local b1 = resolve_unary(b[1]) local b2 = resolve_unary(b[2]) local knode = node.e2[2] if is_record_type(b1) and knode.conststr then return match_record_key(node, b1, { y = knode.y, x = knode.x, kind = "string", tk = assert(knode.conststr) }, b1) else return type_check_index(node, knode, b1, b2) end else node_error(node, "rawget expects two arguments") end elseif node.e1.tk == "print_type" then print(show_type(b)) return BOOLEAN elseif node.e1.tk == "require" then if #b == 1 then if node.e2[1].kind == "string" then local module_name = assert(node.e2[1].conststr) local t, found = require_module(module_name, lax, opts.env, result) if not found then node_error(node, "module not found: '" .. module_name .. "'") elseif not lax and is_unknown(t) then node_error(node, "no type information for required module: '" .. module_name .. "'") end return t else node_error(node, "don't know how to resolve a dynamic require") end else node_error(node, "require expects one literal argument") end elseif node.e1.tk == "pcall" then local ftype = table.remove(b, 1) local fe2 = {} for i = 2, #node.e2 do table.insert(fe2, node.e2[i]) end local fnode = { y = node.y, x = node.x, typename = "op", op = { op = "@funcall" }, e1 = node.e2[1], e2 = fe2, } local rets = type_check_funcall(fnode, ftype, b, argdelta + 1) if rets.typename ~= "tuple" then rets = a_type({ typename = "tuple", rets }) end table.insert(rets, 1, BOOLEAN) return rets elseif node.e1.op and node.e1.op.op == ":" then local func = node.e1.type if func.typename == "function" or func.typename == "poly" then table.insert(b, 1, node.e1.e1.type) return type_check_function_call(node, func, b, true) else if lax and (is_unknown(func)) then if node.e1.e1.kind == "variable" then add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) end return VARARG_UNKNOWN else return INVALID end end else return type_check_function_call(node, a, b, false, argdelta) end return UNKNOWN end local visit_node = {} visit_node.cbs = { ["statements"] = { before = function() begin_scope() end, after = function(node, children) if #st == 2 then fail_unresolved() end if not node.is_repeat then end_scope() end node.type = NONE end, }, ["local_type"] = { before = function(node) add_var(node.var, node.var.tk, node.value.newtype, node.var.is_const) end, after = function(node, children) dismiss_unresolved(node.var.tk) node.type = NONE end, }, ["global_type"] = { before = function(node) add_global(node.var, node.var.tk, node.value.newtype, node.var.is_const) end, after = function(node, children) local existing, existing_is_const = find_global(node.var.tk) local var = node.var if existing then if existing_is_const == true and not var.is_const then node_error(var, "global was previously declared as : " .. var.tk) end if existing_is_const == false and var.is_const then node_error(var, "global was previously declared as not : " .. var.tk) end if not same_type(existing, node.value.newtype) then node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing) end end dismiss_unresolved(var.tk) node.type = NONE end, }, ["local_declaration"] = { after = function(node, children) local vals = get_assignment_values(children[2], #node.vars) for i, var in ipairs(node.vars) do local decltype = node.decltype and node.decltype[i] local infertype = vals and vals[i] if lax and infertype and infertype.typename == "nil" then infertype = nil end if decltype and infertype then assert_is_a(node.vars[i], infertype, decltype, "local declaration", var.tk) end local t = decltype or infertype if t == nil then t = a_type({ typename = "unknown" }) if not lax then if node.exps then node_error(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. var.tk .. "'") else node_error(node.vars[i], "variable '" .. var.tk .. "' has no type or initial value") end end elseif t.typename == "emptytable" then t.declared_at = node t.assigned_to = var.tk end assert(var) add_var(var, var.tk, t, var.is_const) dismiss_unresolved(var.tk) end node.type = NONE end, }, ["global_declaration"] = { after = function(node, children) local vals = get_assignment_values(children[2], #node.vars) for i, var in ipairs(node.vars) do local decltype = node.decltype and node.decltype[i] local infertype = vals and vals[i] if lax and infertype and infertype.typename == "nil" then infertype = nil end if decltype and infertype then assert_is_a(node.vars[i], infertype, decltype, "global declaration", var.tk) end local t = decltype or infertype local existing, existing_is_const = find_global(var.tk) if existing then if infertype and existing_is_const then node_error(var, "cannot reassign to global: " .. var.tk) end if existing_is_const == true and not var.is_const then node_error(var, "global was previously declared as : " .. var.tk) end if existing_is_const == false and var.is_const then node_error(var, "global was previously declared as not : " .. var.tk) end if not same_type(existing, t) then node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing) end else if t == nil then t = a_type({ typename = "unknown" }) elseif t.typename == "emptytable" then t.declared_at = node t.assigned_to = var.tk end add_global(var, var.tk, t, var.is_const) dismiss_unresolved(var.tk) end end node.type = NONE end, }, ["assignment"] = { after = function(node, children) local vals = get_assignment_values(children[2], #children[1]) local exps = flatten_list(vals) for i, vartype in ipairs(children[1]) do local varnode = node.vars[i] if varnode.is_const then node_error(varnode, "cannot assign to variable") end if varnode.kind == "variable" then if widen_back_var(varnode.tk) then vartype = find_var(varnode.tk) end end if vartype then local val = exps[i] if resolve_unary(vartype).typename == "typetype" then node_error(varnode, "cannot reassign a type") elseif val then assert_is_a(varnode, val, vartype, "assignment") if varnode.kind == "variable" and vartype.typename == "union" then add_var(varnode, varnode.tk, val, false, true) end else node_error(varnode, "variable is not being assigned a value") end else node_error(varnode, "unknown variable") end end node.type = NONE end, }, ["do"] = { after = function(node, children) node.type = NONE end, }, ["if"] = { before_statements = function(node) begin_scope() apply_facts(node.exp, node.exp.facts) end, after = function(node, children) end_scope() node.type = NONE end, }, ["elseif"] = { before = function(node) end_scope() begin_scope() end, before_statements = function(node) local f = facts_not(node.parent_if.exp.facts) for e = 1, node.elseif_n - 1 do f = facts_and(f, facts_not(node.parent_if.elseifs[e].exp.facts), node) end f = facts_and(f, node.exp.facts, node) apply_facts(node.exp, f) end, after = function(node, children) node.type = NONE end, }, ["else"] = { before = function(node) end_scope() begin_scope() local f = facts_not(node.parent_if.exp.facts) for _, elseifnode in ipairs(node.parent_if.elseifs) do f = facts_and(f, facts_not(elseifnode.exp.facts), node) end apply_facts(node, f) end, after = function(node, children) node.type = NONE end, }, ["while"] = { before = function() widen_all_unions() end, before_statements = function(node) begin_scope() apply_facts(node.exp, node.exp.facts) end, after = function(node, children) end_scope() node.type = NONE end, }, ["label"] = { before = function(node) widen_all_unions() local label_id = "::" .. node.label .. "::" if st[#st][label_id] then node_error(node, "label '" .. node.label .. "' already defined at " .. filename) end local unresolved = st[#st]["@unresolved"] if unresolved then unresolved.t.labels[node.label] = nil end node.type = a_type({ y = node.y, x = node.x, typename = "none" }) add_var(node, label_id, node.type) end, }, ["goto"] = { after = function(node, children) if not find_var("::" .. node.label .. "::") then local unresolved = st[#st]["@unresolved"] and st[#st]["@unresolved"].t if not unresolved then unresolved = { typename = "unresolved", labels = {}, nominals = {} } add_var(node, "@unresolved", unresolved) end unresolved.labels[node.label] = unresolved.labels[node.label] or {} table.insert(unresolved.labels[node.label], node) end node.type = NONE end, }, ["repeat"] = { before = function() widen_all_unions() end, after = function(node, children) end_scope() node.type = NONE end, }, ["forin"] = { before = function() begin_scope() end, before_statements = function(node) local exp1 = node.exps[1] local exp1type = resolve_tuple(exp1.type) if exp1type.typename == "function" then if exp1.op and exp1.op.op == "@funcall" then local t = resolve_unary(exp1.e2.type) if exp1.e1.tk == "pairs" and not (t.typename == "map" or t.typename == "record") then if not (lax and is_unknown(t)) then node_error(exp1, "attempting pairs loop on something that's not a map or record: %s", exp1.e2.type) end elseif exp1.e1.tk == "ipairs" and not is_array_type(t) then if not (lax and (is_unknown(t) or t.typename == "emptytable")) then node_error(exp1, "attempting ipairs loop on something that's not an array: %s", exp1.e2.type) end end end local last for i, v in ipairs(node.vars) do local r = exp1type.rets[i] if not r then if last and last.is_va then r = last else r = UNKNOWN end end add_var(v, v.tk, r) last = r end else if not (lax and is_unknown(exp1type)) then node_error(exp1, "expression in for loop does not return an iterator") end end end, after = function(node, children) end_scope() node.type = NONE end, }, ["fornum"] = { before = function(node) begin_scope() add_var(nil, node.var.tk, NUMBER) end, after = function(node, children) end_scope() node.type = NONE end, }, ["return"] = { after = function(node, children) local rets = assert(find_var("@return")) local nrets = #rets local vatype if nrets > 0 then vatype = rets[nrets].is_va and rets[nrets] end if #children[1] > nrets and (not lax) and not vatype then rets.typename = "tuple" children[1].typename = "tuple" node_error(node, "excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) end for i = 1, #children[1] do local expected = rets[i] or vatype if expected then expected = resolve_unary(expected) local where = (node.exps[i] and node.exps[i].x) and node.exps[i] or node.exps assert(where and where.x) assert_is_a(where, children[1][i], expected, "return value") end end if #st == 2 then module_type = resolve_unary(children[1]) end node.type = NONE end, }, ["variables"] = { after = function(node, children) node.type = children local n = #children if n > 0 and children[n].typename == "tuple" then local tuple = children[n] for i, c in ipairs(tuple) do children[n + i - 1] = c end end node.type.typename = "tuple" end, }, ["table_literal"] = { after = function(node, children) node.type = a_type({ y = node.y, x = node.x, typename = "emptytable", }) local is_record = false local is_array = false local is_map = false for i, child in ipairs(children) do assert(child.typename == "table_item") if child.kname then is_record = true if not node.type.fields then node.type.fields = {} node.type.field_order = {} end node.type.fields[child.kname] = child.vtype table.insert(node.type.field_order, child.kname) elseif child.ktype.typename == "number" then is_array = true if i == #children and node[i].key_parsed == "implicit" and child.vtype.typename == "tuple" then for _, c in ipairs(child.vtype) do node.type.elements = expand_type(node, node.type.elements, c) end else node.type.elements = expand_type(node, node.type.elements, child.vtype) end if not node.type.elements then node_error(node, "cannot determine type of array elements") is_array = false end else is_map = true node.type.keys = expand_type(node, node.type.keys, child.ktype) node.type.values = expand_type(node, node.type.values, child.vtype) end end if is_array and is_map then node_error(node, "cannot determine type of table literal") elseif is_record and is_array then node.type.typename = "arrayrecord" elseif is_record and is_map then if node.type.keys.typename == "string" then node.type.typename = "map" for _, ftype in pairs(node.type.fields) do node.type.values = expand_type(node, node.type.values, ftype) end node.type.fields = nil node.type.field_order = nil else node_error(node, "cannot determine type of table literal") end elseif is_array then node.type.typename = "array" elseif is_record then node.type.typename = "record" elseif is_map then node.type.typename = "map" end end, }, ["table_item"] = { after = function(node, children) local kname = node.key.conststr local ktype = children[1] local vtype = children[2] if node.decltype then vtype = node.decltype assert_is_a(node.value, children[2], node.decltype, "table item") end node.type = a_type({ y = node.y, x = node.x, typename = "table_item", kname = kname, ktype = ktype, vtype = vtype, }) end, }, ["local_function"] = { before = function(node) begin_function_scope(node, true) end, after = function(node, children) end_function_scope() local rets = get_rets(children[3]) add_var(nil, node.name.tk, a_type({ typename = "function", args = children[2], rets = rets, })) node.type = NONE end, }, ["global_function"] = { before = function(node) begin_function_scope(node, true) end, after = function(node, children) end_function_scope() add_global(nil, node.name.tk, a_type({ typename = "function", args = children[2], rets = get_rets(children[3]), })) node.type = NONE end, }, ["record_function"] = { before = function(node) begin_function_scope(node) end, before_statements = function(node, children) if node.is_method then local rtype = get_self_type(children[1]) children[3][1] = rtype add_var(nil, "self", rtype) end local rtype = resolve_unary(get_self_type(children[1])) if rtype.typename == "emptytable" then rtype.typename = "record" end if is_record_type(rtype) then local fn_type = a_type({ y = node.y, x = node.x, typename = "function", is_method = node.is_method, args = children[3], rets = get_rets(children[4]), }) local ok = false if lax then ok = true elseif rtype.fields and rtype.fields[node.name.tk] and is_a(fn_type, rtype.fields[node.name.tk]) then ok = true elseif find_in_scope(node.fn_owner) == rtype then ok = true end if ok then rtype.fields = rtype.fields or {} rtype.field_order = rtype.field_order or {} rtype.fields[node.name.tk] = fn_type table.insert(rtype.field_order, node.name.tk) else local name = tl.pretty_print_ast(node.fn_owner, { preserve_indent = true, preserve_newlines = false }) node_error(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. name .. "' was originally declared") end else if (not lax) or (rtype.typename ~= "unknown") then node_error(node, "not a module: %s", rtype) end end end, after = function(node, children) end_function_scope() node.type = NONE end, }, ["function"] = { before = function(node) begin_function_scope(node) end, after = function(node, children) end_function_scope() node.type = a_type({ y = node.y, x = node.x, typename = "function", args = children[1], rets = children[2], }) end, }, ["cast"] = { after = function(node, children) node.type = node.casttype end, }, ["paren"] = { after = function(node, children) node.type = resolve_unary(children[1]) end, }, ["op"] = { before = function(node) begin_scope() end, before_e2 = function(node) if node.op.op == "and" then apply_facts(node, node.e1.facts) elseif node.op.op == "or" then apply_facts(node, facts_not(node.e1.facts)) end end, after = function(node, children) end_scope() local a = children[1] local b = children[3] local orig_a = a local orig_b = b local ua = a and resolve_unary(a) local ub = b and resolve_unary(b) if node.op.op == "@funcall" then node.type = type_check_funcall(node, a, b) elseif node.op.op == "@index" then node.type = type_check_index(node, node.e2, a, b) elseif node.op.op == "as" then node.type = b elseif node.op.op == "is" then if node.e1.kind == "variable" then node.facts = { { fact = "is", var = node.e1.tk, typ = b } } else node_error(node, "can only use 'is' on variables") end node.type = BOOLEAN elseif node.op.op == "." then a = ua if a.typename == "map" then if is_a(a.keys, STRING) or is_a(a.keys, ANY) then node.type = a.values else node_error(node, "cannot use . index, expects keys of type %s", a.keys) end else node.type = match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = node.e2.tk }, orig_a) if node.type.needs_compat53 and not opts.skip_compat53 then local key = node.e1.tk .. "." .. node.e2.tk node.kind = "variable" node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk all_needs_compat53[key] = true end end elseif node.op.op == ":" then node.type = match_record_key(node, node.e1.type, node.e2, orig_a) elseif node.op.op == "not" then node.facts = facts_not(node.e1.facts) node.type = BOOLEAN elseif node.op.op == "and" then node.facts = facts_and(node.e1.facts, node.e2.facts, node) node.type = resolve_tuple(b) elseif node.op.op == "or" and b.typename == "emptytable" then node.facts = nil node.type = resolve_tuple(a) elseif node.op.op == "or" and same_type(ua, ub) then node.facts = facts_or(node.e1.facts, node.e2.facts) node.type = resolve_tuple(a) elseif node.op.op == "or" and b.typename == "nil" then node.facts = nil node.type = resolve_tuple(a) elseif node.op.op == "or" and ((ua.typename == "enum" and ub.typename == "string" and is_a(ub, ua)) or (ua.typename == "string" and ub.typename == "enum" and is_a(ua, ub))) then node.facts = nil node.type = (ua.typename == "enum" and ua or ub) elseif node.op.op == "or" and (a.typename == "nominal" or a.typename == "map") and is_record_type(b) and is_a(b, a) then node.facts = nil node.type = resolve_tuple(a) elseif node.op.op == "==" or node.op.op == "~=" then if is_a(a, b, true) or is_a(b, a, true) then node.type = BOOLEAN else if lax and (is_unknown(a) or is_unknown(b)) then node.type = UNKNOWN else node_error(node, "types are not comparable for equality: %s and %s", a, b) end end elseif node.op.arity == 1 and unop_types[node.op.op] then a = ua local types_op = unop_types[node.op.op] node.type = types_op[a.typename] if not node.type then if lax and is_unknown(a) then node.type = UNKNOWN else node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", orig_a) end end elseif node.op.arity == 2 and binop_types[node.op.op] then if node.op.op == "or" then node.facts = facts_or(node.e1.facts, node.e2.facts) end a = ua b = ub local types_op = binop_types[node.op.op] node.type = types_op[a.typename] and types_op[a.typename][b.typename] if not node.type then if lax and (is_unknown(a) or is_unknown(b)) then node.type = UNKNOWN else node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", orig_a, orig_b) end end else error("unknown node op " .. node.op.op) end end, }, ["variable"] = { after = function(node, children) if node.tk == "..." then local va_sentinel = find_var("@is_va") if not va_sentinel or va_sentinel.typename == "nil" then node.type = UNKNOWN node_error(node, "cannot use '...' outside a vararg function") end end node.type, node.is_const = find_var(node.tk) if node.type == nil then node.type = a_type({ typename = "unknown" }) if lax then add_unknown(node, node.tk) else node_error(node, "unknown variable: " .. node.tk) end end end, }, ["identifier"] = { after = function(node, children) node.type = NONE end, }, ["newtype"] = { after = function(node, children) node.type = node.newtype end, }, } visit_node.cbs["break"] = visit_node.cbs["do"] visit_node.cbs["values"] = visit_node.cbs["variables"] visit_node.cbs["expression_list"] = visit_node.cbs["variables"] visit_node.cbs["argument_list"] = visit_node.cbs["variables"] visit_node.cbs["argument"] = visit_node.cbs["variable"] visit_node.cbs["string"] = { after = function(node, children) node.type = a_type({ y = node.y, x = node.x, typename = node.kind, tk = node.tk, }) return node.type end, } visit_node.cbs["number"] = visit_node.cbs["string"] visit_node.cbs["nil"] = visit_node.cbs["string"] visit_node.cbs["boolean"] = visit_node.cbs["string"] visit_node.cbs["..."] = visit_node.cbs["variable"] visit_node.after = { after = function(node, children) assert(type(node.type) == "table", node.kind .. " did not produce a type") assert(type(node.type.typename) == "string", node.kind .. " type does not have a typename") return node.type end, } local visit_type = { cbs = { ["string"] = { after = function(typ, children) return typ end, }, ["function"] = { before = function(typ, children) begin_scope() end, after = function(typ, children) end_scope() return typ end, }, ["record"] = { before = function(typ, children) begin_scope() for name, typ in pairs(typ.fields) do if typ.typename == "typetype" then typ.typename = "nestedtype" add_var(nil, name, typ) end end end, after = function(typ, children) end_scope() for name, typ in pairs(typ.fields) do if typ.typename == "nestedtype" then typ.typename = "typetype" end end return typ end, }, ["typearg"] = { after = function(typ, children) add_var(nil, typ.typearg, a_type({ y = typ.y, x = typ.x, typename = "typearg", typearg = typ.typearg, })) return typ end, }, ["nominal"] = { after = function(typ, children) local t = find_type(typ.names, true) if t then if t.typename == "typearg" then typ.names = nil typ.typename = "typevar" typ.typevar = t.typearg else typ.found = t end else local name = typ.names[1] local unresolved = find_var("@unresolved") if not unresolved then unresolved = { typename = "unresolved", labels = {}, nominals = {} } add_var(nil, "@unresolved", unresolved) end unresolved.nominals[name] = unresolved.nominals[name] or {} table.insert(unresolved.nominals[name], typ) end return typ end, }, ["union"] = { after = function(typ, children) local n_table_types = 0 local n_function_types = 0 local n_string_enum = 0 for _, t in ipairs(typ.types) do t = resolve_unary(t) if table_types[t.typename] then n_table_types = n_table_types + 1 if n_table_types > 1 then type_error(typ, "cannot discriminate a union between multiple table types: %s", typ) break end elseif t.typename == "function" then n_function_types = n_function_types + 1 if n_function_types > 1 then type_error(typ, "cannot discriminate a union between multiple function types: %s", typ) break end elseif t.typename == "string" or t.typename == "enum" then n_string_enum = n_string_enum + 1 if n_string_enum > 1 then type_error(typ, "cannot discriminate a union between multiple string/enum types: %s", typ) break end end end return typ end, }, }, after = { after = function(typ, children, ret) assert(type(ret) == "table", typ.typename .. " did not produce a type") assert(type(ret.typename) == "string", "type node does not have a typename") return ret end, }, } visit_type.cbs["typetype"] = visit_type.cbs["string"] visit_type.cbs["nestedtype"] = visit_type.cbs["string"] visit_type.cbs["typevar"] = visit_type.cbs["string"] visit_type.cbs["array"] = visit_type.cbs["string"] visit_type.cbs["map"] = visit_type.cbs["string"] visit_type.cbs["arrayrecord"] = visit_type.cbs["string"] visit_type.cbs["enum"] = visit_type.cbs["string"] visit_type.cbs["boolean"] = visit_type.cbs["string"] visit_type.cbs["nil"] = visit_type.cbs["string"] visit_type.cbs["number"] = visit_type.cbs["string"] visit_type.cbs["thread"] = visit_type.cbs["string"] visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] visit_type.cbs["emptytable"] = visit_type.cbs["string"] visit_type.cbs["table_item"] = visit_type.cbs["string"] visit_type.cbs["unknown_emptytable_value"] = visit_type.cbs["string"] visit_type.cbs["tuple"] = visit_type.cbs["string"] visit_type.cbs["poly"] = visit_type.cbs["string"] visit_type.cbs["any"] = visit_type.cbs["string"] visit_type.cbs["unknown"] = visit_type.cbs["string"] visit_type.cbs["invalid"] = visit_type.cbs["string"] visit_type.cbs["unresolved"] = visit_type.cbs["string"] visit_type.cbs["none"] = visit_type.cbs["string"] recurse_node(ast, visit_node, visit_type) close_types(st[1]) local redundant = {} local lastx, lasty = 0, 0 table.sort(errors, function(a, b) return ((a.filename and b.filename) and a.filename < b.filename) or (a.filename == b.filename and ((a.y < b.y) or (a.y == b.y and a.x < b.x))) end) for i, err in ipairs(errors) do if err.x == lastx and err.y == lasty then table.insert(redundant, i) end lastx, lasty = err.x, err.y end for i = #redundant, 1, -1 do table.remove(errors, redundant[i]) end if not opts.skip_compat53 then add_compat53_entries(ast, all_needs_compat53) end return errors, unknowns, module_type end function tl.process(filename, env, result, preload_modules) local fd, err = io.open(filename, "r") if not fd then return nil, "could not open " .. filename .. ": " .. err end local input, err = fd:read("*a") fd:close() if not input then return nil, "could not read " .. filename .. ": " .. err end local basename, extension = filename:match("(.*)%.([a-z]+)$") extension = extension and extension:lower() local is_lua if extension == "tl" then is_lua = false elseif extension == "lua" then is_lua = true else is_lua = input:match("^#![^\n]*lua[^\n]*\n") end result, err = tl.process_string(input, is_lua, env, result, preload_modules, filename) if err then return nil, err end return result end function tl.process_string(input, is_lua, env, result, preload_modules, filename) env = env or tl.init_env(is_lua) result = result or { syntax_errors = {}, type_errors = {}, unknowns = {}, } preload_modules = preload_modules or {} filename = filename or "" local tokens, errs = tl.lex(input) if errs then for i, err in ipairs(errs) do table.insert(result.syntax_errors, { y = err.y, x = err.x, msg = "invalid token '" .. err.tk .. "'", filename = filename, }) end end local i, program = tl.parse_program(tokens, result.syntax_errors, filename) if #result.syntax_errors > 0 then return result end for _, name in ipairs(preload_modules) do local module_type = require_module(name, is_lua, env, result) if module_type == UNKNOWN then return nil, string.format("Error: could not preload module '%s'", name) end end local error, unknown local opts = { lax = is_lua, filename = filename, env = env, result = result, skip_compat53 = env.skip_compat53, } error, unknown, result.type = tl.type_check(program, opts) result.ast = program result.env = env return result end function tl.gen(input, env) env = env or tl.init_env() local result, err = tl.process_string(input, false, env) if err then return nil, nil end if not result.ast then return nil, result end return tl.pretty_print_ast(result.ast), result end local function tl_package_loader(module_name) local found_filename, fd, tried = tl.search_module(module_name, false) if found_filename then local input = fd:read("*a") fd:close() local errs = {} local _, program = tl.parse_program(tl.lex(input), errs, module_name) if #errs > 0 then error(module_name .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) end local code = tl.pretty_print_ast(program, true) local chunk, err = load(code, module_name, "t") if chunk then return function() local ret = chunk() package.loaded[module_name] = ret return ret end else error("Internal Compiler Error: Teal generator produced invalid Lua. Please report a bug at https://github.com/teal-language/tl") end end return table.concat(tried, "\n\t") end function tl.loader() if package.searchers then table.insert(package.searchers, 2, tl_package_loader) else table.insert(package.loaders, 2, tl_package_loader) end end function tl.load(input, chunkname, mode, env) local tokens = tl.lex(input) local errs = {} local i, program = tl.parse_program(tokens, errs, chunkname) if #errs > 0 then return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg end local code = tl.pretty_print_ast(program, true) return load(code, chunkname, mode, env) end return tl