#!/usr/bin/env python3 """ Generates Rust source codes from Wayland Protocol definiton XML files. """ import argparse import xmltodict import jinja2 import os from pathlib import Path import textwrap import re import subprocess from multiprocessing.pool import ThreadPool COMMON_TEMPLATE = """\ {{ description["#text"] | comment("//! ") }} // // // GENERATED BY OUR WAYLAND-SCANNER. DO NOT EDIT! // // #![allow(unused)] #![allow(clippy::from_over_into)] #![allow(clippy::match_single_binding)] use smallvec::smallvec; use alloc::rc::Rc; use core::cell::RefCell; use alloc::string::String; use crate::wl::{ Connection, SendError, Interface, NewId, Array, Handle, Opcode, ObjectId, Message, RawMessage, Payload, PayloadType, DeserializeError, }; use crate::wl::protocols::common::{EventSet, RequestSet}; {% for name in interface_names %} {% if name != interface_name -%} use crate::wl::protocols::common::{{ name }}::{{ name | camelcase }}; {%- endif -%} {%- endfor %} macro_rules! from_optional_object_payload { ($ty:ident, $con:expr, $v:expr) => { match ($v).clone() { Payload::ObjectId(id) if id.is_null() => None, Payload::ObjectId(id) => Some($ty::new($con, id)), _ => return Err(DeserializeError::UnexpectedType), // Abort deserializing. } } } macro_rules! from_object_payload { ($ty:ident, $con:expr, $v:expr) => { match ($v).clone() { Payload::ObjectId(id) if id.is_null() => return Err(DeserializeError::ObjectIsNull), Payload::ObjectId(id) => $ty::new($con, id), _ => return Err(DeserializeError::UnexpectedType), } } } macro_rules! from_payload { ($ty:ident, $v:expr) => { match ($v).clone() { Payload::$ty(value) => value.into(), _ => return Err(DeserializeError::UnexpectedType), } } } #[derive(Debug)] pub enum Request { {% for m in requests -%} {{ m.doc }} {{ m.name | camelcase }} { {% for arg in m.args -%} {{ arg.doc }} {{ arg.name }}: {{ arg.type }}, {% endfor %} }, {% endfor %} } impl Message for Request { fn into_raw(self, sender: ObjectId) -> RawMessage { match self { {% for m in requests -%} Request::{{ m.name | camelcase }} { {{ m.args | args_list }} } => { RawMessage { sender, opcode: Opcode({{ m.opcode }}), args: smallvec![{{ m.args | args_list(".into()") }}], } } {% endfor %} } } fn from_raw(con: Rc>, m: &RawMessage) -> Result { match m.opcode { {% for msg_name, o in request_opcodes.items() -%} Opcode({{ o.opcode }}) => { Ok(Request::{{ msg_name | camelcase }} { {% for arg in o.args -%} {% if arg.is_object %} {% if arg.nullable %} {{ arg.name }}: from_optional_object_payload!({{ arg.interface | camelcase }}, con.clone(), m.args[{{ loop.index0 }}]), {% else %} {{ arg.name }}: from_object_payload!({{ arg.interface | camelcase }}, con.clone(), m.args[{{ loop.index0 }}]), {% endif %} {% else %} {{ arg.name }}: from_payload!({{ arg.payload_type }}, m.args[{{ loop.index0 }}]), {% endif %} {% endfor %} }) } {% endfor %} _ => Err(DeserializeError::UnknownOpcode), } } fn into_received_event(self, con: Rc>, id: ObjectId) -> EventSet { panic!("not a event!"); } fn into_received_request(self) -> RequestSet { RequestSet::{{ interface_name | camelcase }}(self) } } {% for enum in enums -%} {{ enum.doc }} #[repr({{ enum_types[enum.name] }})] #[derive(Copy, Clone, Debug, PartialEq)] #[non_exhaustive] pub enum {{ enum.name | camelcase }} { {% for e in enum.entries -%} {{ e.doc }} {{ e.name | enum_name }} = {{ e.value }}, {% endfor %} } impl Into for {{ enum.name | camelcase }} { fn into(self) -> Payload { Payload::{{ enum_payload_types[enum.name] }}(self as {{ enum_types[enum.name] }}) } } impl From<{{ enum_types[enum.name] }}> for {{ enum.name | camelcase }} { fn from(value: {{ enum_types[enum.name] }}) -> {{ enum.name | camelcase }} { match value { {% for e in enum.entries -%} {{ e.value }} => {{ enum.name | camelcase }}::{{ e.name | enum_name }}, {% endfor %} _ => unreachable!(), } } } {% endfor %} #[derive(Debug)] pub enum Event { {% for m in events -%} {{ m.doc }} {{ m.name | camelcase }} { {% for arg in m.args -%} {{ arg.doc }} {{ arg.name }}: {{ arg.type }}, {% endfor %} }, {% endfor %} } impl Message for Event { fn into_raw(self, sender: ObjectId) -> RawMessage { match self { {% for m in events -%} Event::{{ m.name | camelcase }} { {{ m.args | args_list }} } => { RawMessage { sender, opcode: Opcode({{ m.opcode }}), args: smallvec![{{ m.args | args_list(".into()") }}], } } {% endfor %} } } fn from_raw(con: Rc>, m: &RawMessage) -> Result { match m.opcode { {% for msg_name, o in event_opcodes.items() -%} Opcode({{ o.opcode }}) => { Ok(Event::{{ msg_name | camelcase }} { {% for arg in o.args -%} {% if arg.is_object %} {% if arg.nullable %} {{ arg.name }}: from_optional_object_payload!({{ arg.interface | camelcase }}, con.clone(), m.args[{{ loop.index0 }}]), {% else %} {{ arg.name }}: from_object_payload!({{ arg.interface | camelcase }}, con.clone(), m.args[{{ loop.index0 }}]), {% endif %} {% else %} {{ arg.name }}: from_payload!({{ arg.payload_type }}, m.args[{{ loop.index0 }}]), {% endif %} {% endfor %} }) } {% endfor %} _ => Err(DeserializeError::UnknownOpcode), } } fn into_received_event(self, con: Rc>, id: ObjectId) -> EventSet { EventSet::{{ interface_name | camelcase }}({{ interface_name | camelcase }}::new(con, id), self) } fn into_received_request(self) -> RequestSet { panic!("not a request!"); } } {{ interface["description"]["@summary"] | comment }} #[derive(Clone)] pub struct {{ interface_name | camelcase }} { con: Rc>, object_id: ObjectId, } impl PartialEq for {{ interface_name | camelcase }} { fn eq(&self, other: &Self) -> bool { self.id() == other.id() } } impl core::fmt::Debug for {{ interface_name | camelcase }} { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{{ interface_name | camelcase }}@{}", self.object_id.0) } } impl Into for {{ interface_name | camelcase }} { fn into(self) -> Payload { Payload::ObjectId(self.id()) } } impl Interface for {{ interface_name | camelcase }} { type Event = Event; type Request = Request; const NAME: &'static str = "{{ interface_name }}"; const VERSION: u32 = {{ interface_version }}; const PAYLOAD_TYPES: &'static [&'static [PayloadType]] = &[ {% for m in requests -%} &[ {% for arg in m.args -%} PayloadType::{{ arg.payload_type }}, {% endfor %} ], {% endfor %} {% for m in events -%} &[ {% for arg in m.args -%} PayloadType::{{ arg.payload_type }}, {% endfor %} ], {% endfor %} ]; fn new(con: Rc>, object_id: ObjectId) -> {{ interface_name | camelcase }} { {{ interface_name | camelcase }} { con, object_id, } } fn connection(&self) -> &Rc> { &self.con } fn id(&self) -> ObjectId { self.object_id } fn as_new_id(&self) -> NewId { NewId(self.object_id.0) } } """ CLIENT_TEMPLATE = """\ {{ description["#text"] | comment("//! ") }} // // // GENERATED BY OUR WAYLAND-SCANNER. DO NOT EDIT! // // #![allow(unused)] #![allow(clippy::from_over_into)] #![allow(clippy::match_single_binding)] use smallvec::smallvec; use alloc::rc::Rc; use core::cell::RefCell; use alloc::string::String; use crate::wl::{ Connection, SendError, Interface, NewId, Array, Handle, Opcode, ObjectId, Message, RawMessage, Payload, PayloadType, }; {% for name in interface_names %} {% if name != interface_name -%} use crate::wl::protocols::common::{{ name }}::{{ name | camelcase }}; {%- endif -%} {%- endfor %} use crate::wl::protocols::common::{{ interface_name }}::*; pub trait {{ interface_name | camelcase }}Ext { {% for m in requests -%} {{ m.doc }} fn {{ m.name | func_name }}(&self, {% for arg in m.args -%} {{ arg.name }}: {{ arg.type }}, {% endfor %} ) -> Result<(), SendError>; {% endfor %} } impl {{ interface_name | camelcase }}Ext for {{ interface_name | camelcase }} { {% for m in requests -%} {{ m.doc }} fn {{ m.name | func_name }}(&self, {% for arg in m.args -%} {{ arg.name }}: {{ arg.type }}, {% endfor %} ) -> Result<(), SendError> { self.connection().borrow_mut().send( Request::{{ m.name | camelcase }} { {% for arg in m.args -%} {{ arg.name }}, {% endfor %} }.into_raw(self.id()) ) } {% endfor %} } """ SERVER_TEMPLATE = """\ {{ description["#text"] | comment("//! ") }} // // // GENERATED BY OUR WAYLAND-SCANNER. DO NOT EDIT! // // #![allow(unused)] #![allow(clippy::from_over_into)] #![allow(clippy::match_single_binding)] use smallvec::smallvec; use alloc::rc::Rc; use core::cell::RefCell; use alloc::string::String; use crate::wl::{ Connection, SendError, Interface, NewId, Array, Handle, Opcode, ObjectId, Message, RawMessage, Payload, PayloadType, }; {% for name in interface_names %} {% if name != interface_name -%} use crate::wl::protocols::common::{{ name }}::{{ name | camelcase }}; {%- endif -%} {%- endfor %} use crate::wl::protocols::common::{{ interface_name }}::*; pub trait {{ interface_name | camelcase }}Ext { {% for m in events -%} {{ m.doc }} fn {{ m.name | func_name }}(&self, {% for arg in m.args -%} {{ arg.name }}: {{ arg.type }}, {% endfor %} ) -> Result<(), SendError>; {% endfor %} } impl {{ interface_name | camelcase }}Ext for {{ interface_name | camelcase }} { {% for m in events -%} {{ m.doc }} fn {{ m.name | func_name }}(&self, {% for arg in m.args -%} {{ arg.name }}: {{ arg.type }}, {% endfor %} ) -> Result<(), SendError> { self.connection().borrow_mut().send( Event::{{ m.name | camelcase }} { {% for arg in m.args -%} {{ arg.name }}, {% endfor %} }.into_raw(self.id()) ) } {% endfor %} } """ TYPE_MAPPING = { "new_id": "NewId", "int": "i32", "uint": "u32", "fixed": "f32", "string": "String", "array": "Array", "fd": "Handle", "object": "ObjectId", } PAYLOAD_TYPE_MAPPING = { "new_id": "NewId", "int": "Int", "uint": "UInt", "fixed": "Fixed", "string": "String", "array": "Array", "fd": "Handle", "object": "ObjectId", } FUNC_NAME_MAPPING = { "move": "move_", } def comment(s, prefix="/// "): if len(s) == 0: s = "(no document)" s = s.replace("\n", " ") s = s.replace("\t", " ") s = s.capitalize() if not s.endswith("."): s += "." wrapped = textwrap.wrap(re.sub(' +', ' ', s), 80) return textwrap.indent("\n".join(wrapped), prefix) next_opcode = None global_request_opcodes = {} global_event_opcodes = {} def alloc_opcode(interface_name, msg_name, args, is_event): global next_opcode, global_request_opcodes, global_event_opcodes opcode = next_opcode next_opcode += 1 if is_event: global_event_opcodes.setdefault(interface_name, {}) global_event_opcodes[interface_name][msg_name] = { "opcode": opcode, "args": args, } else: global_request_opcodes.setdefault(interface_name, {}) global_request_opcodes[interface_name][msg_name] = { "opcode": opcode, "args": args, } return opcode def camelcase(s): return "".join(t.title() for t in s.split("_")) def generate_interface_file(interface_names, out_dir, interface): args_list = lambda args, suffix="": ",".join(a["name"] + suffix for a in args) into_list = lambda x: [] if x is None else (x if type(x) is list else [x]) func_name = lambda n: FUNC_NAME_MAPPING.get(n, n) enum_name = lambda n: f"_{n}" if n[0].isdigit() else camelcase(n) interface_name = interface["@name"] interface_version = interface["@version"] # Reset the opcode counter. global next_opcode next_opcode = 1 # Since the protocol definition (xml file) does not specify the type # (whether int or uint) in a enum defnition, we need to determine it from # its usage in message defnitions. enum_types = {} enum_payload_types = {} def tidy_enum_defs(defs): enums = [] for enum in into_list(defs): entries = [] for e in into_list(enum.get("entry")): entries.append({ "name": e["@name"], "doc": comment(e.get("@summary", "")), "value": e["@value"], }) enums.append({ "name": enum["@name"], "entries": entries }) # Fill the default types. enum_types[enum["@name"]] = "u32" enum_payload_types[enum["@name"]] = "UInt" return enums def tidy_message_defs(defs, is_event): msgs = [] for m in into_list(defs): args = [] payload_types = [] for a in into_list(m.get("arg")): is_object = False interface = None nullable = False if "@enum" in a: enum = a["@enum"] if "." in enum: words = enum.split(".") type_ = f"super::super::common::{words[0]}::{camelcase(words[1])}" enum_name = words[1] else: type_ = camelcase(enum) enum_name = enum assert a["@type"] in ["uint", "int"] enum_types[enum_name] = "u32" if a["@type"] == "uint" else "i32" enum_payload_types[enum_name] = PAYLOAD_TYPE_MAPPING[a["@type"]] else: type_ = TYPE_MAPPING[a["@type"]] if type_ == "ObjectId": if "@interface" in a: type_ = camelcase(a["@interface"]) is_object = True interface = a["@interface"] if a.get("@allow-null", "false") == "true": type_ = f"Option<{type_}>" nullable = True args.append({ "name": a["@name"], "type": type_, "is_object": is_object, "interface": interface, "nullable": nullable, "payload_type": PAYLOAD_TYPE_MAPPING[a["@type"]], "doc": comment(a.get("@summary", "")), }) msgs.append({ "opcode": alloc_opcode(interface_name, m["@name"], args, is_event), "name": m["@name"], "doc": comment(m["description"].get("#text", m["description"]["@summary"])), "args": args, }) return msgs enums = tidy_enum_defs(interface.get("enum")) requests = tidy_message_defs(interface.get("request"), False) events = tidy_message_defs(interface.get("event"), True) request_opcodes = global_request_opcodes.get(interface_name, {}) event_opcodes = global_event_opcodes.get(interface_name, {}) env = jinja2.Environment() env.filters["camelcase"] = camelcase env.filters["comment"] = comment env.filters["args_list"] = args_list env.filters["func_name"] = func_name env.filters["enum_name"] = enum_name (out_dir / "common" / (interface_name + ".rs")).write_text( env.from_string(COMMON_TEMPLATE).render(**interface, **locals()) ) (out_dir / "client" / (interface_name + ".rs")).write_text( env.from_string(CLIENT_TEMPLATE).render(**interface, **locals()) ) (out_dir / "server" / (interface_name + ".rs")).write_text( env.from_string(SERVER_TEMPLATE).render(**interface, **locals()) ) def generate_mod_file(outfile, interface_names): tmpl = jinja2.Template("""\ //! Wayland Protocol definitions. // // // GENERATED BY OUR WAYLAND-SCANNER. DO NOT EDIT! // // {% for name in interface_names -%} pub mod {{ name }}; {% endfor %} """) outfile.write_text(tmpl.render(**locals(), **globals())) def generate_common_mod_file(outfile, interface_names): tmpl = """\ //! Wayland Protocol definitions. // // // GENERATED BY OUR WAYLAND-SCANNER. DO NOT EDIT! // // {% for name in interface_names -%} pub mod {{ name }}; {% endfor %} #[derive(Debug)] pub enum EventSet { {% for name in interface_names -%} {{ name | camelcase }}({{ name }}::{{ name | camelcase }}, {{ name }}::Event), {% endfor %} } #[derive(Debug)] pub enum RequestSet { {% for name in interface_names -%} {{ name | camelcase }}({{ name }}::Request), {% endfor %} } """ env = jinja2.Environment() env.filters["camelcase"] = camelcase outfile.write_text(env.from_string(tmpl).render(**locals(), **globals())) def main(): parser = argparse.ArgumentParser() parser.add_argument("out_dir") parser.add_argument("protocol_xmls", nargs="+") args = parser.parse_args() interface_names = set() interfaces = [] for protocol_xml in args.protocol_xmls: xml = xmltodict.parse(Path(protocol_xml).read_text()) for interface in xml["protocol"]["interface"]: name = interface["@name"] interface_names.add(name) interfaces.append(interface) out_dir = Path(args.out_dir) os.makedirs(out_dir / "common", exist_ok=True) os.makedirs(out_dir / "client", exist_ok=True) os.makedirs(out_dir / "server", exist_ok=True) thread_pool = ThreadPool() interface_names = list(sorted(interface_names)) for interface in interfaces: name = interface["@name"] thread_pool.apply( generate_interface_file, [interface_names, out_dir, interface] ) thread_pool.close() thread_pool.join() generate_common_mod_file(out_dir / "common" / "mod.rs", interface_names) generate_mod_file(out_dir / "client" / "mod.rs", interface_names) generate_mod_file(out_dir / "server" / "mod.rs", interface_names) generate_mod_file(out_dir / "mod.rs", ["common", "client", "server"]) subprocess.check_output(["rustfmt"] + list(out_dir.glob("**/*.rs"))) if __name__ == "__main__": main()