In [741]:
from stringcase import pascalcase, snakecase, camelcase
from json import loads, load
from copy import deepcopy
from dataclasses import dataclass
import re

In [742]:
from enum import Enum
from typing import NamedTuple, List, Dict, Optional, Tuple, Any, Set

In [743]:
with open("../swagger.json") as f:
 swagger = load(f)

In [744]:
swagger.keys()

dict_keys(['swagger', 'info', 'basePath', 'paths', 'tags', 'consumes', 'host', 'produces', 'definitions', 'securityDefinitions', 'security'])

In [745]:
PROVIDED_ENUMS: Set[str] = {
 'side', 'pegPriceType', 
 'ordType', 'timeInForce', 
 'execInst', 'contingencyType', 
 'binSize'
}

def type_postfix(defs_: List[StructDef]):
 STRING_TYPEFIX = {
 "PutOrderBulkRequest": {
 "orders": "Vec"
 },
 "PostOrderBulkRequest": {
 "orders": "Vec"
 }
 }
 
 for def_ in defs_:
 if def_.name in STRING_TYPEFIX:
 for field in def_.fields:
 if field.name[1] in STRING_TYPEFIX[def_.name]:
 toty = STRING_TYPEFIX[def_.name][field.name[1]]
 print(f"Replacing {field.ty} to {toty} for {def_.name}.{field.name[1]}")
 field.ty = toty

SIGNED_ENDPOINTS = {
 "/announcement/urgent": ["get"],
 "/apiKey": ["get"],
 "/chat": ["post"],
 "/execution": ["get"],
 "/execution/tradeHistory": ["get"],
 "/globalNotification": ["get"],
 "/leaderboard/name": ["get"],
 "/order": ["get", "put", "post", "delete"],
 "/order/bulk": ["put", "post"],
 "/order/closePosition": ["post"],
 "/order/all": ["delete"],
 "/order/cancelAllAfter": ["post"],
 "/position": ["get"],
 "/position/isolate": ["post"],
 "/position/riskLimit": ["post"],
 "/position/transferMargin": ["post"],
 "/position/leverage": ["post"],
 "/user": ["get"],
 "/user/affiliateStatus": ["get"],
 "/user/commission": ["get"],
 "/user/communicationToken": ["post"],
 "/user/executionHistory": ["get"],
 "/user/depositAddress": ["get"],
 "/user/margin": ["get"],
 "/user/preferences": ["post"],
 "/user/quoteFillRatio": ["get"],
 "/user/requestWithdrawal": ["post"],
 "/user/wallet": ["get"],
 "/user/walletHistory": ["get"],
 "/user/walletSummary": ["get"],
 "/userEvent": ["get"],
}

string_formats = {
 "date-time": "DateTime",
 "guid": "Uuid",
 None: "String",
 "JSON": "Value",
}

number_formats = {
 "int64": "i64",
 "int32": "i32",
 "double": "f64"
}

def rustify(s: str) -> str:
 s = s.replace("ID", "Id")
 s = snakecase(s)
 if s == "type":
 s = "r#type"
 return s

# Definitions

In [769]:
@dataclass
class FieldDef:
 name: Tuple[str, str]
 ty: str
 optional: Optional[bool]
 modifiers: Dict[str, str]
 desc: Optional[str] = None

 @classmethod
 def from_swagger(cls, parent_name: str, name: str, tydesc: Dict[str, Any]) -> Tuple[FieldDef, List[StructDef]]:
 common_keys = ["required", "type", "default", "description"]
 if "type" in tydesc:
 ty = tydesc["type"] 
 desc = tydesc.get("description")
 
 optional = not tydesc.get("required", False)

 if ty == "string":
 assert_keys(name, tydesc, *common_keys, "enum", "maxLength", "format")

 if name in PROVIDED_ENUMS:
 rty = f"super::{pascalcase(name)}"
 else:
 rty = string_formats[tydesc.get("format")]

 modifiers = {}
 
 return FieldDef((rustify(name), name), rty, optional, modifiers, desc), []

 elif ty == "number":
 assert_keys(name, tydesc, *common_keys, "format", "minimum")

 modifiers = {}
 if "defaulyt" in tydesc and tydesc["default"] == 0:
 modifiers["default"] = None
 
 rty = number_formats[tydesc.get("format")]

 return FieldDef((rustify(name), name), rty, optional, modifiers, desc), []
 
 elif ty == "boolean":
 assert_keys(name, tydesc, *common_keys)

 modifiers = {}
 if "default" in tydesc and not tydesc["default"]:
 modifiers["default"] = None
 
 return FieldDef((rustify(name), name), "bool", optional, modifiers, desc), []
 
 elif ty == "object":
 assert_keys(name, tydesc, *common_keys, "properties")
 sdfs = StructDef.from_swagger(parent_name, name, tydesc)
 
 modifiers = {}
 
 assert sdfs
 
 fdf = FieldDef((rustify(name), name), sdfs[0].name, optional, modifiers, desc)
 
 return fdf, sdfs

 elif ty == "array":
 assert_keys(name, tydesc, *common_keys, "items")
 
 items = tydesc["items"]
 
 fdf, sdfs = FieldDef.from_swagger(parent_name, name, items)
 fdf.ty = f"Vec<{fdf.ty}>"
 if name == "disableEmails":
 print(fdf)
 
 if ("default" in tydesc and tydesc["default"] == []) or "default" not in tydesc:
 fdf.modifiers["default"] = None
 else:
 raise NotImplementedError(tydesc)
 
 return fdf, sdfs
 
 elif ty == "null":
 return FieldDef((rustify(name), name), "()", False, {}, None), []
 else:
 raise RuntimeError(f"Unimplemented for {ty}")


 elif "$ref" in tydesc:
 assert_keys(name, tydesc, "$ref", *common_keys)
 ref = tydesc["$ref"]

 if ref.startswith("#/definitions/"):
 ty = ref.lstrip("#/definitions/")
 if ty == "x-any":
 ty = "Value"
 return FieldDef((rustify(name), name), ty, False, {}, None), []
 else:
 raise NotImplementedError
 else:
 raise NotImplementedError(f"{name}, {tydesc}")
 
 def can_default(self) -> bool:
 if "default" in self.modifiers or self.optional:
 return True
 return False
 
 def __str__(self) -> str:
 if self.optional:
 ty = f"Option<{self.ty}>"
 else:
 ty = self.ty
 
 mods = []
 if self.name[0] != self.name[1]:
 mods.append(f"rename = \"{self.name[1]}\"")
 
 for mod, modv in fdef.modifiers.items():
 if modv is None:
 mods.append(mod)
 else:
 mods.append(f"{mod} = \"{modv}\"")
 
 serde_header = ""
 if mods:
 mods = ", ".join(mods)
 serde_header = f"#[serde({mods})]\n"
 field = serde_header + f"""pub {self.name[0]}: {ty}"""
 if self.desc:
 desc = self.desc.replace("\n", " ")
 field = f"""/// {desc}\n""" + field
 
 return field
 
@dataclass
class StructDef:
 name: str
 fields: List[FieldDef]
 derives: List[str]
 desc: Optional[str] = None
 value: bool = False
 
 @classmethod
 def from_swagger(cls, parent_name: str, name: str, defs: Dict[str, Any]) -> List[StructDef]:
 assert defs["type"] == "object"
 derives = ["Clone", "Debug", "Deserialize", "Serialize"]
 
 desc = defs.get("description")
 
 
 
 if "properties" in defs and not defs["properties"]:
 return [StructDef(f"{parent_name}{pascalcase(name)}", [], derives + ["Default"], desc, True)]
 
 if name == "x-any":
 return [StructDef(f"{parent_name}XAny", [], derives + ["Default"], desc, True)]
 
 if "properties" not in defs:
 defs["properties"] = {}
 
 required = set(defs.get("required", []))
 fdfs, sdfs = [], []
 for subname, def_ in defs["properties"].items():
 fdf, sdfs_ = FieldDef.from_swagger(f"{parent_name}{pascalcase(subname)}", subname, def_)
 
 sdfs.extend(sdfs_)
 fdfs.append(fdf)
 if fdf.name[1] == "disableEmails":
 print(fdf)
 print(fdf.optional, required, fdf.modifiers)
 
 fdf.optional &= fdf.name[1] not in required
 
 fields_can_default = all([fdf.can_default() for fdf in fdfs])
 if fields_can_default:
 derives.append("Default")
 
 
 return [StructDef(f"{parent_name}{pascalcase(name)}", fdfs, derives, desc), *sdfs]
 
 def __str__(self) -> str:
 if self.desc:
 desc = self.desc.replace("\n", " ")
 desc = f"\n/// {desc}"
 else:
 desc = ""
 
 derives = ", ".join(self.derives)
 
 if self.value:
 code = f"""#[derive({derives})]{desc}
pub struct {self.name}(serde_json::Value);"""
 return code
 elif not self.fields:
 code = f"""#[derive({derives})]{desc}
pub struct {self.name};"""
 return code
 
 
 fields = [str(fdef).replace("\n", "\n ") for fdef in self.fields]
 fields = ",\n ".join(fields)
 
 code = f"""#[derive({derives})]{desc}
pub struct {self.name} {{
 {fields}
}}"""
 return code

In [770]:
defs_ = []

for name, defs in swagger["definitions"].items():
 defs_.extend(StructDef.from_swagger("", name, defs))
type_postfix(defs_)

#[serde(rename = "disableEmails")]
pub disable_emails: Option>
#[serde(rename = "disableEmails")]
pub disable_emails: Option>
True set() {'default': None}


In [771]:
with open("../src/models/definitions.rs", "w") as f:
 f.write("""use chrono::{DateTime, Utc};
use serde_json::Value;
use uuid::Uuid;
use serde::{Deserialize, Serialize};
""")
 f.write("\n".join([str(d) for d in defs_]))


# Paths

In [772]:
@dataclass
class RequestImpl:
 method: str
 endpoint: str
 
 reqty: str
 respty: str
 signed: bool
 has_payload: bool

 def __str__(self) -> str:
 signed = "false"
 if self.signed:
 signed = "true"
 
 has_payload = "false"
 if self.has_payload:
 has_payload = "true"
 
 return f"""impl Request for {self.reqty} {{
 const METHOD: Method = Method::{self.method.upper()};
 const SIGNED: bool = {signed};
 const ENDPOINT: &'static str = "{self.endpoint}";
 const HAS_PAYLOAD: bool = {has_payload};
 type Response = {self.respty};
}}"""

In [773]:
defs_ = []
impls = []

for endpoint, defs in swagger["paths"].items():
 for method, defs in defs.items():
 # Request
 cmethod = method.capitalize()
 
 reqname = pascalcase(endpoint.lstrip("/").replace("/", "_"))
 reqname = reqname.replace("_", "")
 reqty = f"{cmethod}{reqname}Request"
 
 desc = defs.get("summary", "No description")
 
 tydesc_ = {"type": "object", "description": desc}
 
 for tydesc in defs["parameters"]:
 if "properties" not in tydesc_:
 tydesc_["properties"] = {}
 
 tydesc = deepcopy(tydesc)
 tydesc.pop("in")
 name = tydesc.pop("name")
 tydesc_["properties"][name] = tydesc
 
 reqdef, = StructDef.from_swagger("", reqty, tydesc_)
 
 # Response
 respty = f"{cmethod}{reqname}Response"
 
 schema = defs["responses"]["200"]["schema"]
 
 fdf, sdfs = FieldDef.from_swagger("", respty, schema)
 
 respdefs = sdfs

 if not sdfs:
 # If no struct created, the type of the fdf is the response type
 respty = f"{fdf.ty}"
 
 
 # Impls
 signed = False
 if method in SIGNED_ENDPOINTS.get(endpoint, []):
 signed = True
 
 if len(defs["parameters"]) == 0:
 has_payload = False
 else:
 has_payload = True
 
 impl = RequestImpl(method, endpoint, reqty, respty, signed, has_payload)
 
 defs_.extend([reqdef, *respdefs])
 impls.append(impl)
 
type_postfix(defs_)

Replacing Value to Vec for PostOrderBulkRequest.orders
Replacing Value to Vec for PutOrderBulkRequest.orders


In [774]:
with open("../src/models/requests.rs", "w") as f:
 f.write("""use http::Method;
use super::Request;
use super::definitions::*;
use serde_json::Value;
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
""")
 f.write("\n".join([str(d) for d in defs_]))
 f.write("\n")
 
 f.write("\n".join([str(d) for d in impls]))