#!/usr/bin/env python3 import copy import json import re import subprocess from enum import Enum as PyEnum from typing import Callable from urllib import request VoidFn = Callable[[], None] CHEATCODES_JSON_URL = "https://raw.githubusercontent.com/foundry-rs/foundry/master/crates/cheatcodes/assets/cheatcodes.json" OUT_PATH = "src/Vm.sol" VM_SAFE_DOC = """\ /// The `VmSafe` interface does not allow manipulation of the EVM state or other actions that may /// result in Script simulations differing from on-chain execution. It is recommended to only use /// these cheats in scripts. """ VM_DOC = """\ /// The `Vm` interface does allow manipulation of the EVM state. These are all intended to be used /// in tests, but it is not recommended to use these cheats in scripts. """ def main(): json_str = request.urlopen(CHEATCODES_JSON_URL).read().decode("utf-8") contract = Cheatcodes.from_json(json_str) ccs = contract.cheatcodes ccs = list(filter(lambda cc: cc.status not in ["experimental", "internal"], ccs)) ccs.sort(key=lambda cc: cc.func.id) safe = list(filter(lambda cc: cc.safety == "safe", ccs)) safe.sort(key=CmpCheatcode) unsafe = list(filter(lambda cc: cc.safety == "unsafe", ccs)) unsafe.sort(key=CmpCheatcode) assert len(safe) + len(unsafe) == len(ccs) prefix_with_group_headers(safe) prefix_with_group_headers(unsafe) out = "" out += "// Automatically @generated by scripts/vm.py. Do not modify manually.\n\n" pp = CheatcodesPrinter( spdx_identifier="MIT OR Apache-2.0", solidity_requirement=">=0.6.2 <0.9.0", abicoder_pragma=True, ) pp.p_prelude() pp.prelude = False out += pp.finish() out += "\n\n" out += VM_SAFE_DOC vm_safe = Cheatcodes( # TODO: Custom errors were introduced in 0.8.4 errors=[], # contract.errors events=contract.events, enums=contract.enums, structs=contract.structs, cheatcodes=safe, ) pp.p_contract(vm_safe, "VmSafe") out += pp.finish() out += "\n\n" out += VM_DOC vm_unsafe = Cheatcodes( errors=[], events=[], enums=[], structs=[], cheatcodes=unsafe, ) pp.p_contract(vm_unsafe, "Vm", "VmSafe") out += pp.finish() # Compatibility with <0.8.0 def memory_to_calldata(m: re.Match) -> str: return " calldata " + m.group(1) out = re.sub(r" memory (.*returns)", memory_to_calldata, out) with open(OUT_PATH, "w") as f: f.write(out) forge_fmt = ["forge", "fmt", OUT_PATH] res = subprocess.run(forge_fmt) assert res.returncode == 0, f"command failed: {forge_fmt}" print(f"Wrote to {OUT_PATH}") class CmpCheatcode: cheatcode: "Cheatcode" def __init__(self, cheatcode: "Cheatcode"): self.cheatcode = cheatcode def __lt__(self, other: "CmpCheatcode") -> bool: return cmp_cheatcode(self.cheatcode, other.cheatcode) < 0 def __eq__(self, other: "CmpCheatcode") -> bool: return cmp_cheatcode(self.cheatcode, other.cheatcode) == 0 def __gt__(self, other: "CmpCheatcode") -> bool: return cmp_cheatcode(self.cheatcode, other.cheatcode) > 0 def cmp_cheatcode(a: "Cheatcode", b: "Cheatcode") -> int: if a.group != b.group: return -1 if a.group < b.group else 1 if a.status != b.status: return -1 if a.status < b.status else 1 if a.safety != b.safety: return -1 if a.safety < b.safety else 1 if a.func.id != b.func.id: return -1 if a.func.id < b.func.id else 1 return 0 # HACK: A way to add group header comments without having to modify printer code def prefix_with_group_headers(cheats: list["Cheatcode"]): s = set() for i, cheat in enumerate(cheats): if cheat.group in s: continue s.add(cheat.group) c = copy.deepcopy(cheat) c.func.description = "" c.func.declaration = f"// ======== {group(c.group)} ========" cheats.insert(i, c) return cheats def group(s: str) -> str: if s == "evm": return "EVM" if s == "json": return "JSON" return s[0].upper() + s[1:] class Visibility(PyEnum): EXTERNAL: str = "external" PUBLIC: str = "public" INTERNAL: str = "internal" PRIVATE: str = "private" def __str__(self): return self.value class Mutability(PyEnum): PURE: str = "pure" VIEW: str = "view" NONE: str = "" def __str__(self): return self.value class Function: id: str description: str declaration: str visibility: Visibility mutability: Mutability signature: str selector: str selector_bytes: bytes def __init__( self, id: str, description: str, declaration: str, visibility: Visibility, mutability: Mutability, signature: str, selector: str, selector_bytes: bytes, ): self.id = id self.description = description self.declaration = declaration self.visibility = visibility self.mutability = mutability self.signature = signature self.selector = selector self.selector_bytes = selector_bytes @staticmethod def from_dict(d: dict) -> "Function": return Function( d["id"], d["description"], d["declaration"], Visibility(d["visibility"]), Mutability(d["mutability"]), d["signature"], d["selector"], bytes(d["selectorBytes"]), ) class Cheatcode: func: Function group: str status: str safety: str def __init__(self, func: Function, group: str, status: str, safety: str): self.func = func self.group = group self.status = status self.safety = safety @staticmethod def from_dict(d: dict) -> "Cheatcode": return Cheatcode( Function.from_dict(d["func"]), str(d["group"]), str(d["status"]), str(d["safety"]), ) class Error: name: str description: str declaration: str def __init__(self, name: str, description: str, declaration: str): self.name = name self.description = description self.declaration = declaration @staticmethod def from_dict(d: dict) -> "Error": return Error(**d) class Event: name: str description: str declaration: str def __init__(self, name: str, description: str, declaration: str): self.name = name self.description = description self.declaration = declaration @staticmethod def from_dict(d: dict) -> "Event": return Event(**d) class EnumVariant: name: str description: str def __init__(self, name: str, description: str): self.name = name self.description = description class Enum: name: str description: str variants: list[EnumVariant] def __init__(self, name: str, description: str, variants: list[EnumVariant]): self.name = name self.description = description self.variants = variants @staticmethod def from_dict(d: dict) -> "Enum": return Enum( d["name"], d["description"], list(map(lambda v: EnumVariant(**v), d["variants"])), ) class StructField: name: str ty: str description: str def __init__(self, name: str, ty: str, description: str): self.name = name self.ty = ty self.description = description class Struct: name: str description: str fields: list[StructField] def __init__(self, name: str, description: str, fields: list[StructField]): self.name = name self.description = description self.fields = fields @staticmethod def from_dict(d: dict) -> "Struct": return Struct( d["name"], d["description"], list(map(lambda f: StructField(**f), d["fields"])), ) class Cheatcodes: errors: list[Error] events: list[Event] enums: list[Enum] structs: list[Struct] cheatcodes: list[Cheatcode] def __init__( self, errors: list[Error], events: list[Event], enums: list[Enum], structs: list[Struct], cheatcodes: list[Cheatcode], ): self.errors = errors self.events = events self.enums = enums self.structs = structs self.cheatcodes = cheatcodes @staticmethod def from_dict(d: dict) -> "Cheatcodes": return Cheatcodes( errors=[Error.from_dict(e) for e in d["errors"]], events=[Event.from_dict(e) for e in d["events"]], enums=[Enum.from_dict(e) for e in d["enums"]], structs=[Struct.from_dict(e) for e in d["structs"]], cheatcodes=[Cheatcode.from_dict(e) for e in d["cheatcodes"]], ) @staticmethod def from_json(s) -> "Cheatcodes": return Cheatcodes.from_dict(json.loads(s)) @staticmethod def from_json_file(file_path: str) -> "Cheatcodes": with open(file_path, "r") as f: return Cheatcodes.from_dict(json.load(f)) class Item(PyEnum): ERROR: str = "error" EVENT: str = "event" ENUM: str = "enum" STRUCT: str = "struct" FUNCTION: str = "function" class ItemOrder: _list: list[Item] def __init__(self, list: list[Item]) -> None: assert len(list) <= len(Item), "list must not contain more items than Item" assert len(list) == len(set(list)), "list must not contain duplicates" self._list = list pass def get_list(self) -> list[Item]: return self._list @staticmethod def default() -> "ItemOrder": return ItemOrder( [ Item.ERROR, Item.EVENT, Item.ENUM, Item.STRUCT, Item.FUNCTION, ] ) class CheatcodesPrinter: buffer: str prelude: bool spdx_identifier: str solidity_requirement: str abicoder_v2: bool block_doc_style: bool indent_level: int _indent_str: str nl_str: str items_order: ItemOrder def __init__( self, buffer: str = "", prelude: bool = True, spdx_identifier: str = "UNLICENSED", solidity_requirement: str = "", abicoder_pragma: bool = False, block_doc_style: bool = False, indent_level: int = 0, indent_with: int | str = 4, nl_str: str = "\n", items_order: ItemOrder = ItemOrder.default(), ): self.prelude = prelude self.spdx_identifier = spdx_identifier self.solidity_requirement = solidity_requirement self.abicoder_v2 = abicoder_pragma self.block_doc_style = block_doc_style self.buffer = buffer self.indent_level = indent_level self.nl_str = nl_str if isinstance(indent_with, int): assert indent_with >= 0 self._indent_str = " " * indent_with elif isinstance(indent_with, str): self._indent_str = indent_with else: assert False, "indent_with must be int or str" self.items_order = items_order def finish(self) -> str: ret = self.buffer.rstrip() self.buffer = "" return ret def p_contract(self, contract: Cheatcodes, name: str, inherits: str = ""): if self.prelude: self.p_prelude(contract) self._p_str("interface ") name = name.strip() if name != "": self._p_str(name) self._p_str(" ") if inherits != "": self._p_str("is ") self._p_str(inherits) self._p_str(" ") self._p_str("{") self._p_nl() self._with_indent(lambda: self._p_items(contract)) self._p_str("}") self._p_nl() def _p_items(self, contract: Cheatcodes): for item in self.items_order.get_list(): if item == Item.ERROR: self.p_errors(contract.errors) elif item == Item.EVENT: self.p_events(contract.events) elif item == Item.ENUM: self.p_enums(contract.enums) elif item == Item.STRUCT: self.p_structs(contract.structs) elif item == Item.FUNCTION: self.p_functions(contract.cheatcodes) else: assert False, f"unknown item {item}" def p_prelude(self, contract: Cheatcodes | None = None): self._p_str(f"// SPDX-License-Identifier: {self.spdx_identifier}") self._p_nl() if self.solidity_requirement != "": req = self.solidity_requirement elif contract and len(contract.errors) > 0: req = ">=0.8.4 <0.9.0" else: req = ">=0.6.0 <0.9.0" self._p_str(f"pragma solidity {req};") self._p_nl() if self.abicoder_v2: self._p_str("pragma experimental ABIEncoderV2;") self._p_nl() self._p_nl() def p_errors(self, errors: list[Error]): for error in errors: self._p_line(lambda: self.p_error(error)) def p_error(self, error: Error): self._p_comment(error.description, doc=True) self._p_line(lambda: self._p_str(error.declaration)) def p_events(self, events: list[Event]): for event in events: self._p_line(lambda: self.p_event(event)) def p_event(self, event: Event): self._p_comment(event.description, doc=True) self._p_line(lambda: self._p_str(event.declaration)) def p_enums(self, enums: list[Enum]): for enum in enums: self._p_line(lambda: self.p_enum(enum)) def p_enum(self, enum: Enum): self._p_comment(enum.description, doc=True) self._p_line(lambda: self._p_str(f"enum {enum.name} {{")) self._with_indent(lambda: self.p_enum_variants(enum.variants)) self._p_line(lambda: self._p_str("}")) def p_enum_variants(self, variants: list[EnumVariant]): for i, variant in enumerate(variants): self._p_indent() self._p_comment(variant.description) self._p_indent() self._p_str(variant.name) if i < len(variants) - 1: self._p_str(",") self._p_nl() def p_structs(self, structs: list[Struct]): for struct in structs: self._p_line(lambda: self.p_struct(struct)) def p_struct(self, struct: Struct): self._p_comment(struct.description, doc=True) self._p_line(lambda: self._p_str(f"struct {struct.name} {{")) self._with_indent(lambda: self.p_struct_fields(struct.fields)) self._p_line(lambda: self._p_str("}")) def p_struct_fields(self, fields: list[StructField]): for field in fields: self._p_line(lambda: self.p_struct_field(field)) def p_struct_field(self, field: StructField): self._p_comment(field.description) self._p_indented(lambda: self._p_str(f"{field.ty} {field.name};")) def p_functions(self, cheatcodes: list[Cheatcode]): for cheatcode in cheatcodes: self._p_line(lambda: self.p_function(cheatcode.func)) def p_function(self, func: Function): self._p_comment(func.description, doc=True) self._p_line(lambda: self._p_str(func.declaration)) def _p_comment(self, s: str, doc: bool = False): s = s.strip() if s == "": return s = map(lambda line: line.lstrip(), s.split("\n")) if self.block_doc_style: self._p_str("/*") if doc: self._p_str("*") self._p_nl() for line in s: self._p_indent() self._p_str(" ") if doc: self._p_str("* ") self._p_str(line) self._p_nl() self._p_indent() self._p_str(" */") self._p_nl() else: first_line = True for line in s: if not first_line: self._p_indent() first_line = False if doc: self._p_str("/// ") else: self._p_str("// ") self._p_str(line) self._p_nl() def _with_indent(self, f: VoidFn): self._inc_indent() f() self._dec_indent() def _p_line(self, f: VoidFn): self._p_indent() f() self._p_nl() def _p_indented(self, f: VoidFn): self._p_indent() f() def _p_indent(self): for _ in range(self.indent_level): self._p_str(self._indent_str) def _p_nl(self): self._p_str(self.nl_str) def _p_str(self, txt: str): self.buffer += txt def _inc_indent(self): self.indent_level += 1 def _dec_indent(self): self.indent_level -= 1 if __name__ == "__main__": main()