UniTO/anno3/apprendimento_automatico/esercizi/1/one_vs_rest.ipynb

327 lines
18 KiB
Text
Raw Normal View History

2020-06-16 18:27:43 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiments with the one vs rest multiclass classification scheme"
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 1,
"metadata": {},
2020-06-16 18:27:43 +02:00
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from sklearn import datasets\n",
"from sklearn.multiclass import OneVsRestClassifier as OvR\n",
"from sklearn.svm import LinearSVC\n",
"import numpy as np\n",
"import copy\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 2,
"metadata": {},
2020-06-16 18:27:43 +02:00
"outputs": [],
"source": [
"digits = datasets.load_digits()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the \"Optical Recognition of Handwritten Digits Data Set\" from UCI (included in scikit learn and already loaded on the previous line). Let us plot the first 10 images in the dataset."
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 3,
"metadata": {},
2020-06-16 18:27:43 +02:00
"outputs": [
{
"data": {
2020-07-03 19:08:23 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAADQCAYAAAAu/itEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAT/ElEQVR4nO3dQWwc9fnG8ef9Z3FTKDW0DlHrJDgWJJUlFJranJAgUkEUDuFSxA16CRekwoncmt7CjUj0AKpockGoHEI4ICBVEzgaW00IUBxMME1WrUnUYkVRacLq/R+y1Etm5ze7453d1/X3I1Uk+9qe109m3242r2fM3QUAiOv/Bt0AACCNQQ0AwTGoASA4BjUABMegBoDgGNQAEFytkw8yswckHZC0TtLv3X1/6uNHRkZ8bGys62b+9a9/Jevnzp3LrX3/+9/PrW3atCm3tm7duuLG2jh9+rQuXrx4WhVnUmRubi631mg0cms//vGPc2s33XRTqV6WlpY0Pz/fkLSgAWZy8eLF3Nqnn36aW/vud7+bW9u+fXvpfmZnZ5ckLaqDc6VsJv/4xz+S9Xq9nlsbGhrKrU1MTOTWyj53pO4ykao7V1LPkc8++yy3dtttt/W8l4WFBV24cMHa1QoHtZmtk/Q7SfdJOifpPTN73d0/yvucsbExzczMdN3oq6++mqw/88wzubX77rsvt7Z/f/45cPPNNxc3do1Go6Hrr79ekn6hijMpcu+99+bWvvzyy9zab3/729za7t27u+6j0Who27ZtkvSRpEkNMJPjx4/n1h5++OHc2p133lnqa6Y0Gg3VarX16vBcKZvJs88+m6zv3bs3tzY6Oppb+/Of/5xbK/PckbrPRKruXEk9Rx5//PHc2muvvdbzXiYnJ3Nrnbz1cZekeXc/4+6XJb0iqftn8v+Q6elprV+/XmSybHp6+ptXGZfJZNn09LQk/YdzZRmZdK+TQT0q6WzL7881H1uz6vW6rrvuutaHyKRe1+bNm1sfWvOZSP99y+Fyy0NrPhcy6V7P/jHRzPaY2YyZzZw/f75XX3ZVI5MsMskik/bIZVkng7ouqfWl0qbmY9/i7i+6+6S7T27YsKFX/YU0OjqqK1eutD5EJqOjOnu29S9eZCL99/3f1n+ty+RCJpwrRToZ1O9Jut3MtprZkKRHJb1ebVuxTU1N6auvvhKZLJuamtInn3wiSUNksmxqakqS1nOuLCOT7hVufbj712b2pKS3dHWV5iV3/7CKZlJbHVJ6XSa12veDH/wgt/bHP/4xecxf/vKXmcdqtZq2bNmi+fn5yjMpklqle+edd3Jrx44dy62V2fqo1Wp6/vnn9dBDD22T9FdVmMmJEyeS9V27duXWhoeHc2sLCwtlW8pVq9Uk6W/qwfMntblRdB6/8MILubUnnngitzY7O5tb+/nPf548Zp5eZrJSBw8ezK2ltoD6raM9and/Q9IbFfeyqgwPD8vdtw26j0gefPBBSfrA3fP3jNamJTLJIJMu8JOJABAcgxoAgmNQA0BwDGoACI5BDQDBdbT10UupdZ/U+p2UvvLZ+Ph4bi11waZUP1L79bx+KlpFK3uxoEirR90quiDOjh07cmupizKlLlQVwZ49e3JrRautP/vZz3JrW7duza2VXcGLInXRJSm9nvfUU0/l1layylnmKoC8ogaA4BjUABAcgxoAgmNQA0BwDGoACI5BDQDBMagBILi+71GnLke6c+fO5OemdqVTUjukETz33HO5tX379iU/d2lpqdQxUzfFjS613yql91RTn1vm8q79lDr/z5w5k/zc1M8opHalU8/Xsje37afUnrSU3odO3dw2dR6lLj0sFT+n2+EVNQAEx6AGgOAY1AAQHIMaAIJjUANAcAxqAAgu1Hpe6nKkVR0zwopRatUntSIkle+/6PKPg5bqL7XOKBVfBjVP0SpXZEWrq//85z9za6n1vFTtT3/6U/KY/XpuHTlyJLf29NNPJz/3scceK3XMAwcO5Nb+8Ic/lPqaKbyiBoDgGNQAEByDGgCCY1ADQHAMagAIjkENAMF1tJ5nZguSLkpqSPra3SfLHjC1slN0R/CU1ArezMxMbu2RRx4pdbxTp07JzE6pB5kMQuru5iu8Q/kdvcgldYWx1GpUkdTqXtFVz1agJ5msROp5l1qze+KJJ3Jrzz77bPKY+/fvT5V7lsnw8HCpmiQdOnQot5Z6jqSk7nRfVjd71Lvc/ULPO1jdyKQ9cskikywy6RBvfQBAcJ0Oapf0tpnNmtmeKhtaZcikPXLJIpMsMulQp2993O3udTO7RdJRM/vY3d9t/YBm2HskacuWLT1uM57t27fr/fff30kmGR+7e24uZEImTclMpDWbS1sdvaJ293rzv19IOizprjYf86K7T7r75IYNG3rbZUBDQ0OSyKSNK1J+LmRCJk3JTJq1tZhLW4WD2sxuMLMbv/m1pPslfVB1Y5FdunRJjUZDEpm0unTpktQ8p8jlKjLJIpPudfLWx0ZJh83sm49/2d3fLHvA1FW+Umt0kvTqq6+WqqU888wzXX/O4uKi5ubmZGYn1YNM/lcsLi5K0k96kUvqqoHHjx9Pfu7Jkydza6nVqdTNbX/1q18lj5n3ub3MJGXv3r3Jetkb2B49ejS3Vna1tdeZpG7UXHSVyNQKXurrpq66V8WaZ+Ggdvczknb0/Mir2Pj4uCYmJjQzM0MuLZr/J/zRatsprxKZZJFJ91jPA4DgGNQAEByDGgCCY1ADQHAMagAIjkENAMH1/S7kqT3qossmpnaeJyfzN31WcvnUQSvayUzt/qbuzpzaRS6683k/pC61WnT5yVQ9dfnUVF5jY2PJY6b+HPqh6I7fe/aUu5xGalf6hRdeKPU1I0k9v5aWlnJr/X6O8IoaAIJjUANAcAxqAAiOQQ0AwTGoASA4BjUABGfu3vsvanZe0ufN345IinQDy171c6u7d3w18+CZSAPI5ZpMetlDr5BJFs+frMozqWRQf+sAZjORLmcYoZ8IPVwrQk8RemgVoZ8IPbSK0E+EHlr1ox/e+gCA4BjUABBcPwb1i304Rjci9BOhh2tF6ClCD60i9BOhh1YR+onQQ6vK+6n8PWoAwMrw1gcABFfpoDazB8xszszmzSx9m+Q+MLMFMztlZifMLH3L8+p6IJNsD2SS7SFUJhK55PTTn0zcvZL/SVon6VNJ45KGJJ2UNFHV8TrsaUHSyACPTyZksiozIZfBZlLlK+q7JM27+xl3vyzpFUmDvWjv4JFJFplkkUl7azaXKgf1qKSzLb8/13xskFzS22Y2a2blrqS+MmSSRSZZETORyKWdvmTS9zu8DNjd7l43s1skHTWzj9393UE3NWBkkkUm7ZFLVl8yqfIVdV3S5pbfb2o+NjDuXm/+9wtJh3X1r1L9RCZZZJIVLhOJXNrpVyZVDur3JN1uZlvNbEjSo5Jer/B4SWZ2g5nd+M2vJd0v6YM+t0EmWWSSFSoTiVza6Wcmlb314e5fm9mTkt7S1X+tfcndP6zqeB3YKOmwmUlXv++X3f3NfjZAJllkkhUwE4lc2ulbJvxkIgAEx08mAkBwDGoACI5BDQDBMagBIDgGNQAEx6AGgOAY1AAQHIMaAIJjUANAcAxqAAiOQQ0AwTGoASA4BjUABMegBoDgGNQAEByDGgCCY1ADQHAMagAIjkENAMExqAEgOAY1AATHoAaA4BjUABAcgxoAgmNQA0BwDGoACI5BDQDBMagBIDgGNQAEx6AGgOAY1AAQHIMaAIJjUANAcAxqAAiOQQ0AwTGoASA4BjUABMegBoDgGNQAEByDGgCCY1ADQHAMagAIjkENAMExqAEgOAY1AARX6+SDzOwBSQckrZP0e3ffn/r4kZERHxsb67qZubm5ZP073/lObq3M8Vbi9OnTunjx4mlVnEmRVGaNRiO3NjEx0fNelpaWND8/35C0oAozWVxcTNZT3/eXX36ZW/v3v/+dW1u3bl3ymHfccUdu7cSJE0uSFtXBuVI2k7Nnzybrqe/7hz/8YW5t48aNubWiTFJmZ2c7zkQqn8v8/HyynjpXtm/
2020-06-16 18:27:43 +02:00
"text/plain": [
2020-07-03 19:08:23 +02:00
"<Figure size 432x288 with 10 Axes>"
2020-06-16 18:27:43 +02:00
]
},
2020-07-03 19:08:23 +02:00
"metadata": {
"needs_background": "light"
},
2020-06-16 18:27:43 +02:00
"output_type": "display_data"
}
],
"source": [
"for index, image in enumerate(digits.images[:10]):\n",
" plt.subplot(2, 5, index + 1)\n",
" plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us create a training set using the first 1000 images and a test set using the rest of the data."
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 4,
"metadata": {},
2020-06-16 18:27:43 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2020-07-03 19:08:23 +02:00
"[ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0. 13. 15. 10. 15. 5. 0. 0. 3.\n",
" 15. 2. 0. 11. 8. 0. 0. 4. 12. 0. 0. 8. 8. 0. 0. 5. 8. 0.\n",
" 0. 9. 8. 0. 0. 4. 11. 0. 1. 12. 7. 0. 0. 2. 14. 5. 10. 12.\n",
" 0. 0. 0. 0. 6. 13. 10. 0. 0. 0.]\n"
2020-06-16 18:27:43 +02:00
]
}
],
"source": [
"X,y = digits.data[0:1000], digits.target[0:1000]\n",
"X_test, y_test = digits.data[1000:], digits.target[1000:]\n",
"print(X[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"scikit-learn provide us with an One-Vs-Rest classifier that we already imported with name `OvR`. Let us use that classifier to fit the training set and to make predictions over the test set:"
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n"
]
}
],
2020-06-16 18:27:43 +02:00
"source": [
"binaryLearner = LinearSVC(random_state=0)\n",
"\n",
"oneVrestLearningAlgorithm = OvR(binaryLearner)\n",
"oneVrestLearningAlgorithm.fit(X,y)\n",
"predicted_labels = oneVrestLearningAlgorithm.predict(X_test)\n",
"\n",
"# n.b.: the above is equivalent to:\n",
"# predicted_labels = OvR(LinearSVC(random_state=0)).fit(X,y).predict(X_test)"
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9058971141781681\n"
]
}
],
2020-06-16 18:27:43 +02:00
"source": [
2020-07-03 19:08:23 +02:00
"print (\"Accuracy: %s\" % (1.0 - np.count_nonzero(y_test - predicted_labels) / float(len(predicted_labels))))"
2020-06-16 18:27:43 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exercise"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reimplement the OvR classifier by completing the methods in the following class [[1](#hint1)]:"
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 7,
"metadata": {},
2020-06-16 18:27:43 +02:00
"outputs": [],
"source": [
"class OneVsRestClassifier:\n",
" def __init__(self, learner):\n",
2020-07-03 19:08:23 +02:00
" self.learner = learner\n",
" self.labels = None\n",
" self.data = None\n",
" self.classifiers = None\n",
" \n",
" def _fit_cls(self, one, labels):\n",
" from sklearn.base import clone\n",
" cls = clone(self.learner)\n",
" cls.fit(self.data, [1 if label == one else -1 for label in labels])\n",
" return cls\n",
2020-06-16 18:27:43 +02:00
"\n",
" def fit(self, data, labels):\n",
2020-07-03 19:08:23 +02:00
" assert self.data is None\n",
" assert self.labels is None\n",
" assert self.classifiers is None\n",
2020-06-16 18:27:43 +02:00
"\n",
2020-07-03 19:08:23 +02:00
" self.labels = sorted(set(labels))\n",
" self.data = data\n",
" self.classifiers = list(map(lambda one: self._fit_cls(one, labels), self.labels))\n",
2020-06-16 18:27:43 +02:00
" return self\n",
"\n",
" def predict(self, data):\n",
2020-07-03 19:08:23 +02:00
" predictions = np.array(list(map(lambda cls: cls.predict(data), self.classifiers))).T\n",
" assert len(data) == predictions.shape[0]\n",
" # np.where[0][0] is like list.index\n",
" return [self.labels[np.where(pred == 1)[0][0] if 1 in pred else 0] for pred in predictions]"
2020-06-16 18:27:43 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calculate the accuracy of your solution using the following code [[2](#hint2)]:"
]
},
{
"cell_type": "code",
2020-07-03 19:08:23 +02:00
"execution_count": 8,
2020-06-16 18:27:43 +02:00
"metadata": {
2020-07-03 19:08:23 +02:00
"scrolled": true
2020-06-16 18:27:43 +02:00
},
2020-07-03 19:08:23 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8393977415307403\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n",
"/home/user/.local/lib/python3.7/site-packages/sklearn/svm/_base.py:977: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" \"the number of iterations.\", ConvergenceWarning)\n"
]
}
],
2020-06-16 18:27:43 +02:00
"source": [
"ovr = OneVsRestClassifier(LinearSVC(random_state=0))\n",
"predicted_labels = ovr.fit(X,y).predict(X_test)\n",
2020-07-03 19:08:23 +02:00
"print (\"Accuracy: %s\" % (1.0-np.count_nonzero(predicted_labels-y_test)/float(len(y_test))))"
2020-06-16 18:27:43 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a name=\"hint1\">Hint 1:</a> Feel free to organize your code as you like (add as many methods as you believe are necessary).\n",
"\n",
"<a name=\"hint2\">Hint 2:</a> The scheme provided by scikit-learn is a little different from the one we have seen in the lessons. It is normal if your accuracy is not as good as the one obtained above (expect the accuracy to be between 0.8 and 0.9)."
]
},
2020-07-03 19:08:23 +02:00
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"learner = LinearSVC(random_state=0)"
]
},
2020-06-16 18:27:43 +02:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {
2020-07-03 19:08:23 +02:00
"scrolled": true
2020-06-16 18:27:43 +02:00
},
"outputs": [],
"source": []
2020-07-03 19:08:23 +02:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2020-06-16 18:27:43 +02:00
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
2020-07-03 19:08:23 +02:00
"display_name": "Python 3",
2020-06-16 18:27:43 +02:00
"language": "python",
2020-07-03 19:08:23 +02:00
"name": "python3"
2020-06-16 18:27:43 +02:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
2020-07-03 19:08:23 +02:00
"version": 3
2020-06-16 18:27:43 +02:00
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
2020-07-03 19:08:23 +02:00
"pygments_lexer": "ipython3",
"version": "3.7.7"
2020-06-16 18:27:43 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 1
}