UniTO/anno3/apprendimento_automatico/esercizi/all_m/one_vs_rest.ipynb

225 lines
18 KiB
Text
Raw Normal View History

2020-06-23 21:53:50 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiments with the one vs rest multiclass classification scheme"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"from sklearn import datasets\n",
"from sklearn.multiclass import OneVsRestClassifier as OvR\n",
"from sklearn.svm import LinearSVC\n",
"import numpy as np\n",
"import copy\n",
"from matplotlib import pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"digits = datasets.load_digits()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the \"Optical Recognition of Handwritten Digits Data Set\" from UCI (included in scikit learn and already loaded on the previous line). Let us plot the first 10 images in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWYAAADRCAYAAAD/nhhvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnX2MHNWZr58XDOHD2RkjJ/ESW4wTFIJCZOMQhBITD4Hs\nEpTEIOVjI927dlYKXGl1l49dYlZXVxhppYB2JZv7x0pkEzAo2SAgawg3urmBDUNgI0MMHgLhKzae\nXBsYwMHjXUK8wuG9f3TN0NNVdU51VddMTffvkVrT3W9X9elnzjldfc5bp8zdEUII0RyOmu8CCCGE\nmI06ZiGEaBjqmIUQomGoYxZCiIahjlkIIRqGOmYhhGgY0Y7ZzC40s2fN7NdmtmkuCrUQkJc0cpJG\nTrKRlwjunnsDjgZ2AyPAMcA4cHrHa3xQbvIiJ914kRPVlaJOOm+xI+azgd3uPuHubwG3A+s7X9S+\nw2uvvTa3o++M3XHHHbNuX/ziF2fur1y5ctZteHh45v6ll14667ZmzZpZj19//fWZ2ze+8Y1Zj4uW\ntTNexUs37xOLrVu3buZ2yimnzHq8atWqmdv73ve+WY+3b98+c/vKV74y63HZ8sxlXQnFHnjggZnb\nhg0bZj0eGhqadXvXu94163HIZ9nyzJWT66+/fuZ2/vnnz3ocItS2Vq5c2ZP2E3DS0/bTbds6ePDg\nzG3Tpk2zHq9fv37mdtppp8163It6G2k/s4h1zO8H9rU93p88N+jISxo5SSMn2chLhFjHHO7WBxd5\nSSMnaeQkG3mJsCgSfxFY0fZ4Ba1vt1ls3rx55v7w8HDuzkZHR4Nv9pGPfCQ3dtxxx+XGTj755NzY\n2rVrS5VnbGyMiYmJWZ+tja68TExMMDY2lvt+oXKEYiHXixcvzo2dccYZubHQe/bSCZSvK6HY6tWr\nc2MAixblV/my5RkeHp53Jx/4wAdyYzFCbats+wk4gR62n1ifEoqHPtvSpUtL7bNCnzILC411mNki\n4DngfOAl4FHgq+7+TNtrPDZeksedd96ZG9u0KX+i9jOf+Uxwv6ExtiVLlsQLloGZ4e6W3K/VS4jQ\nP35qaio3FqoMF198camyNMXJ2NhYbiz22UIdeWi/Iaa91O3khhtuyI1dc801ubGVK1cG9/vYY4/l\nxhZ6+wm1kY0bN+bG7r777p6Xpd1JJ8EjZnc/Yma7ac2g/gHY3C6vKqHOd+/evbmxgwcPBvd70kkn\n5cbuuOOO4LZf+tKXgvGEbwEn0apc+4Dv9NJLiNAR1YMPPpgbq9J5FaHuujI+Pp4bO++883JjQ0ND\nwf1OTEyULVKUqk5CnSuE6/JNN92UG7vsssuC+w11zBdccEFw24LMW/vZtm1bbiz2a2suKXKCyd8B\nHweed/dv1lyehcItwKdpOTlVXmZQXUkjJ2nUfiJEO2Z3fwgIH6IOGHKSjbykkZM0chJHp2QLIUTD\nUMcshBANI5YuV4j2Gf/R0dFoCstCYGxsrPSs/DT95kVOsqnqRU6y6Tcv3TjpecfcL3RWhOuuu67r\nffSbFznJpqoXOcmm37x046TI6nLfB34OfMjM9pnZ13pQxgWNnGQjL2nkJI2cxClyxPwN4DbgvbRO\npXx30Z2H8iEhnKu8Z8+e3FjsLKfQCSixMhXMYy7tJEYoXxfKn/RQd46mma0AltGabX8d+Ja739Kr\n/YcS/FetWpUbi+VolzmS65LSdeXSSy8NxkPnAXzsYx/LjcVOMOlRrnKI2tpP6AQSCOcxX3HFFbmx\nKvnuIyMjXW9TpGN+C7jS3cfNbDHwmJndN1cJ4Q1FTtLISTbykkZOIhTJY5509/Hk/hvAM0D+4hQD\ngJykkZNs5CWNnMTpKl3OzEaAM4FH6ijMQkRO0shJNvKSRk6yKZyVkfzkuAu4PPmWm6Hf0lqgWGpL\nyAn0nxc5yaaqFzkZjLrSTbpccHW5mReZHQP8b+D/uPvWjljuKlCxibazzjorN1bX5F9oUgTyV6br\nXAkq5CSJl1odKzb5F6qchw4dyo3dckv+PFxoVa0Qc+UEwqlToYnBKpN/ZcvajZeQkxdeeCH4PqGV\n3mL1PETsfcswV3UlNvkXaj+hyb8qXwp5k3+h1eWKpMsZ8B3g6SyBg4icpJGTbOQljZzEKTKUcR7w\nX4HDZvbfgEPARnf/cWzD2PKca9asyY1VWfi7ytFCQUo7Adi6Nb8uxpLqQ0fFIebgZ2AlJzFCRzOh\ndKTQdgDr16cuNddrSnuJtYHQkW0oFTWWDhdqt2XXY+6gtroSSoeDcNpb6JdjqB6FluKFcifKFMnK\n+ClworsfTyvfcAJIjQcNEnKSRk6ykZc0chKnUFaGu7+Z3D2W1qXHX6+tRAsEOUkjJ9nISxo5CVOo\nYzazo8xsHHgFeMDdn663WM1HTtLISTbykkZOwhQ9Yn7b3VcDy4FPmdloraVaAMhJGjnJRl7SyEmY\nrlaXc/dDZvYj4CxgbPr5fss3hOI5h3lOoP+8yEk2Vb3IyWDUlZ4u+2lmS4Ej7j5lZscDnwFmJX/2\n2/J8EF6ir4gT6D8vcpJNVS9yMhh1pafLfgJ/DPzUzN6kNR50r7v/a8UyLnTkJM20k3FaEzlL5QRQ\nXclCTiJEj5jd/Ukz+y7wMeDd7v73RXcey2MOnaFXhbrzMKs4gXBOZOwsvLLlj50RVRV3fxJYY2ZX\nkXjpZvtY+UK536Ez/2LE8l6rUrWuhAjlOb/+en6SQyyPORS///77g9sWqZ9VnYT+31deeWVw2w0b\nNnTzVjPceOONubHQWbVlKXLm33LgIuDbQObpg4OGnGQjL2nkJI2cxCkylLEFuBp4u+ayLCTkJBt5\nSSMnaeQkQrBjNrPPAa+6+y70zQbISR7ykkZO0shJMWJjzJ8AvmBmFwHHAX9kZre5+5+3v6jf0log\nmNpSyAn0n5dIuo/qSho5SaP2U4BCy34CmNk64G/c/fMdz+cuz3fnnXcG9xlaFjRv+c0iXHbZZbmx\n2ORE0WU/k+cynSSxWpYtLDv5t2vXrtxY2esB5i1bWKau1DX5F7tWWygeW5wmj27qSpWlUEOEJsCr\nXNOv7ORfL9tP6P99ySWXBMsXmvwLTQS3FsTLJjb5lzehH1r2s9AJJmY2AfwBWGpmj7r72UW262fk\nJE3i5N+B44Gl81ua5qC6kkZOwhQ988+Bj7t7VwuNxI7uYgvp5xFLw9u5c2du7Mtf/nKp98yglJP5\nIrQAfw+voO3AaBknsZMJQulKIbZv3x6Mlz0q7pI5ryuhthc76g394rzhhhuC23bxS7e0k9D/bGho\nKLjtrbfemhuLXaQij9jFGMrQzTX/NFCfRk7SyEk28pJGTnIo2jE7cL+Z7TSzr9dZoAWEnKSRk2zk\nJY2cBCg6lPFJd3/ZzN4D3Gdmz7r7Q3UWbAEgJ2nkJBt5SSMnAQp1zO7+cvL3NTPbDpwNzEjst7QW\niKe2xJxA/3mRk2yqepGTwagrPU2XM7MTgKPd/T/M7ETgJ8B17v6TJJ6b1hKbZAhNJNx33325sdjk\nXygdKDb5t2nTpszn21NbYk6S1zQqXa7uq2RXdRK7Nl9dk391TNx046WudLkQsfYTmvyLXYewSLpp\n1boS6txi/8/QNTNXrVqVG3viiSdyYzGfeZOVoXS5ImPMpwL7zOww8FvgiXaBA8r7gJ+b2RQtJ6fS\nShMbZOQkG7WfNKorEYoMZfw1cJW732xmi4ATi+489u0aSmsLnZwSO3ElRN4RcTe4+14zexzYUsZL\nP1LVSeyoPXSUFDqaiZ1wELpKdqxMBY+2S7efGNdcc01uLPSrMXaEF/q12ot006p1JTSkEfvFGUqJ\nC+03dGJKHSmXwY7ZzIaAc919A4C7H6F1qfGBRl7SyEkaOclGXuLEhjJWAq+Z2S1m9riZ/VMyPjTo\nyEsaOUkjJ9nIS4RYx7wIWAP8o7uvAX4H5P9+GhzkJY2cpJGTbOQlQmyMeT+w391/kTy+iwyB/ZbW\nAtHUloH0IifZBLzISTY
"text/plain": [
"<matplotlib.figure.Figure at 0x109b68190>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for index, image in enumerate(digits.images[:10]):\n",
" plt.subplot(2, 5, index + 1)\n",
" plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us create a training set using the first 1000 images and a test set using the rest of the data."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0. 13. 15. 10. 15. 5.\n",
" 0. 0. 3. 15. 2. 0. 11. 8. 0. 0. 4. 12. 0. 0. 8.\n",
" 8. 0. 0. 5. 8. 0. 0. 9. 8. 0. 0. 4. 11. 0. 1.\n",
" 12. 7. 0. 0. 2. 14. 5. 10. 12. 0. 0. 0. 0. 6. 13.\n",
" 10. 0. 0. 0.]\n"
]
}
],
"source": [
"X,y = digits.data[0:1000], digits.target[0:1000]\n",
"X_test, y_test = digits.data[1000:], digits.target[1000:]\n",
"print(X[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"scikit-learn provide us with an One-Vs-Rest classifier that we already imported with name `OvR`. Let us use that classifier to fit the training set and to make predictions over the test set:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"binaryLearner = LinearSVC(random_state=0)\n",
"\n",
"oneVrestLearningAlgorithm = OvR(binaryLearner)\n",
"oneVrestLearningAlgorithm.fit(X,y)\n",
"predicted_labels = oneVrestLearningAlgorithm.predict(X_test)\n",
"\n",
"# n.b.: the above is equivalent to:\n",
"# predicted_labels = OvR(LinearSVC(random_state=0)).fit(X,y).predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print \"Accuracy: %s\" % (1.0 - np.count_nonzero(y_test - predicted_labels) / float(len(predicted_labels)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exercise"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Reimplement the OvR classifier by completing the methods in the following class [[1](#hint1)]:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class OneVsRestClassifier:\n",
" def __init__(self, learner):\n",
" #... to be done ...\n",
" return 1\n",
"\n",
" def fit(self, data, labels):\n",
" #... to be done ...\n",
"\n",
" return self\n",
"\n",
" def predict(self, data):\n",
" #... to be done ...\n",
" return 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calculate the accuracy of your solution using the following code [[2](#hint2)]:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ovr = OneVsRestClassifier(LinearSVC(random_state=0))\n",
"predicted_labels = ovr.fit(X,y).predict(X_test)\n",
"print \"Accuracy: %s\" % (1.0-np.count_nonzero(predicted_labels-y_test)/float(len(y_test)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a name=\"hint1\">Hint 1:</a> Feel free to organize your code as you like (add as many methods as you believe are necessary).\n",
"\n",
"<a name=\"hint2\">Hint 2:</a> The scheme provided by scikit-learn is a little different from the one we have seen in the lessons. It is normal if your accuracy is not as good as the one obtained above (expect the accuracy to be between 0.8 and 0.9)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 1
}