{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = np.linspace(0, 7, 1000)\n", "y = np.sin(x)\n", "data = list(zip(x, y))\n", "plt.scatter(x, y)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def slope(p1, p2):\n", " x1, y1 = p1\n", " x2, y2 = p2\n", " return (y2 - y1) / (x2 - x1)\n", "\n", "# a x + b = y\n", "# ax + b - y = 0\n", "# ax - y = -b\n", "\n", "def line(p1, p2):\n", " a = slope(p1, p2)\n", " b = -a * p1[0] + p1[1]\n", " return (a,b)\n", "\n", "def intersection(l1, l2):\n", " a, c = l1\n", " b, d = l2\n", " \n", " return ((d - c) / (a - b)), ((a*d - b*c)/(a - b))\n", "\n", "def above(pt, line):\n", " return pt[1] > line[0] * pt[0] + line[1]\n", "\n", "def below(pt, line):\n", " return pt[1] < line[0] * pt[0] + line[1]\n", "\n", "def upper_bound(pt, gamma):\n", " return (pt[0], pt[1] + gamma)\n", "\n", "def lower_bound(pt, gamma):\n", " return (pt[0], pt[1] - gamma)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class GreedyPLR:\n", " def __init__(self, gamma):\n", " self.__state = \"need2\"\n", " self.__gamma = gamma\n", " \n", " def process(self, pt):\n", " self.__last_pt = pt\n", " if self.__state == \"need2\":\n", " self.__s0 = pt\n", " self.__state = \"need1\"\n", " elif self.__state == \"need1\":\n", " self.__s1 = pt\n", " self.__setup()\n", " self.__state = \"ready\"\n", " elif self.__state == \"ready\":\n", " return self.__process(pt)\n", " else:\n", " assert False\n", " \n", " def __setup(self):\n", " self.__rho_lower = line(upper_bound(self.__s0, self.__gamma),\n", " lower_bound(self.__s1, self.__gamma))\n", " self.__rho_upper = line(lower_bound(self.__s0, self.__gamma),\n", " upper_bound(self.__s1, self.__gamma))\n", " \n", " self.__sint = intersection(self.__rho_lower, self.__rho_upper)\n", " \n", " def __current_segment(self):\n", " segment_start = self.__s0[0]\n", " segment_stop = self.__last_pt[0]\n", " avg_slope = (self.__rho_lower[0] + self.__rho_upper[0]) / 2\n", " intercept = -avg_slope * self.__sint[0] + self.__sint[1]\n", " return (segment_start, segment_stop, avg_slope, intercept)\n", " \n", " def __process(self, pt):\n", " if not (above(pt, self.__rho_lower) and below(pt, self.__rho_upper)):\n", " # we have to start a new segment.\n", " prev_segment = self.__current_segment()\n", " \n", " self.__s0 = pt\n", " self.__state = \"need1\"\n", " \n", " # return the previous segment\n", " return prev_segment\n", " \n", " # we can tweak our extreme slopes to account for this point.\n", " # if this point's upper bound is below the current rho_upper,\n", " # we have to change rho_upper.\n", "\n", " s_upper = upper_bound(pt, self.__gamma)\n", " s_lower = lower_bound(pt, self.__gamma)\n", " if below(s_upper, self.__rho_upper):\n", " self.__rho_upper = line(self.__sint, s_upper)\n", " \n", " # if this point's lower bound is above the current rho_lower,\n", " # we have to change rho_lower\n", " if above(s_lower, self.__rho_lower):\n", " self.__rho_lower = line(self.__sint, s_lower)\n", " \n", " return None\n", " \n", " def finish(self):\n", " if self.__state == \"need2\":\n", " self.__state = \"finished\"\n", " return None\n", " elif self.__state == \"need1\":\n", " self.__state = \"finished\"\n", " return (self.__s0[0], self.__s0[0] + 1, 0, self.__s0[1])\n", " elif self.__state == \"ready\":\n", " self.__state = \"finished\"\n", " return self.__current_segment()\n", " else:\n", " assert False\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "77" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plr = GreedyPLR(0.0005)\n", "lines = []\n", "for pt in data:\n", " l = plr.process(pt)\n", " if l:\n", " lines.append(l)\n", " \n", "last = plr.finish()\n", "if last:\n", " lines.append(last)\n", " \n", "len(lines)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x, y)\n", "for l in lines:\n", " xl = np.linspace(l[0], l[1], 100)\n", " yl = l[2] * xl + l[3]\n", " plt.scatter(xl, yl)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(1, 1), (3, 3), (4, 3)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def update_hull(hull, upper=True):\n", " # update an upper or lower convex hull using the triangle update rule\n", " # assume the hull is sorted by x coordinate already.\n", " \n", " # take the last three points of the hull. If the middle point is\n", " # above the line connecting the other two points, remove it. If not,\n", " # repeat. When updating the lower hull, check if below the line.\n", " reversed_hull = list(reversed(hull))\n", " kept_points = []\n", " while True:\n", " if len(reversed_hull) < 3:\n", " break\n", " \n", " pt1, pt2, pt3, *_ = reversed_hull\n", " \n", " l = line(pt1, pt3)\n", " if upper and above(pt2, l):\n", " del reversed_hull[1]\n", " continue\n", " \n", " if not upper and below(pt2, l):\n", " del reversed_hull[1]\n", " continue\n", " \n", " # otherwise, pt1 gets to stay!\n", " kept_points.insert(0, reversed_hull.pop(0))\n", " \n", " \n", " while reversed_hull:\n", " kept_points.insert(0, reversed_hull.pop(0))\n", "\n", " return kept_points\n", "\n", "current_hull = [(1, 1), (2, 1), (3, 3), (4, 3)]\n", "current_hull = update_hull(current_hull, upper=False)\n", "current_hull" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def argmax(l):\n", " return max(enumerate(l), key=lambda x: x[1])[0]\n", "\n", "def argmin(l):\n", " return min(enumerate(l), key=lambda x: x[1])[0]\n", " \n", "\n", "class OptimalPLR:\n", " def __init__(self, gamma):\n", " self.__state = \"need2\"\n", " self.__gamma = gamma\n", " \n", " def process(self, pt):\n", " self.__last_pt = pt\n", " if self.__state == \"need2\":\n", " self.__s0 = pt\n", " self.__state = \"need1\"\n", " elif self.__state == \"need1\":\n", " self.__s1 = pt\n", " self.__setup()\n", " self.__state = \"ready\"\n", " elif self.__state == \"ready\":\n", " return self.__process(pt)\n", " else:\n", " assert False\n", " \n", " def __setup(self):\n", " self.__rho_lower = line(upper_bound(self.__s0, self.__gamma),\n", " lower_bound(self.__s1, self.__gamma))\n", " self.__rho_upper = line(lower_bound(self.__s0, self.__gamma),\n", " upper_bound(self.__s1, self.__gamma))\n", " \n", " self.__upper_hull = [upper_bound(self.__s0, self.__gamma),\n", " upper_bound(self.__s1, self.__gamma)]\n", " self.__lower_hull = [lower_bound(self.__s0, self.__gamma),\n", " lower_bound(self.__s1, self.__gamma)]\n", " def __current_segment(self):\n", " sint = intersection(self.__rho_lower, self.__rho_upper)\n", " segment_start = self.__s0[0]\n", " segment_stop = self.__last_pt[0]\n", " avg_slope = (self.__rho_lower[0] + self.__rho_upper[0]) / 2\n", " intercept = -avg_slope * sint[0] + sint[1]\n", " return (segment_start, segment_stop, avg_slope, intercept)\n", " \n", " def __process(self, pt):\n", " if not (above(pt, self.__rho_lower) and below(pt, self.__rho_upper)):\n", " # we have to start a new segment.\n", " prev_segment = self.__current_segment()\n", " \n", " self.__s0 = pt\n", " self.__state = \"need1\"\n", " \n", " # return the previous segment\n", " return prev_segment\n", " \n", " # we can tweak our extreme slopes to account for this point.\n", " # if this point's upper bound is below the current rho_upper,\n", " # we have to change rho_upper.\n", "\n", " s_upper = upper_bound(pt, self.__gamma)\n", " s_lower = lower_bound(pt, self.__gamma)\n", " if below(s_upper, self.__rho_upper):\n", " # find the point in the lower hull that would minimize\n", " # the slope between that point and s_upper. \n", " resulting_slopes = [line(x, s_upper)[0] for x in self.__lower_hull]\n", " idx = argmin(resulting_slopes)\n", " self.__rho_upper = line(self.__lower_hull[idx], s_upper)\n", " \n", " # remove everything from the hull prior to that point, add new point\n", " self.__lower_hull = self.__lower_hull[idx:]\n", " self.__lower_hull.append(s_lower)\n", " self.__lower_hull = update_hull(self.__lower_hull, upper=False)\n", "\n", " \n", " # if this point's lower bound is above the current rho_lower,\n", " # we have to change rho_lower\n", " if above(s_lower, self.__rho_lower):\n", " # find the point in the upper hull that would maximize\n", " # the slope between the point and s_lower\n", " resulting_slopes = [line(x, s_lower)[0] for x in self.__upper_hull]\n", " idx = argmax(resulting_slopes)\n", " self.__rho_lower = line(self.__upper_hull[idx], s_lower)\n", " \n", " # remove everything from the hull prior to that point, add new point\n", " self.__upper_hull = self.__upper_hull[idx:]\n", " self.__upper_hull.append(s_upper)\n", " self.__upper_hull = update_hull(self.__upper_hull)\n", " \n", " return None\n", " \n", " def finish(self):\n", " if self.__state == \"need2\":\n", " self.__state = \"finished\"\n", " return None\n", " elif self.__state == \"need1\":\n", " self.__state = \"finished\"\n", " return (self.__s0[0], self.__s0[0] + 1, 0, self.__s0[1])\n", " elif self.__state == \"ready\":\n", " self.__state = \"finished\"\n", " return self.__current_segment()\n", " else:\n", " assert False" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plr = OptimalPLR(0.25)\n", "lines = []\n", "for pt in data:\n", " l = plr.process(pt)\n", " if l:\n", " lines.append(l)\n", " \n", "last = plr.finish()\n", "if last:\n", " lines.append(last)\n", "len(lines)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plr = OptimalPLR(0.025)\n", "lines2 = []\n", "for pt in data:\n", " l = plr.process(pt)\n", " if l:\n", " lines2.append(l)\n", " \n", "last = plr.finish()\n", "if last:\n", " lines2.append(last)\n", "len(lines2)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(20,5))\n", "\n", "\n", "plt.subplot(1, 3, 1)\n", "plt.scatter(x, y)\n", "plt.title(\"Original data (n=1000)\", size=20)\n", "plt.tick_params(axis='x', labelrotation=0, labelsize=16)\n", "plt.tick_params(axis='y', labelsize=16)\n", "\n", "plt.subplot(1, 3, 2)\n", "for l in lines:\n", " xl = np.linspace(l[0], l[1], 100)\n", " yl = l[2] * xl + l[3]\n", " plt.scatter(xl, yl)\n", "plt.title(\"Optimal PLR, δ = 0.05 (7 segments)\", size=20)\n", "plt.tick_params(axis='x', labelrotation=0, labelsize=16)\n", "plt.tick_params(axis='y', labelsize=16)\n", " \n", "plt.subplot(1, 3, 3)\n", "for l in lines2:\n", " xl = np.linspace(l[0], l[1], 100)\n", " yl = l[2] * xl + l[3]\n", " plt.scatter(xl, yl)\n", "plt.title(\"Optimal PLR, δ = 0.005 (22 segments)\", size=20)\n", "plt.tick_params(axis='x', labelrotation=0, labelsize=16)\n", "plt.tick_params(axis='y', labelsize=16)\n", "plt.savefig(\"plot.png\")\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from_rust = [\n", "(0, 0.224, 0.9925214438051018, 0.00014488723457774244),\n", "(0.224, 0.371, 0.9564359979470114, 0.00811544111468232),\n", "(0.371, 0.49000000000000005, 0.9098257927831952, 0.025184110625250444),\n", "(0.49000000000000005, 0.602, 0.8560902362553744, 0.051374096095132604),\n", "(0.602, 0.7000000000000001, 0.7973679477778012, 0.08644199332413038),\n", "(0.7000000000000001, 0.798, 0.7345403303455742, 0.13026493807044426),\n", "(0.798, 0.889, 0.6672938613391461, 0.18364968490552003),\n", "(0.889, 0.9800000000000001, 0.5968809282166496, 0.2460461759022603),\n", "(0.9800000000000001, 1.064, 0.52453963029918, 0.3166206305259157),\n", "(1.064, 1.1480000000000001, 0.4512773049507166, 0.3943392750949695),\n", "(1.1480000000000001, 1.232, 0.37483263881525447, 0.48185046990923225),\n", "(1.232, 1.316, 0.2957447083682151, 0.5790264497986415),\n", "(1.316, 1.4000000000000001, 0.21456509612801, 0.6855881183987156),\n", "(1.4000000000000001, 1.4769999999999999, 0.13535613338353905, 0.7961030925848586),\n", "(1.4769999999999999, 1.554, 0.05875237893886226, 0.9089827857523513),\n", "(1.554, 1.631, -0.018199546284158596, 1.0282975942267096),\n", "(1.631, 1.708, -0.09504361970083763, 1.153358954523449),\n", "(1.708, 1.792, -0.17476140876329194, 1.2893415834624926),\n", "(1.792, 1.8760000000000001, -0.25674332111792314, 1.4359537229597428),\n", "(1.8760000000000001, 1.9600000000000002, -0.3369097215772785, 1.5860507383639242),\n", "(1.9600000000000002, 2.044, -0.41470028452663665, 1.7382295287324734),\n", "(2.044, 2.128, -0.4895664424253739, 1.890971731375827),\n", "(2.128, 2.219, -0.5638422878887794, 2.0488442399382443),\n", "(2.219, 2.31, -0.6365330755574079, 2.20984715465532),\n", "(2.31, 2.408, -0.7064150208952986, 2.371076767650351),\n", "(2.408, 2.506, -0.7722395133713664, 2.529290458375451),\n", "(2.506, 2.6109999999999998, -0.832566340878024, 2.680265889786253),\n", "(2.6109999999999998, 2.73, -0.8890772083181679, 2.827667324426729),\n", "(2.73, 2.863, -0.9393402024298271, 2.9646838555034947),\n", "(2.863, 3.045, -0.980719181256966, 3.0830673326267295),\n", "(3.045, 3.346, -1.0587004536456222, 3.3386367109432715),\n", "(3.346, 3.5, -1.0194089527822754, 3.2163564540974594),\n", "(3.5, 3.6260000000000003, -0.969537899233178, 3.0490937912774494),\n", "(3.6260000000000003, 3.745, -0.9112648557380234, 2.8443939825940068),\n", "(3.745, 3.8500000000000005, -0.8470640051072593, 2.6097982889007594),\n", "(3.8500000000000005, 3.9549999999999996, -0.7778436510635867, 2.3488828625107105),\n", "(3.9549999999999996, 4.053, -0.7035404258383973, 2.0602635640493663),\n", "(4.053, 4.151, -0.6255241060600776, 1.7491566288953644),\n", "(4.151, 4.242, -0.5452012677617251, 1.420545750946877),\n", "(4.242, 4.333, -0.4637127145158647, 1.0795392357639941),\n", "(4.333, 4.424, -0.3788934556974868, 0.7166593459207461),\n", "(4.424, 4.515000000000001, -0.29144539501621314, 0.3344128066449319),\n", "(4.515000000000001, 4.606, -0.20211904114644377, -0.06428224649977465),\n", "(4.606, 4.69, -0.1151167858286416, -0.4606319398756733),\n", "(4.69, 4.774, -0.031174188126666322, -0.8500583320514762),\n", "(4.774, 4.858, 0.05256849223960554, -1.2455724883595696),\n", "(4.858, 4.949, 0.13893433546476147, -1.6606130712268432),\n", "(4.949, 5.04, 0.2271286146332705, -2.092412737173717),\n", "(5.04, 5.131, 0.3129166735955567, -2.520077948156951),\n", "(5.131, 5.2219999999999995, 0.3956086070533476, -2.939622030057251),\n", "(5.2219999999999995, 5.313, 0.4745201155249519, -3.346901263923985),\n", "(5.313, 5.4110000000000005, 0.5517095003597423, -3.751884706679768),\n", "(5.4110000000000005, 5.509, 0.6260229407482628, -4.148680392368301),\n", "(5.509, 5.614000000000001, 0.6958450294401022, -4.527636690886439),\n", "(5.614000000000001, 5.726, 0.761797755685506, -4.891664002832056),\n", "(5.726, 5.845, 0.8211801307009096, -5.224864212667721),\n", "(5.845, 5.984999999999999, 0.873255945628706, -5.52102481985113),\n", "(5.984999999999999, 6.167, 0.9148761058720644, -5.758745208484667),\n", "(6.167, 6.503, 0.9929896535897769, -6.239043539927865),\n", "(6.503, 6.6499999999999995, 0.9576396162376866, -6.009263861693788),\n", "(6.6499999999999995, 6.769, 0.9115503947797831, -5.702988862860086),\n", "(6.769, 6.881, 0.8582426830616655, -5.342287319093105),\n", "(6.881, 6.986, 0.7977272612145305, -4.926080687706908),\n", "(6.986, 6.993, 0.7607573751780505, -4.6682830946896905)\n", "]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x, y)\n", "for l in from_rust:\n", " xl = np.linspace(l[0], l[1], 100)\n", " yl = l[2] * xl + l[3]\n", " plt.scatter(xl, yl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }