UniTO/anno3/altro_muovi/marco/.ipynb_checkpoints/one_vs_rest-checkpoint.ipynb
2024-10-29 09:11:05 +01:00

328 lines
20 KiB
Text
Executable file

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiments with the one vs rest multiclass classification scheme"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from sklearn import datasets\n",
"from sklearn.multiclass import OneVsRestClassifier as OvR\n",
"from sklearn.svm import LinearSVC\n",
"import numpy as np\n",
"import copy\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"digits = datasets.load_digits()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the \"Optical Recognition of Handwritten Digits Data Set\" from UCI (included in scikit learn and already loaded on the previous line). Let us plot the first 10 images in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAADQCAYAAAAu/itEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAT/ElEQVR4nO3dQWwc9fnG8ef9Z3FTKDW0DlHrJDgWJJUlFJranJAgUkEUDuFSxA16CRekwoncmt7CjUj0AKpockGoHEI4ICBVEzgaW00IUBxMME1WrUnUYkVRacLq/R+y1Etm5ze7453d1/X3I1Uk+9qe109m3242r2fM3QUAiOv/Bt0AACCNQQ0AwTGoASA4BjUABMegBoDgGNQAEFytkw8yswckHZC0TtLv3X1/6uNHRkZ8bGys62b+9a9/Jevnzp3LrX3/+9/PrW3atCm3tm7duuLG2jh9+rQuXrx4WhVnUmRubi631mg0cms//vGPc2s33XRTqV6WlpY0Pz/fkLSgAWZy8eLF3Nqnn36aW/vud7+bW9u+fXvpfmZnZ5ckLaqDc6VsJv/4xz+S9Xq9nlsbGhrKrU1MTOTWyj53pO4ykao7V1LPkc8++yy3dtttt/W8l4WFBV24cMHa1QoHtZmtk/Q7SfdJOifpPTN73d0/yvucsbExzczMdN3oq6++mqw/88wzubX77rsvt7Z/f/45cPPNNxc3do1Go6Hrr79ekn6hijMpcu+99+bWvvzyy9zab3/729za7t27u+6j0Who27ZtkvSRpEkNMJPjx4/n1h5++OHc2p133lnqa6Y0Gg3VarX16vBcKZvJs88+m6zv3bs3tzY6Oppb+/Of/5xbK/PckbrPRKruXEk9Rx5//PHc2muvvdbzXiYnJ3Nrnbz1cZekeXc/4+6XJb0iqftn8v+Q6elprV+/XmSybHp6+ptXGZfJZNn09LQk/YdzZRmZdK+TQT0q6WzL7881H1uz6vW6rrvuutaHyKRe1+bNm1sfWvOZSP99y+Fyy0NrPhcy6V7P/jHRzPaY2YyZzZw/f75XX3ZVI5MsMskik/bIZVkng7ouqfWl0qbmY9/i7i+6+6S7T27YsKFX/YU0OjqqK1eutD5EJqOjOnu29S9eZCL99/3f1n+ty+RCJpwrRToZ1O9Jut3MtprZkKRHJb1ebVuxTU1N6auvvhKZLJuamtInn3wiSUNksmxqakqS1nOuLCOT7hVufbj712b2pKS3dHWV5iV3/7CKZlJbHVJ6XSa12veDH/wgt/bHP/4xecxf/vKXmcdqtZq2bNmi+fn5yjMpklqle+edd3Jrx44dy62V2fqo1Wp6/vnn9dBDD22T9FdVmMmJEyeS9V27duXWhoeHc2sLCwtlW8pVq9Uk6W/qwfMntblRdB6/8MILubUnnngitzY7O5tb+/nPf548Zp5eZrJSBw8ezK2ltoD6raM9and/Q9IbFfeyqgwPD8vdtw26j0gefPBBSfrA3fP3jNamJTLJIJMu8JOJABAcgxoAgmNQA0BwDGoACI5BDQDBdbT10UupdZ/U+p2UvvLZ+Ph4bi11waZUP1L79bx+KlpFK3uxoEirR90quiDOjh07cmupizKlLlQVwZ49e3JrRautP/vZz3JrW7duza2VXcGLInXRJSm9nvfUU0/l1layylnmKoC8ogaA4BjUABAcgxoAgmNQA0BwDGoACI5BDQDBMagBILi+71GnLke6c+fO5OemdqVTUjukETz33HO5tX379iU/d2lpqdQxUzfFjS613yql91RTn1vm8q79lDr/z5w5k/zc1M8opHalU8/Xsje37afUnrSU3odO3dw2dR6lLj0sFT+n2+EVNQAEx6AGgOAY1AAQHIMaAIJjUANAcAxqAAgu1Hpe6nKkVR0zwopRatUntSIkle+/6PKPg5bqL7XOKBVfBjVP0SpXZEWrq//85z9za6n1vFTtT3/6U/KY/XpuHTlyJLf29NNPJz/3scceK3XMAwcO5Nb+8Ic/lPqaKbyiBoDgGNQAEByDGgCCY1ADQHAMagAIjkENAMF1tJ5nZguSLkpqSPra3SfLHjC1slN0R/CU1ArezMxMbu2RRx4pdbxTp07JzE6pB5kMQuru5iu8Q/kdvcgldYWx1GpUkdTqXtFVz1agJ5msROp5l1qze+KJJ3Jrzz77bPKY+/fvT5V7lsnw8HCpmiQdOnQot5Z6jqSk7nRfVjd71Lvc/ULPO1jdyKQ9cskikywy6RBvfQBAcJ0Oapf0tpnNmtmeKhtaZcikPXLJIpMsMulQp2993O3udTO7RdJRM/vY3d9t/YBm2HskacuWLT1uM57t27fr/fff30kmGR+7e24uZEImTclMpDWbS1sdvaJ293rzv19IOizprjYf86K7T7r75IYNG3rbZUBDQ0OSyKSNK1J+LmRCJk3JTJq1tZhLW4WD2sxuMLMbv/m1pPslfVB1Y5FdunRJjUZDEpm0unTpktQ8p8jlKjLJIpPudfLWx0ZJh83sm49/2d3fLHvA1FW+Umt0kvTqq6+WqqU888wzXX/O4uKi5ubmZGYn1YNM/lcsLi5K0k96kUvqqoHHjx9Pfu7Jkydza6nVqdTNbX/1q18lj5n3ub3MJGXv3r3Jetkb2B49ejS3Vna1tdeZpG7UXHSVyNQKXurrpq66V8WaZ+Ggdvczknb0/Mir2Pj4uCYmJjQzM0MuLZr/J/zRatsprxKZZJFJ91jPA4DgGNQAEByDGgCCY1ADQHAMagAIjkENAMH1/S7kqT3qossmpnaeJyfzN31WcvnUQSvayUzt/qbuzpzaRS6683k/pC61WnT5yVQ9dfnUVF5jY2PJY6b+HPqh6I7fe/aUu5xGalf6hRdeKPU1I0k9v5aWlnJr/X6O8IoaAIJjUANAcAxqAAiOQQ0AwTGoASA4BjUABGfu3vsvanZe0ufN345IinQDy171c6u7d3w18+CZSAPI5ZpMetlDr5BJFs+frMozqWRQf+sAZjORLmcYoZ8IPVwrQk8RemgVoZ8IPbSK0E+EHlr1ox/e+gCA4BjUABBcPwb1i304Rjci9BOhh2tF6ClCD60i9BOhh1YR+onQQ6vK+6n8PWoAwMrw1gcABFfpoDazB8xszszmzSx9m+Q+MLMFMztlZifMLH3L8+p6IJNsD2SS7SFUJhK55PTTn0zcvZL/SVon6VNJ45KGJJ2UNFHV8TrsaUHSyACPTyZksiozIZfBZlLlK+q7JM27+xl3vyzpFUmDvWjv4JFJFplkkUl7azaXKgf1qKSzLb8/13xskFzS22Y2a2blrqS+MmSSRSZZETORyKWdvmTS9zu8DNjd7l43s1skHTWzj9393UE3NWBkkkUm7ZFLVl8yqfIVdV3S5pbfb2o+NjDuXm/+9wtJh3X1r1L9RCZZZJIVLhOJXNrpVyZVDur3JN1uZlvNbEjSo5Jer/B4SWZ2g5nd+M2vJd0v6YM+t0EmWWSSFSoTiVza6Wcmlb314e5fm9mTkt7S1X+tfcndP6zqeB3YKOmwmUlXv++X3f3NfjZAJllkkhUwE4lc2ulbJvxkIgAEx08mAkBwDGoACI5BDQDBMagBIDgGNQAEx6AGgOAY1AAQHIMaAIJjUANAcAxqAAiOQQ0AwTGoASA4BjUABMegBoDgGNQAEByDGgCCY1ADQHAMagAIjkENAMExqAEgOAY1AATHoAaA4BjUABAcgxoAgmNQA0BwDGoACI5BDQDBMagBIDgGNQAEx6AGgOAY1AAQHIMaAIJjUANAcAxqAAiOQQ0AwTGoASA4BjUABMegBoDgGNQAEByDGgCCY1ADQHAMagAIjkENAMExqAEgOAY1AARX6+SDzOwBSQckrZP0e3ffn/r4kZERHxsb67qZubm5ZP073/lObq3M8Vbi9OnTunjx4mlVnEmRVGaNRiO3NjEx0fNelpaWND8/35C0oAozWVxcTNZT3/eXX36ZW/v3v/+dW1u3bl3ymHfccUdu7cSJE0uSFtXBuVI2k7Nnzybrqe/7hz/8YW5t48aNubWiTFJmZ2c7zkQqn8v8/HyynjpXtm/f3vXxVmJhYUEXLlywdrXCQW1m6yT9TtJ9ks5Jes/MXnf3j/I+Z2xsTDMzM103eu+99ybrqT+ogwcPdn28shqNhq6//npJ+oUqzqRIKrPUk7PXvTQaDW3btk2SPpI0qQozee6555L11Pf92muv5dZOnjyZW/ve976XPOaxY8faPt5oNDQyMrJeHZ4rZTN56qmnkvXU9/3444+X+ro33XRTYV/tNBoN1Wq1jjORyufy8MMPJ+upc+X48eNdH28lJicnc2udvPVxl6R5dz/j7pclvSJpd496W5Wmp6e1fv16kcmy6elp3XbbbZJ0mUyWzc7OStJ/OFeWTU9PS2TSlU4G9aik1r9XnWs+tmbV63Vdd911rQ+RSb2uzZs3tz605jORpL///e+SdLnloTWfS71el8ikKz37x0Qz22NmM2Y2c/78+V592VWNTLLIJItM2iOXZZ0M6rqk1pdKm5qPfYu7v+juk+4+uWHDhl71F9Lo6KiuXLnS+hCZjI5e+w9aaz4TSfrRj34kSUMtD2VyWWuZjI6OSgWZSGsvl5ROBvV7km43s61mNiTpUUmvV9tWbFNTU/rqq69EJsumpqb0ySefSNIQmSzbuXOnJK3nXFk2NTUlkUlXCrc+3P1rM3tS0lu6ukrzkrt/WEUzCwsLyfo777yTWzt06FBu7dZbby19zHZqtZq2bNmi+fn5yjM5cuRIsp7K5De/+U2v28lVq9X0/PPP66GHHtom6a+qMJMiqW2E1MZIqpbaDig6pqS/qeLnz4kTJ0p/bmpjKrX5UHYrolarST3MJPUcLnr+pJi13ZSTJO3YsSO3tpI/izwd7VG7+xuS3uj50Vex4eFhufu2QfcRyYMPPihJH7h7/p7R2rREJhlk0gV+MhEAgmNQA0BwDGoACI5BDQDBMagBILiOtj76pegiL59//nlubXh4OLdW9sJFnfRUtZWs2BVdkGa1KroAUcq+fftya6k1r35foKdbd955Z7Je9oJmqfO/KJOii6z1StFzOOWee+7JraUy6/f5wCtqAAiOQQ0AwTGoASA4BjUABMegBoDgGNQAEByDGgCCC7VHXXSX4dTNR5eWlnJrqR3TQe9JFynaEU1dbrFotzayKi6vKRXfGDdP6uawUvoGsf1QdPyf/vSnubXU/njq+VHmruBVWEkfqT/X1M8hrGR3uwxeUQNAcAxqAAiOQQ0AwTGoASA4BjUABMegBoDgQq3nFa1ApdayUnf+ffrpp8u2tKJLavZC0RpQajUptYqWWj2KsHaV6qHoLs9l1/dS51+/LtlZ1krWxVJ3sv/ss89yaxHOEym9QphaX5Wkm2++Obf261//OreWOgdT645Sudx4RQ0AwTGoASA4BjUABMegBoDgGNQAEByDGgCC62g9z8wWJF2U1JD0tbtPVtlUnipWpIpWafKcOnVKZnZKFWdStMqTWq1KrWylVhb/8pe/JI9ZcFW+O3qRS+r7LlrjNLNSn1vhCl5PMkmthO3atSv5uam72aeeA6k1zqI/h4JztyeZFCla5UzVy159smiltyi3drrZo97l7he6PsL/NjJpj1yyyCSLTDrEWx8AEFyng9olvW1ms2a2p8qGVhkyaY9cssgki0w61OlbH3e7e93MbpF01Mw+dvd3Wz+gGfYeSdqyZUuP24xn+/btev/993eSScbH7p6bC5mQSVMyE2nN5tJWR6+o3b3e/O8Xkg5LuqvNx7zo7pPuPrlhw4bedhnQ0NCQJDJp44qUnwuZkElTMpNmbS3m0lbhoDazG8zsxm9+Lel+SR9U3Vhkly5dUqPRkEQmrS5duiQ1zylyuYpMssike5289bFR0uHmylNN0svu/mYVzRw5ciRZHx4ezq3t27ev1DFT60d5FhcXNTc3JzM7qYozKbppaWrNLrUelVrJKlofyltbWlxclKSfVJ1L0fpT6jy55557et1OUi8zSf15pr5nKZ1Z6lxI3RT34MGDyWPmPSf7dZ50IrWCl8os9b2XWb8rUjio3f2MpPS1AteY8fFxTUxMaGZmhlxajI+PS9JHg9qzj4hMssike6znAUBwDGoACI5BDQDBMagBIDgGNQAEx6AGgOBC3YX82LFjyfqBAwdKfd3HHnsstxb97tJFe9SpHdjUrmfq+y6zW95PRXcZP3ToUG4tdcfq6FK9F53Hqbttp3awd+/enVsr2mePoKjH1GVOU5cJTp2DZS+PmsIragAIjkENAMExqAEgOAY1AATHoAaA4BjUABCcuXvvv6jZeUmfN387IinSDSx71c+t7t7x1cyDZyINIJdrMullD71CJlk8f7Iqz6SSQf2tA5jNRLqcYYR+IvRwrQg9ReihVYR+IvTQKkI/EXpo1Y9+eOsDAIJjUANAcP0Y1C/24RjdiNBPhB6uFaGnCD20itBPhB5aRegnQg+tKu+n8veoAQArw1sfABBcpYPazB4wszkzmzezvVUeq8N+FszslJmdMLOZAfVAJtkeyCTbQ6hMJHLJ6ac/mbh7Jf+TtE7Sp5LGJQ1JOilpoqrjddjTgqSRAR6fTMhkVWZCLoPNpMpX1HdJmnf3M+5+WdIrkvIvbrs2kEkWmWSRSXtrNpcqB/WopLMtvz/XfGyQXNLbZjZrZnsGcHwyySKTrIiZSOTSTl8yCXWHlz64293rZnaLpKNm9rG7vzvopgaMTLLIpD1yyepLJlW+oq5L2tzy+03NxwbG3evN/34h6bCu/lWqn8gki0yywmUikUs7/cqkykH9nqTbzWyrmQ1JelTS6xUeL8nMbjCzG7/5taT7JX3Q5zbIJItMskJlIpFLO/3MpLK3Ptz9azN7UtJbuvqvtS+5+4dVHa8DGyUdNjPp6vf9sru/2c8GyCSLTLICZiKRSzt9y4SfTASA4PjJRAAIjkENAMExqAEgOAY1AATHoAaA4BjUABAcgxoAgmNQA0Bw/w/9yfPpDXP9eAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"for index, image in enumerate(digits.images[:10]):\n",
" plt.subplot(2, 5, index + 1)\n",
" plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us create a training set using the first 1000 images and a test set using the rest of the data."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"X,y = digits.data[0:1000], digits.target[0:1000]\n",
"X_test, y_test = digits.data[1000:], digits.target[1000:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"scikit-learn provide us with an One-Vs-Rest classifier that we already imported with name `OvR`. Let us use that classifier to fit the training set and to make predictions over the test set:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n"
]
}
],
"source": [
"binaryLearner = LinearSVC(random_state=0)\n",
"\n",
"oneVrestLearningAlgorithm = OvR(binaryLearner)\n",
"oneVrestLearningAlgorithm.fit(X,y)\n",
"predicted_labels = oneVrestLearningAlgorithm.predict(X_test)\n",
"\n",
"# n.b.: the above is equivalent to:\n",
"# predicted_labels = OvR(LinearSVC(random_state=0)).fit(X,y).predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9058971141781681\n"
]
}
],
"source": [
"print (\"Accuracy:\", (1.0 - np.count_nonzero(y_test - predicted_labels) / float(len(predicted_labels))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exercise"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reimplement the OvR classifier by completing the methods in the following class [[1](#hint1)]:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class OneVsRestClassifier:\n",
" \n",
" def __init__(self, learner):\n",
" self.learner = learner\n",
" return None\n",
"\n",
" def fit(self, data, labels):\n",
" self.labels = list(set(labels))\n",
" self.labels.sort()\n",
" self.classifiers = []\n",
" for i in range(len(self.labels)):\n",
" self.classifiers.append(copy.copy(self.learner))\n",
" #fit con data modificato (valore interessato in 1, il resto in -1)\n",
" self.classifiers[i].fit(data, [1 if label == self.labels[i] else -1 for label in labels])\n",
" return self\n",
"\n",
" def predict(self, data):\n",
" #trasposta delle predictions (ogni riga corrisponde alla prediction di ogni classificartore in ordine)\n",
" predictions = np.array([classifier.predict(data) for classifier in self.classifiers]).transpose()\n",
" prediction = []\n",
" #il valore predetto è la prima occorrenza di 1 in ogni riga di predictions, 0 se non è presente (questo favorisce le labels più piccole, soprattutto 0)\n",
" for i in range(len(data)):\n",
" prediction.append(self.labels[predictions[i].tolist().index(1) if 1 in predictions[i] else 0])\n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n"
]
},
{
"data": {
"text/plain": [
"<__main__.OneVsRestClassifier at 0x7f7b83ded610>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier = OneVsRestClassifier(binaryLearner)\n",
"classifier.fit(X,y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calculate the accuracy of your solution using the following code [[2](#hint2)]:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8393977415307403\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n"
]
}
],
"source": [
"ovr = OneVsRestClassifier(LinearSVC(random_state=0))\n",
"predicted_labels = ovr.fit(X,y).predict(X_test)\n",
"print(\"Accuracy:\", (1.0-np.count_nonzero(predicted_labels-y_test)/float(len(y_test))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a name=\"hint1\">Hint 1:</a> Feel free to organize your code as you like (add as many methods as you believe are necessary).\n",
"\n",
"<a name=\"hint2\">Hint 2:</a> The scheme provided by scikit-learn is a little different from the one we have seen in the lessons. It is normal if your accuracy is not as good as the one obtained above (expect the accuracy to be between 0.8 and 0.9)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 1
}