{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e34afb69-8f87-4073-9560-cb80c67e6ae5", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import sklearn\n", "from sklearn import linear_model\n", "import polars\n", "from sklearn.model_selection import train_test_split\n", "from scipy import sparse" ] }, { "cell_type": "code", "execution_count": 24, "id": "b02b8789-e927-4df5-96e1-d26706ff5fa2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(49998, 4098)\n", "(599976, 4096)\n", "(599976,)\n", "(449982, 4096) (149994, 4096)\n" ] } ], "source": [ "\n", "df_it = polars.read_csv_batched(\"train.csv\", has_header=True)\n", "ys = []\n", "Xs = []\n", "idxs = []\n", "for i in range(12):\n", " df = df_it.next_batches(1)[0]\n", " df[0, 0:10]\n", " print(df.shape)\n", " y = df[:, 1].to_numpy()\n", " ys.append(y)\n", " X = sparse.csr_matrix(np.float32(df[:, 2:].to_numpy()))\n", " X = sklearn.preprocessing.normalize(X)\n", " Xs.append(X)\n", " idx = df[:, 0]\n", " idxs.append(idx)\n", "\n", "X = sparse.vstack(Xs)\n", "y = np.hstack(ys)\n", "\n", "del Xs\n", "del ys\n", "\n", "print(X.shape)\n", "print(y.shape)\n", "(X_train, X_test, y_train, y_test) = train_test_split(X, y, shuffle=False)\n", "del X\n", "del y\n", "n_train = X_train.shape[0]\n", "print(X_train.shape, X_test.shape)\n", "\n", "\n", "# print(df.shape)\n", "# print(df[:, 2:].mean())\n" ] }, { "cell_type": "code", "execution_count": 26, "id": "4e6d2504-69d1-4135-844f-1d80a1d96e04", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "160, 171, 173, 187, 192, 196, 199, 200, 201, 202, 205, 214, 220, 223, 224, 225, 226, 227, 228, 231, 232, 233, 234, 235, 236, 237, 238, 239, 242, 243, 244, 245, 246, 249, 250, 251, 252, 333, 339, 8201, 8211, 8212, 8217, 8220, 8221, 8222, 8239\n" ] } ], "source": [ "from collections import Counter\n", "counter = Counter()\n", "\n", "c = 0\n", "for line in open(\"dataset/archive/sentences.prepared.csv\"):\n", " (rid, lang, sentence) = line.strip().split(\"\\t\", 2)\n", " if lang not in {\"fra\",\"eng\", \"ita\", \"deu\", \"esp\", \"por\"}:\n", " continue\n", " c += 1\n", " if c > 100_000:\n", " break\n", " for chr in sentence:\n", " if ord(chr) > 128:\n", " counter[chr] += 1\n", "letters = sorted(ord(letter) for (letter, count) in counter.most_common(100) if count >= 10)\n", "print(\", \".join(map(str, letters)))\n", "#print(letters)" ] }, { "cell_type": "code", "execution_count": 27, "id": "5f7408bc-da6b-4608-84e0-af67b203a748", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", " This problem is unconstrained.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "RUNNING THE L-BFGS-B CODE\n", "\n", " * * *\n", "\n", "Machine precision = 2.220D-16\n", " N = 65552 M = 10\n", "\n", "At X0 0 variables are exactly at the bounds\n", "\n", "At iterate 0 f= 1.24762D+06 |proj g|= 2.45139D+04\n", "\n", "At iterate 50 f= 1.21037D+04 |proj g|= 3.55325D+02\n", "\n", "At iterate 100 f= 4.94206D+03 |proj g|= 1.13277D+02\n", "\n", "At iterate 150 f= 3.82850D+03 |proj g|= 1.43956D+01\n", "\n", "At iterate 200 f= 3.43115D+03 |proj g|= 1.99731D+01\n", "\n", "At iterate 250 f= 3.29184D+03 |proj g|= 1.39921D+01\n", "\n", "At iterate 300 f= 3.23884D+03 |proj g|= 4.21846D+00\n", "\n", "At iterate 350 f= 3.21968D+03 |proj g|= 1.86934D+00\n", "\n", "At iterate 400 f= 3.21289D+03 |proj g|= 7.79969D-01\n", "\n", "At iterate 450 f= 3.21017D+03 |proj g|= 1.19910D+00\n", "\n", "At iterate 500 f= 3.20925D+03 |proj g|= 7.50076D-01\n", "\n", " * * *\n", "\n", "Tit = total number of iterations\n", "Tnf = total number of function evaluations\n", "Tnint = total number of segments explored during Cauchy searches\n", "Skip = number of BFGS updates skipped\n", "Nact = number of active bounds at final generalized Cauchy point\n", "Projg = norm of the final projected gradient\n", "F = final function value\n", "\n", " * * *\n", "\n", " N Tit Tnf Tnint Skip Nact Projg F\n", "65552 500 549 1 0 0 7.501D-01 3.209D+03\n", " F = 3209.2533390579683 \n", "\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT \n", "\n", "\n", "\n", "\n", "\n", "----------\n", "64\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/fulmicoton/miniconda3/lib/python3.9/site-packages/sklearn/linear_model/_logistic.py:458: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 8.0min finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0.9997377672884693\n", "0.9961465125271678\n" ] } ], "source": [ "for C in [64]:\n", " model = sklearn.linear_model.LogisticRegression(max_iter=500, penalty='l2', multi_class='multinomial', C=C, verbose=1, class_weight='balanced') #, l1_ratio=0.1,) # penalty='elasticnet', solver='saga') \n", " model.fit(X_train, y_train)\n", " print(\"\\n\\n\\n\\n\\n----------\")\n", " print(C)\n", " print((model.predict(X_train) == y_train).mean())\n", " print((model.predict(X_test) == y_test).mean())" ] }, { "cell_type": "code", "execution_count": 21, "id": "d47b1c8a-2d26-41aa-a537-91c86ecf77c6", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", " 'spa' 'swe' 'tur' 'vie']\n", "[[13558 13 4 2 26 1 0 3]\n", " [ 20 36656 23 10 27 10 0 20]\n", " [ 2 9 11541 9 2 5 0 11]\n", " [ 1 10 25 18634 4 28 0 41]\n", " [ 14 9 2 4 3663 1 0 0]\n", " [ 3 3 16 28 0 9036 0 104]\n", " [ 0 0 0 0 0 0 21220 0]\n", " [ 1 4 10 29 1 68 0 8532]]\n" ] }, { "data": { "text/plain": [ "array([[ 209, 0, 0],\n", " [ 0, 5127, 3],\n", " [ 0, 1, 1704]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import time\n", "time.sleep(100)\n", "print(\"start\")from sklearn import metrics\n", "print(model.classes_)\n", "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", "\n", "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" ] }, { "cell_type": "code", "execution_count": 30, "id": "6ba8d76f-4dde-45d9-b07d-b12b4355b0d6", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", " 'spa' 'swe' 'tur' 'vie']\n", "[[13565 7 2 2 22 3 0 3]\n", " [ 13 36698 8 12 20 7 0 9]\n", " [ 1 6 11555 7 1 4 0 5]\n", " [ 3 6 20 18666 4 22 0 29]\n", " [ 11 12 2 4 3665 5 0 1]\n", " [ 4 3 6 22 1 9064 0 89]\n", " [ 0 0 0 0 0 0 21220 0]\n", " [ 2 3 9 34 0 58 0 8539]]\n" ] }, { "data": { "text/plain": [ "array([[ 209, 0, 0],\n", " [ 0, 5128, 3],\n", " [ 0, 2, 1703]])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "from sklearn import metrics\n", "print(model.classes_)\n", "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", "\n", "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" ] }, { "cell_type": "code", "execution_count": 26, "id": "0efb14c6-50db-4bd6-bb33-cd2dcee69493", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", " 'spa' 'swe' 'tur' 'vie']\n", "[[13562 11 2 3 21 2 0 3]\n", " [ 14 36684 11 13 21 5 0 15]\n", " [ 1 6 11551 8 2 4 0 7]\n", " [ 3 6 21 18662 5 22 0 30]\n", " [ 11 12 2 4 3666 5 0 0]\n", " [ 4 3 8 23 1 9060 0 90]\n", " [ 0 0 0 0 0 0 21220 0]\n", " [ 2 3 8 35 0 57 0 8539]]\n" ] }, { "data": { "text/plain": [ "array([[ 209, 0, 0],\n", " [ 0, 5128, 3],\n", " [ 0, 1, 1704]])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn import metrics\n", "print(model.classes_)\n", "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", "\n", "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" ] }, { "cell_type": "code", "execution_count": 16, "id": "c4252d1b-1ed4-4ffb-8f6a-0196b32622f9", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['ara' 'cmn' 'deu' 'eng' 'fra' 'hin' 'ita' 'jpn' 'kor' 'nld' 'por' 'rus'\n", " 'spa' 'swe' 'tur' 'vie']\n", "[[ 6866 8 2 6 10 2 0 7]\n", " [ 4 18245 22 11 30 3 0 22]\n", " [ 0 2 5838 9 3 7 0 12]\n", " [ 1 6 14 9446 0 27 0 40]\n", " [ 10 7 1 3 1770 0 0 4]\n", " [ 1 2 10 31 4 4517 0 58]\n", " [ 0 0 0 0 0 0 10485 0]\n", " [ 2 4 10 34 1 43 0 4133]]\n" ] }, { "data": { "text/plain": [ "array([[ 101, 0, 0],\n", " [ 0, 2561, 5],\n", " [ 0, 0, 812]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn import metrics\n", "print(model.classes_)\n", "print(sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['deu', 'eng', 'fra', 'ita','nld', 'por', 'rus', 'spa']))\n", "\n", "sklearn.metrics.confusion_matrix(y_test, model.predict(X_test), labels=['kor', 'jpn', 'cmn'])" ] }, { "cell_type": "code", "execution_count": 27, "id": "476077a6-2110-447a-8197-f0848777a508", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(565,)\n", "(75000,)\n", "ita por 8484775\n", "ita spa 3553712\n", "ita por 2614727\n", "ita spa 3213159\n", "tur nld 4129221\n", "por spa 6992976\n", "por spa 7888818\n", "deu fra 843713\n", "eng swe 1164430\n", "por spa 972364\n" ] } ], "source": [ "y_predict = model.predict(X_test)\n", "print(np.where((y_predict == y_test) == False)[0].shape)\n", "print(y_predict.shape)\n", "i = 0 \n", "for row in list(np.where((y_predict == y_test) == False))[0]:\n", " i += 1\n", " print(y_test[row], y_predict[row], idx[int(n_train + row)])\n", " if i == 10:\n", " break" ] }, { "cell_type": "code", "execution_count": 31, "id": "ac15d1d4-c839-4e96-89cc-724af639d231", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(16, 4096)\n", "-0.0047384733\n", "[-4.73847311e-03 -2.28266937e-03 -5.36394365e-01 5.88218350e-01\n", " -1.03133002e-01 -5.35878639e-04 -2.13580000e+00 -2.86555783e-03\n", " -1.19892844e-03 -3.64568806e-01 -3.39155160e-01 -5.44463552e-04\n", " 8.38150022e-01 1.53047103e+00 5.64755284e-01 -3.03773764e-02]\n" ] } ], "source": [ "(LANG, DIM) = model.coef_.shape\n", "print(model.coef_.shape)\n", "coef = np.float32(model.coef_)\n", "\n", "print(coef[0,0])\n", "print(model.coef_[:,0])\n", "\n", "f = open(\"src/weights.rs\", \"w\")\n", "\n", "f.write(\"#[derive(Clone, Copy, Debug, Eq, PartialEq)]\\n\")\n", "f.write(\"pub enum Lang {\\n\")\n", "for lang in model.classes_:\n", " f.write(\"\\t%s,\\n\" % lang.capitalize(),)\n", "f.write(\"}\\n\\n\")\n", "\n", "f.write(\"\"\"\n", "impl Lang {\n", " pub fn three_letter_code(self)-> &'static str {\n", " match self {\n", "\"\"\")\n", "for lang in model.classes_:\n", " f.write(\"\\t\\t\\tLang::%s => \\\"%s\\\",\\n\" % (lang.capitalize(), lang))\n", "f.write(\"\\t\\t}\\t}\\n}\\n\\n\\n\")\n", "\n", "\n", "f.write(\"pub const LANGUAGES: [Lang; %d] = [\\n\\t\" % LANG)\n", "for lang in model.classes_:\n", " f.write(\"Lang::%s, \" % lang.capitalize())\n", "f.write(\"];\\n\\n\")\n", "\n", "f.write(\"pub const WEIGHTS: [f32; %d] = [\\n\" % (LANG * DIM))\n", "for i in range(DIM):\n", " f.write(\"\\t\")\n", " for val in coef[:, i]:\n", " f.write(\"%f, \" % val)\n", " f.write(\"\\n\")\n", "f.write(\"];\\n\\n\")\n", "\n", "\n", "f.write(\"pub const INTERCEPTS: [f32; %d] = [\\n\\t\" % LANG)\n", "for val in model.intercept_:\n", " f.write(\"%f, \" % val)\n", "f.write(\"];\\n\\n\")\n", "\n", "\n", "f.flush()\n", "f.close()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }