{ "cells": [ { "cell_type": "code", "execution_count": 741, "metadata": {}, "outputs": [], "source": [ "from stringcase import pascalcase, snakecase, camelcase\n", "from json import loads, load\n", "from copy import deepcopy\n", "from dataclasses import dataclass\n", "import re" ] }, { "cell_type": "code", "execution_count": 742, "metadata": {}, "outputs": [], "source": [ "from enum import Enum\n", "from typing import NamedTuple, List, Dict, Optional, Tuple, Any, Set" ] }, { "cell_type": "code", "execution_count": 743, "metadata": {}, "outputs": [], "source": [ "with open(\"../swagger.json\") as f:\n", " swagger = load(f)" ] }, { "cell_type": "code", "execution_count": 744, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['swagger', 'info', 'basePath', 'paths', 'tags', 'consumes', 'host', 'produces', 'definitions', 'securityDefinitions', 'security'])" ] }, "execution_count": 744, "metadata": {}, "output_type": "execute_result" } ], "source": [ "swagger.keys()" ] }, { "cell_type": "code", "execution_count": 745, "metadata": {}, "outputs": [], "source": [ "PROVIDED_ENUMS: Set[str] = {\n", " 'side', 'pegPriceType', \n", " 'ordType', 'timeInForce', \n", " 'execInst', 'contingencyType', \n", " 'binSize'\n", "}\n", "\n", "def type_postfix(defs_: List[StructDef]):\n", " STRING_TYPEFIX = {\n", " \"PutOrderBulkRequest\": {\n", " \"orders\": \"Vec\"\n", " },\n", " \"PostOrderBulkRequest\": {\n", " \"orders\": \"Vec\"\n", " }\n", " }\n", " \n", " for def_ in defs_:\n", " if def_.name in STRING_TYPEFIX:\n", " for field in def_.fields:\n", " if field.name[1] in STRING_TYPEFIX[def_.name]:\n", " toty = STRING_TYPEFIX[def_.name][field.name[1]]\n", " print(f\"Replacing {field.ty} to {toty} for {def_.name}.{field.name[1]}\")\n", " field.ty = toty\n", "\n", "SIGNED_ENDPOINTS = {\n", " \"/announcement/urgent\": [\"get\"],\n", " \"/apiKey\": [\"get\"],\n", " \"/chat\": [\"post\"],\n", " \"/execution\": [\"get\"],\n", " \"/execution/tradeHistory\": [\"get\"],\n", " \"/globalNotification\": [\"get\"],\n", " \"/leaderboard/name\": [\"get\"],\n", " \"/order\": [\"get\", \"put\", \"post\", \"delete\"],\n", " \"/order/bulk\": [\"put\", \"post\"],\n", " \"/order/closePosition\": [\"post\"],\n", " \"/order/all\": [\"delete\"],\n", " \"/order/cancelAllAfter\": [\"post\"],\n", " \"/position\": [\"get\"],\n", " \"/position/isolate\": [\"post\"],\n", " \"/position/riskLimit\": [\"post\"],\n", " \"/position/transferMargin\": [\"post\"],\n", " \"/position/leverage\": [\"post\"],\n", " \"/user\": [\"get\"],\n", " \"/user/affiliateStatus\": [\"get\"],\n", " \"/user/commission\": [\"get\"],\n", " \"/user/communicationToken\": [\"post\"],\n", " \"/user/executionHistory\": [\"get\"],\n", " \"/user/depositAddress\": [\"get\"],\n", " \"/user/margin\": [\"get\"],\n", " \"/user/preferences\": [\"post\"],\n", " \"/user/quoteFillRatio\": [\"get\"],\n", " \"/user/requestWithdrawal\": [\"post\"],\n", " \"/user/wallet\": [\"get\"],\n", " \"/user/walletHistory\": [\"get\"],\n", " \"/user/walletSummary\": [\"get\"],\n", " \"/userEvent\": [\"get\"],\n", "}\n", "\n", "string_formats = {\n", " \"date-time\": \"DateTime\",\n", " \"guid\": \"Uuid\",\n", " None: \"String\",\n", " \"JSON\": \"Value\",\n", "}\n", "\n", "number_formats = {\n", " \"int64\": \"i64\",\n", " \"int32\": \"i32\",\n", " \"double\": \"f64\"\n", "}\n", "\n", "def rustify(s: str) -> str:\n", " s = s.replace(\"ID\", \"Id\")\n", " s = snakecase(s)\n", " if s == \"type\":\n", " s = \"r#type\"\n", " return s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Definitions" ] }, { "cell_type": "code", "execution_count": 769, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class FieldDef:\n", " name: Tuple[str, str]\n", " ty: str\n", " optional: Optional[bool]\n", " modifiers: Dict[str, str]\n", " desc: Optional[str] = None\n", "\n", " @classmethod\n", " def from_swagger(cls, parent_name: str, name: str, tydesc: Dict[str, Any]) -> Tuple[FieldDef, List[StructDef]]:\n", " common_keys = [\"required\", \"type\", \"default\", \"description\"]\n", " if \"type\" in tydesc:\n", " ty = tydesc[\"type\"] \n", " desc = tydesc.get(\"description\")\n", " \n", " optional = not tydesc.get(\"required\", False)\n", "\n", " if ty == \"string\":\n", " assert_keys(name, tydesc, *common_keys, \"enum\", \"maxLength\", \"format\")\n", "\n", " if name in PROVIDED_ENUMS:\n", " rty = f\"super::{pascalcase(name)}\"\n", " else:\n", " rty = string_formats[tydesc.get(\"format\")]\n", "\n", " modifiers = {}\n", " \n", " return FieldDef((rustify(name), name), rty, optional, modifiers, desc), []\n", "\n", " elif ty == \"number\":\n", " assert_keys(name, tydesc, *common_keys, \"format\", \"minimum\")\n", "\n", " modifiers = {}\n", " if \"defaulyt\" in tydesc and tydesc[\"default\"] == 0:\n", " modifiers[\"default\"] = None\n", " \n", " rty = number_formats[tydesc.get(\"format\")]\n", "\n", " return FieldDef((rustify(name), name), rty, optional, modifiers, desc), []\n", " \n", " elif ty == \"boolean\":\n", " assert_keys(name, tydesc, *common_keys)\n", "\n", " modifiers = {}\n", " if \"default\" in tydesc and not tydesc[\"default\"]:\n", " modifiers[\"default\"] = None\n", " \n", " return FieldDef((rustify(name), name), \"bool\", optional, modifiers, desc), []\n", " \n", " elif ty == \"object\":\n", " assert_keys(name, tydesc, *common_keys, \"properties\")\n", " sdfs = StructDef.from_swagger(parent_name, name, tydesc)\n", " \n", " modifiers = {}\n", " \n", " assert sdfs\n", " \n", " fdf = FieldDef((rustify(name), name), sdfs[0].name, optional, modifiers, desc)\n", " \n", " return fdf, sdfs\n", "\n", " elif ty == \"array\":\n", " assert_keys(name, tydesc, *common_keys, \"items\")\n", " \n", " items = tydesc[\"items\"]\n", " \n", " fdf, sdfs = FieldDef.from_swagger(parent_name, name, items)\n", " fdf.ty = f\"Vec<{fdf.ty}>\"\n", " if name == \"disableEmails\":\n", " print(fdf)\n", " \n", " if (\"default\" in tydesc and tydesc[\"default\"] == []) or \"default\" not in tydesc:\n", " fdf.modifiers[\"default\"] = None\n", " else:\n", " raise NotImplementedError(tydesc)\n", " \n", " return fdf, sdfs\n", " \n", " elif ty == \"null\":\n", " return FieldDef((rustify(name), name), \"()\", False, {}, None), []\n", " else:\n", " raise RuntimeError(f\"Unimplemented for {ty}\")\n", "\n", "\n", " elif \"$ref\" in tydesc:\n", " assert_keys(name, tydesc, \"$ref\", *common_keys)\n", " ref = tydesc[\"$ref\"]\n", "\n", " if ref.startswith(\"#/definitions/\"):\n", " ty = ref.lstrip(\"#/definitions/\")\n", " if ty == \"x-any\":\n", " ty = \"Value\"\n", " return FieldDef((rustify(name), name), ty, False, {}, None), []\n", " else:\n", " raise NotImplementedError\n", " else:\n", " raise NotImplementedError(f\"{name}, {tydesc}\")\n", " \n", " def can_default(self) -> bool:\n", " if \"default\" in self.modifiers or self.optional:\n", " return True\n", " return False\n", " \n", " def __str__(self) -> str:\n", " if self.optional:\n", " ty = f\"Option<{self.ty}>\"\n", " else:\n", " ty = self.ty\n", " \n", " mods = []\n", " if self.name[0] != self.name[1]:\n", " mods.append(f\"rename = \\\"{self.name[1]}\\\"\")\n", " \n", " for mod, modv in fdef.modifiers.items():\n", " if modv is None:\n", " mods.append(mod)\n", " else:\n", " mods.append(f\"{mod} = \\\"{modv}\\\"\")\n", " \n", " serde_header = \"\"\n", " if mods:\n", " mods = \", \".join(mods)\n", " serde_header = f\"#[serde({mods})]\\n\"\n", " field = serde_header + f\"\"\"pub {self.name[0]}: {ty}\"\"\"\n", " if self.desc:\n", " desc = self.desc.replace(\"\\n\", \" \")\n", " field = f\"\"\"/// {desc}\\n\"\"\" + field\n", " \n", " return field\n", " \n", "@dataclass\n", "class StructDef:\n", " name: str\n", " fields: List[FieldDef]\n", " derives: List[str]\n", " desc: Optional[str] = None\n", " value: bool = False\n", " \n", " @classmethod\n", " def from_swagger(cls, parent_name: str, name: str, defs: Dict[str, Any]) -> List[StructDef]:\n", " assert defs[\"type\"] == \"object\"\n", " derives = [\"Clone\", \"Debug\", \"Deserialize\", \"Serialize\"]\n", " \n", " desc = defs.get(\"description\")\n", " \n", " \n", " \n", " if \"properties\" in defs and not defs[\"properties\"]:\n", " return [StructDef(f\"{parent_name}{pascalcase(name)}\", [], derives + [\"Default\"], desc, True)]\n", " \n", " if name == \"x-any\":\n", " return [StructDef(f\"{parent_name}XAny\", [], derives + [\"Default\"], desc, True)]\n", " \n", " if \"properties\" not in defs:\n", " defs[\"properties\"] = {}\n", " \n", " required = set(defs.get(\"required\", []))\n", " fdfs, sdfs = [], []\n", " for subname, def_ in defs[\"properties\"].items():\n", " fdf, sdfs_ = FieldDef.from_swagger(f\"{parent_name}{pascalcase(subname)}\", subname, def_)\n", " \n", " sdfs.extend(sdfs_)\n", " fdfs.append(fdf)\n", " if fdf.name[1] == \"disableEmails\":\n", " print(fdf)\n", " print(fdf.optional, required, fdf.modifiers)\n", " \n", " fdf.optional &= fdf.name[1] not in required\n", " \n", " fields_can_default = all([fdf.can_default() for fdf in fdfs])\n", " if fields_can_default:\n", " derives.append(\"Default\")\n", " \n", " \n", " return [StructDef(f\"{parent_name}{pascalcase(name)}\", fdfs, derives, desc), *sdfs]\n", " \n", " def __str__(self) -> str:\n", " if self.desc:\n", " desc = self.desc.replace(\"\\n\", \" \")\n", " desc = f\"\\n/// {desc}\"\n", " else:\n", " desc = \"\"\n", " \n", " derives = \", \".join(self.derives)\n", " \n", " if self.value:\n", " code = f\"\"\"#[derive({derives})]{desc}\n", "pub struct {self.name}(serde_json::Value);\"\"\"\n", " return code\n", " elif not self.fields:\n", " code = f\"\"\"#[derive({derives})]{desc}\n", "pub struct {self.name};\"\"\"\n", " return code\n", " \n", " \n", " fields = [str(fdef).replace(\"\\n\", \"\\n \") for fdef in self.fields]\n", " fields = \",\\n \".join(fields)\n", " \n", " code = f\"\"\"#[derive({derives})]{desc}\n", "pub struct {self.name} {{\n", " {fields}\n", "}}\"\"\"\n", " return code" ] }, { "cell_type": "code", "execution_count": 770, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "#[serde(rename = \"disableEmails\")]\n", "pub disable_emails: Option>\n", "#[serde(rename = \"disableEmails\")]\n", "pub disable_emails: Option>\n", "True set() {'default': None}\n" ] } ], "source": [ "defs_ = []\n", "\n", "for name, defs in swagger[\"definitions\"].items():\n", " defs_.extend(StructDef.from_swagger(\"\", name, defs))\n", "type_postfix(defs_)" ] }, { "cell_type": "code", "execution_count": 771, "metadata": {}, "outputs": [], "source": [ "with open(\"../src/models/definitions.rs\", \"w\") as f:\n", " f.write(\"\"\"use chrono::{DateTime, Utc};\n", "use serde_json::Value;\n", "use uuid::Uuid;\n", "use serde::{Deserialize, Serialize};\n", "\"\"\")\n", " f.write(\"\\n\".join([str(d) for d in defs_]))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Paths" ] }, { "cell_type": "code", "execution_count": 772, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class RequestImpl:\n", " method: str\n", " endpoint: str\n", " \n", " reqty: str\n", " respty: str\n", " signed: bool\n", " has_payload: bool\n", "\n", " def __str__(self) -> str:\n", " signed = \"false\"\n", " if self.signed:\n", " signed = \"true\"\n", " \n", " has_payload = \"false\"\n", " if self.has_payload:\n", " has_payload = \"true\"\n", " \n", " return f\"\"\"impl Request for {self.reqty} {{\n", " const METHOD: Method = Method::{self.method.upper()};\n", " const SIGNED: bool = {signed};\n", " const ENDPOINT: &'static str = \"{self.endpoint}\";\n", " const HAS_PAYLOAD: bool = {has_payload};\n", " type Response = {self.respty};\n", "}}\"\"\"" ] }, { "cell_type": "code", "execution_count": 773, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Replacing Value to Vec for PostOrderBulkRequest.orders\n", "Replacing Value to Vec for PutOrderBulkRequest.orders\n" ] } ], "source": [ "defs_ = []\n", "impls = []\n", "\n", "for endpoint, defs in swagger[\"paths\"].items():\n", " for method, defs in defs.items():\n", " # Request\n", " cmethod = method.capitalize()\n", " \n", " reqname = pascalcase(endpoint.lstrip(\"/\").replace(\"/\", \"_\"))\n", " reqname = reqname.replace(\"_\", \"\")\n", " reqty = f\"{cmethod}{reqname}Request\"\n", " \n", " desc = defs.get(\"summary\", \"No description\")\n", " \n", " tydesc_ = {\"type\": \"object\", \"description\": desc}\n", " \n", " for tydesc in defs[\"parameters\"]:\n", " if \"properties\" not in tydesc_:\n", " tydesc_[\"properties\"] = {}\n", " \n", " tydesc = deepcopy(tydesc)\n", " tydesc.pop(\"in\")\n", " name = tydesc.pop(\"name\")\n", " tydesc_[\"properties\"][name] = tydesc\n", " \n", " reqdef, = StructDef.from_swagger(\"\", reqty, tydesc_)\n", " \n", " # Response\n", " respty = f\"{cmethod}{reqname}Response\"\n", " \n", " schema = defs[\"responses\"][\"200\"][\"schema\"]\n", " \n", " fdf, sdfs = FieldDef.from_swagger(\"\", respty, schema)\n", " \n", " respdefs = sdfs\n", "\n", " if not sdfs:\n", " # If no struct created, the type of the fdf is the response type\n", " respty = f\"{fdf.ty}\"\n", " \n", " \n", " # Impls\n", " signed = False\n", " if method in SIGNED_ENDPOINTS.get(endpoint, []):\n", " signed = True\n", " \n", " if len(defs[\"parameters\"]) == 0:\n", " has_payload = False\n", " else:\n", " has_payload = True\n", " \n", " impl = RequestImpl(method, endpoint, reqty, respty, signed, has_payload)\n", " \n", " defs_.extend([reqdef, *respdefs])\n", " impls.append(impl)\n", " \n", "type_postfix(defs_)" ] }, { "cell_type": "code", "execution_count": 774, "metadata": {}, "outputs": [], "source": [ "with open(\"../src/models/requests.rs\", \"w\") as f:\n", " f.write(\"\"\"use http::Method;\n", "use super::Request;\n", "use super::definitions::*;\n", "use serde_json::Value;\n", "use serde::{Deserialize, Serialize};\n", "use chrono::{DateTime, Utc};\n", "\"\"\")\n", " f.write(\"\\n\".join([str(d) for d in defs_]))\n", " f.write(\"\\n\")\n", " \n", " f.write(\"\\n\".join([str(d) for d in impls]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }