{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import matplotlib.pyplot as plt\n", "from shapely.geometry import Polygon, LineString, Point\n", "from shapely.ops import split\n", "import numpy as np\n", "from scipy.optimize import linprog\n", "import random\n", "import graphviz" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open('trial_model.json') as in_file:\n", " asterism = json.load(in_file)\n", "input_point = np.zeros((60))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def random_color():\n", " return \"#\"+''.join([random.choice('0123456789ABCDEF') for _ in range(6)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nodes = asterism['arena']\n", "root_node = nodes[0]\n", "if 'input_bounds_opt' in asterism:\n", " bounds = asterism['input_bounds_opt']['data']\n", " bounds = np.reshape(np.array(bounds['data']), bounds['dim'])\n", " bounds = np.transpose(bounds[:,:])[-2:]\n", "print(f'{bounds=}')\n", "parents = asterism['parents']\n", "node_colors = [random_color() for _ in enumerate(parents)]\n", "\n", "cons_coeffs = []\n", "cons_rhs = []\n", "for node in nodes:\n", " con = node['star']['constraints']\n", " if con is None:\n", " cons_coeffs.append(None)\n", " cons_rhs.append(None)\n", " continue\n", "\n", " coeffs = con['coeffs']\n", " coeffs = np.reshape(coeffs['data'], coeffs['dim'])\n", " rhs = con['rhs']\n", " rhs = np.reshape(rhs['data'], rhs['dim'])\n", "\n", " rhs_offset = np.dot(coeffs[:, :input_point.shape[0]], input_point)\n", " coeffs = coeffs[:, input_point.shape[0]:]\n", " rhs = rhs - rhs_offset\n", "\n", " cons_coeffs.append(coeffs)\n", " cons_rhs.append(rhs)\n", "\n", "def check_ineq(i, point):\n", " ineq_test = np.matmul(cons_coeffs[i], point) <= cons_rhs[i]\n", " return np.all(ineq_test) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "boxes = []\n", "points = []\n", "\n", "bounds_box = Polygon([(bounds[0,0], bounds[1,0]), (bounds[0,1], bounds[1,0]), (bounds[0,1], bounds[1,1]),\n", " (bounds[0,0], bounds[1,1]), (bounds[0,0], bounds[1,0])])\n", "boxes.append(bounds_box)\n", "\n", "for i, parent in enumerate(parents):\n", " if parent is None:\n", " continue\n", "\n", " if cons_coeffs[i] is None:\n", " line = None\n", " elif np.all(cons_coeffs[i][-1] == 0.) and cons_rhs[i][-1] == 0.:\n", " line = None\n", " elif np.all(cons_coeffs[i][-1] == 0.):\n", " print('Reached contradictory polygon')\n", " line = None\n", " elif isinstance(boxes[parent], Point):\n", " line = None\n", " elif cons_coeffs[i][-1][1] == 0.:\n", " x = cons_rhs[i][-1] / cons_coeffs[i][-1][0]\n", " line = LineString([(x, bounds[1, 0]), (x, bounds[1, 1])])\n", " else:\n", " slope = -cons_coeffs[i][-1][0] / cons_coeffs[i][-1][1]\n", " intercept = cons_rhs[i][-1] / cons_coeffs[i][-1][1]\n", " left = (bounds[0, 0], slope * bounds[0, 0] + intercept)\n", " right = (bounds[0, 1], slope * bounds[0, 1] + intercept)\n", " line = LineString([left, right])\n", "\n", " if line:\n", " parent_poly = boxes[parent]\n", " child_polys = split(parent_poly, line)\n", " geom_added = False\n", " for geom in child_polys.geoms:\n", " if check_ineq(i, np.array([geom.centroid.x, geom.centroid.y])):\n", " boxes.append(geom)\n", " geom_added = True\n", " break\n", " if not geom_added:\n", " res = linprog([0, 1],\n", " A_ub=cons_coeffs[i],\n", " b_ub=cons_rhs[i])\n", " assert res.success\n", " boxes.append(Point(*res.x))\n", " else:\n", " boxes.append(boxes[parent])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i, _ in enumerate(parents):\n", " print(i)\n", " i = i + 1\n", " fig, axs = plt.subplots()\n", " for geom, color in zip(boxes[:i], node_colors[:i]):\n", " if isinstance(geom, Point):\n", " c = np.asarray(geom.coords)\n", " axs.plot(c, marker='o', color=color)\n", " continue\n", "\n", " xs, ys = geom.exterior.xy\n", " axs.fill(xs, ys, color=color, ec='none')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = graphviz.Graph('graph')\n", "\n", "for i, c in enumerate(node_colors):\n", " g.node(str(i), style='filled', fillcolor=c)\n", "\n", "for i, p in enumerate(parents):\n", " if p is None:\n", " continue\n", "\n", " g.edge(str(p), str(i))\n", "\n", "g.save()" ] } ], "metadata": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" }, "kernelspec": { "display_name": "Python 3.8.6 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }