This commit is contained in:
Francesco Mecca 2020-07-03 19:08:23 +02:00
parent 2540e7c3ee
commit 84096d29f6
22 changed files with 8105 additions and 4245 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,893 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# # Classifiers introduction\n",
"\n",
"In the following program we introduce the basic steps of classification of a dataset in a matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the package for learning and modeling trees"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn import tree "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define the matrix containing the data (one example per row)\n",
"and the vector containing the corresponding target value"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"X = [[0, 0, 0], [1, 1, 1], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]\n",
"Y = [1, 0, 0, 0, 1, 1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Declare the classification model you want to use and then fit the model to the data"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"clf = tree.DecisionTreeClassifier()\n",
"clf = clf.fit(X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict the target value (and print it) for the passed data, using the fitted model currently in clf"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\n"
]
}
],
"source": [
"print(clf.predict([[0, 1, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1 0]\n"
]
}
],
"source": [
"print(clf.predict([[1, 0, 1],[0, 0, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: Tree Pages: 1 -->\n",
"<svg width=\"485pt\" height=\"358pt\"\n",
" viewBox=\"0.00 0.00 484.54 358.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 354)\">\n",
"<title>Tree</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-354 480.537,-354 480.537,4 -4,4\"/>\n",
"<!-- 0 -->\n",
"<g id=\"node1\" class=\"node\"><title>0</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"311.307,-350 220.23,-350 220.23,-286 311.307,-286 311.307,-350\"/>\n",
"<text text-anchor=\"middle\" x=\"265.769\" y=\"-334.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[1] &lt;= 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"265.769\" y=\"-320.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"265.769\" y=\"-306.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 6</text>\n",
"<text text-anchor=\"middle\" x=\"265.769\" y=\"-292.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [3, 3]</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node2\" class=\"node\"><title>1</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"256.307,-250 165.23,-250 165.23,-186 256.307,-186 256.307,-250\"/>\n",
"<text text-anchor=\"middle\" x=\"210.769\" y=\"-234.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[0] &lt;= 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"210.769\" y=\"-220.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.444</text>\n",
"<text text-anchor=\"middle\" x=\"210.769\" y=\"-206.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 3</text>\n",
"<text text-anchor=\"middle\" x=\"210.769\" y=\"-192.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 2]</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;1 -->\n",
"<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M248.366,-285.992C243.551,-277.413 238.256,-267.978 233.199,-258.966\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"236.193,-257.15 228.246,-250.142 230.088,-260.576 236.193,-257.15\"/>\n",
"<text text-anchor=\"middle\" x=\"221.482\" y=\"-270.01\" font-family=\"Times,serif\" font-size=\"14.00\">True</text>\n",
"</g>\n",
"<!-- 6 -->\n",
"<g id=\"node7\" class=\"node\"><title>6</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"366.307,-250 275.23,-250 275.23,-186 366.307,-186 366.307,-250\"/>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-234.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[2] &lt;= 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-220.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.444</text>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-206.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 3</text>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-192.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [2, 1]</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;6 -->\n",
"<g id=\"edge6\" class=\"edge\"><title>0&#45;&gt;6</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M283.171,-285.992C287.986,-277.413 293.281,-267.978 298.338,-258.966\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"301.449,-260.576 303.291,-250.142 295.344,-257.15 301.449,-260.576\"/>\n",
"<text text-anchor=\"middle\" x=\"310.055\" y=\"-270.01\" font-family=\"Times,serif\" font-size=\"14.00\">False</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node3\" class=\"node\"><title>2</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"146.307,-150 55.2303,-150 55.2303,-86 146.307,-86 146.307,-150\"/>\n",
"<text text-anchor=\"middle\" x=\"100.769\" y=\"-134.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[2] &lt;= 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"100.769\" y=\"-120.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"100.769\" y=\"-106.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 2</text>\n",
"<text text-anchor=\"middle\" x=\"100.769\" y=\"-92.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 1]</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;2 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M175.964,-185.992C165.631,-176.787 154.195,-166.598 143.424,-157.002\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"145.519,-154.181 135.724,-150.142 140.863,-159.408 145.519,-154.181\"/>\n",
"</g>\n",
"<!-- 5 -->\n",
"<g id=\"node6\" class=\"node\"><title>5</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"256.307,-143 165.23,-143 165.23,-93 256.307,-93 256.307,-143\"/>\n",
"<text text-anchor=\"middle\" x=\"210.769\" y=\"-127.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
"<text text-anchor=\"middle\" x=\"210.769\" y=\"-113.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
"<text text-anchor=\"middle\" x=\"210.769\" y=\"-99.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;5 -->\n",
"<g id=\"edge5\" class=\"edge\"><title>1&#45;&gt;5</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M210.769,-185.992C210.769,-175.646 210.769,-164.057 210.769,-153.465\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"214.269,-153.288 210.769,-143.288 207.269,-153.288 214.269,-153.288\"/>\n",
"</g>\n",
"<!-- 3 -->\n",
"<g id=\"node4\" class=\"node\"><title>3</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"91.3068,-50 0.230281,-50 0.230281,-0 91.3068,-0 91.3068,-50\"/>\n",
"<text text-anchor=\"middle\" x=\"45.7686\" y=\"-34.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
"<text text-anchor=\"middle\" x=\"45.7686\" y=\"-20.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
"<text text-anchor=\"middle\" x=\"45.7686\" y=\"-6.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;3 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M81.9945,-85.9375C76.6731,-77.133 70.8654,-67.5239 65.4898,-58.6297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"68.4603,-56.778 60.2922,-50.0301 62.4695,-60.3988 68.4603,-56.778\"/>\n",
"</g>\n",
"<!-- 4 -->\n",
"<g id=\"node5\" class=\"node\"><title>4</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"201.307,-50 110.23,-50 110.23,-0 201.307,-0 201.307,-50\"/>\n",
"<text text-anchor=\"middle\" x=\"155.769\" y=\"-34.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
"<text text-anchor=\"middle\" x=\"155.769\" y=\"-20.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
"<text text-anchor=\"middle\" x=\"155.769\" y=\"-6.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 0]</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;4 -->\n",
"<g id=\"edge4\" class=\"edge\"><title>2&#45;&gt;4</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M119.543,-85.9375C124.864,-77.133 130.672,-67.5239 136.047,-58.6297\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"139.068,-60.3988 141.245,-50.0301 133.077,-56.778 139.068,-60.3988\"/>\n",
"</g>\n",
"<!-- 7 -->\n",
"<g id=\"node8\" class=\"node\"><title>7</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"366.307,-150 275.23,-150 275.23,-86 366.307,-86 366.307,-150\"/>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-134.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[0] &lt;= 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-120.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.5</text>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-106.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 2</text>\n",
"<text text-anchor=\"middle\" x=\"320.769\" y=\"-92.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 1]</text>\n",
"</g>\n",
"<!-- 6&#45;&gt;7 -->\n",
"<g id=\"edge7\" class=\"edge\"><title>6&#45;&gt;7</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M320.769,-185.992C320.769,-177.859 320.769,-168.959 320.769,-160.378\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"324.269,-160.142 320.769,-150.142 317.269,-160.142 324.269,-160.142\"/>\n",
"</g>\n",
"<!-- 10 -->\n",
"<g id=\"node11\" class=\"node\"><title>10</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"476.307,-143 385.23,-143 385.23,-93 476.307,-93 476.307,-143\"/>\n",
"<text text-anchor=\"middle\" x=\"430.769\" y=\"-127.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
"<text text-anchor=\"middle\" x=\"430.769\" y=\"-113.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
"<text text-anchor=\"middle\" x=\"430.769\" y=\"-99.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 0]</text>\n",
"</g>\n",
"<!-- 6&#45;&gt;10 -->\n",
"<g id=\"edge10\" class=\"edge\"><title>6&#45;&gt;10</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M355.573,-185.992C368.396,-174.568 382.919,-161.63 395.76,-150.19\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"398.369,-152.553 403.507,-143.288 393.712,-147.326 398.369,-152.553\"/>\n",
"</g>\n",
"<!-- 8 -->\n",
"<g id=\"node9\" class=\"node\"><title>8</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"338.307,-50 247.23,-50 247.23,-0 338.307,-0 338.307,-50\"/>\n",
"<text text-anchor=\"middle\" x=\"292.769\" y=\"-34.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
"<text text-anchor=\"middle\" x=\"292.769\" y=\"-20.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
"<text text-anchor=\"middle\" x=\"292.769\" y=\"-6.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 0]</text>\n",
"</g>\n",
"<!-- 7&#45;&gt;8 -->\n",
"<g id=\"edge8\" class=\"edge\"><title>7&#45;&gt;8</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M311.211,-85.9375C308.615,-77.4998 305.791,-68.3232 303.152,-59.7451\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"306.449,-58.5585 300.162,-50.0301 299.758,-60.6172 306.449,-58.5585\"/>\n",
"</g>\n",
"<!-- 9 -->\n",
"<g id=\"node10\" class=\"node\"><title>9</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"448.307,-50 357.23,-50 357.23,-0 448.307,-0 448.307,-50\"/>\n",
"<text text-anchor=\"middle\" x=\"402.769\" y=\"-34.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
"<text text-anchor=\"middle\" x=\"402.769\" y=\"-20.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
"<text text-anchor=\"middle\" x=\"402.769\" y=\"-6.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
"</g>\n",
"<!-- 7&#45;&gt;9 -->\n",
"<g id=\"edge9\" class=\"edge\"><title>7&#45;&gt;9</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M348.759,-85.9375C357.023,-76.7661 366.074,-66.7217 374.364,-57.5217\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"377.021,-59.8019 381.115,-50.0301 371.821,-55.116 377.021,-59.8019\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.files.Source at 0x10dfc3510>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import graphviz \n",
"dot_data = tree.export_graphviz(clf, out_file=None) \n",
"graph = graphviz.Source(dot_data) \n",
"graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the following we start using a dataset (from UCI Machine Learning repository)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris\n",
"iris = load_iris()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Declare the type of prediction model and the working criteria for the model induction algorithm"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",\n",
" random_state=300,\n",
" min_samples_leaf=5,\n",
" class_weight={0:1,1:10,2:10}) # setosa, versicolor, verginica"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Split the dataset in training and test set"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# Generate a random permutation of the indices of examples that will be later used \n",
"# for the training and the test set\n",
"import numpy as np\n",
"np.random.seed(0)\n",
"indices = np.random.permutation(len(iris.data))\n",
"\n",
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
"indices_training=indices[:-10]\n",
"indices_test=indices[-10:]\n",
"\n",
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
"iris_y_train = iris.target[indices_training]\n",
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
"iris_y_test = iris.target[indices_test]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fit the learning model on training set"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# fit the model to the training data\n",
"clf = clf.fit(iris_X_train, iris_y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Obtain predictions"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predictions:\n",
"[1 2 1 0 0 0 2 1 2 0]\n",
"True classes:\n",
"[1 1 1 0 0 0 2 1 2 0]\n",
"['setosa' 'versicolor' 'virginica']\n"
]
}
],
"source": [
"# apply fitted model \"clf\" to the test set \n",
"predicted_y_test = clf.predict(iris_X_test)\n",
"\n",
"# print the predictions (class numbers associated to classes names in target names)\n",
"print(\"Predictions:\")\n",
"print(predicted_y_test)\n",
"print(\"True classes:\")\n",
"print(iris_y_test) \n",
"print(iris.target_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Print the index of the test instances and the corresponding predictions"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Instance # 88: \n",
"Predicted: versicolor\t True: versicolor\n",
"\n",
"Instance # 70: \n",
"Predicted: virginica\t True: versicolor\n",
"\n",
"Instance # 87: \n",
"Predicted: versicolor\t True: versicolor\n",
"\n",
"Instance # 36: \n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 21: \n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 9: \n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 103: \n",
"Predicted: virginica\t True: virginica\n",
"\n",
"Instance # 67: \n",
"Predicted: versicolor\t True: versicolor\n",
"\n",
"Instance # 117: \n",
"Predicted: virginica\t True: virginica\n",
"\n",
"Instance # 47: \n",
"Predicted: setosa\t True: setosa\n",
"\n"
]
}
],
"source": [
"# print the corresponding instances indexes and class names \n",
"for i in range(len(iris_y_test)): \n",
" print(\"Instance # \"+str(indices_test[i])+\": \")\n",
" print(\"Predicted: \"+iris.target_names[predicted_y_test[i]]+\"\\t True: \"+iris.target_names[iris_y_test[i]]+\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Look at the specific examples"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=5.6, sepal width (cm)=3.0, petal length (cm)=4.1, petal width (cm)=1.3\n",
"Predicted: versicolor\t True: versicolor\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=5.9, sepal width (cm)=3.2, petal length (cm)=4.8, petal width (cm)=1.8\n",
"Predicted: virginica\t True: versicolor\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=6.3, sepal width (cm)=2.3, petal length (cm)=4.4, petal width (cm)=1.3\n",
"Predicted: versicolor\t True: versicolor\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=5.5, sepal width (cm)=3.5, petal length (cm)=1.3, petal width (cm)=0.2\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=5.1, sepal width (cm)=3.7, petal length (cm)=1.5, petal width (cm)=0.4\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=4.9, sepal width (cm)=3.1, petal length (cm)=1.5, petal width (cm)=0.1\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=6.3, sepal width (cm)=2.9, petal length (cm)=5.6, petal width (cm)=1.8\n",
"Predicted: virginica\t True: virginica\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=5.8, sepal width (cm)=2.7, petal length (cm)=4.1, petal width (cm)=1.0\n",
"Predicted: versicolor\t True: versicolor\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=7.7, sepal width (cm)=3.8, petal length (cm)=6.7, petal width (cm)=2.2\n",
"Predicted: virginica\t True: virginica\n",
"\n",
"Instance # [ 88 70 87 36 21 9 103 67 117 47]: \n",
"sepal length (cm)=4.6, sepal width (cm)=3.2, petal length (cm)=1.4, petal width (cm)=0.2\n",
"Predicted: setosa\t True: setosa\n",
"\n"
]
}
],
"source": [
"for i in range(len(iris_y_test)): \n",
" print(\"Instance # \"+str(indices_test)+\": \")\n",
" s=\"\"\n",
" for j in range(len(iris.feature_names)):\n",
" s=s+iris.feature_names[j]+\"=\"+str(iris_X_test[i][j])\n",
" if (j<len(iris.feature_names)-1): s=s+\", \"\n",
" print(s)\n",
" print(\"Predicted: \"+iris.target_names[predicted_y_test[i]]+\"\\t True: \"+iris.target_names[iris_y_test[i]]+\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Obtain model performance results"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy score: 0.9\n",
"F1 score: 0.885714285714\n"
]
}
],
"source": [
"# print some metrics results\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.metrics import f1_score\n",
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
"print(\"Accuracy score: \"+ str(acc_score))\n",
"f1=f1_score(iris_y_test, predicted_y_test, average='macro')\n",
"print(\"F1 score: \"+str(f1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Use Cross Validation"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.96666667 1. 0.9 0.86666667 1. ]\n"
]
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import cross_val_score # will be used to separate training and test\n",
"iris = load_iris()\n",
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=5,class_weight={0:1,1:1,2:1})\n",
"clf = clf.fit(iris.data, iris.target)\n",
"scores = cross_val_score(clf, iris.data, iris.target, cv=5) # score will be the accuracy\n",
"print(scores)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.96658312 1. 0.89769821 0.86666667 1. ]\n"
]
}
],
"source": [
"# computes F1- score\n",
"f1_scores = cross_val_score(clf, iris.data, iris.target, cv=5, scoring='f1_macro')\n",
"print(f1_scores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Show the resulting tree "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Print the picture in a PDF file"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"'my_iris_predictions.pdf'"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import graphviz \n",
"dot_data = tree.export_graphviz(clf, out_file=None) \n",
"graph = graphviz.Source(dot_data) \n",
"graph.render(\"my_iris_predictions\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Generate a picture here"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n",
"['setosa', 'versicolor', 'virginica']\n"
]
}
],
"source": [
"print(list(iris.feature_names))\n",
"print(list(iris.target_names))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
" -->\n",
"<!-- Title: Tree Pages: 1 -->\n",
"<svg width=\"653pt\" height=\"528pt\"\n",
" viewBox=\"0.00 0.00 652.83 528.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 524)\">\n",
"<title>Tree</title>\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-524 648.831,-524 648.831,4 -4,4\"/>\n",
"<!-- 0 -->\n",
"<g id=\"node1\" class=\"node\"><title>0</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M360.357,-520C360.357,-520 221.398,-520 221.398,-520 215.398,-520 209.398,-514 209.398,-508 209.398,-508 209.398,-454 209.398,-454 209.398,-448 215.398,-442 221.398,-442 221.398,-442 360.357,-442 360.357,-442 366.357,-442 372.357,-448 372.357,-454 372.357,-454 372.357,-508 372.357,-508 372.357,-514 366.357,-520 360.357,-520\"/>\n",
"<text text-anchor=\"start\" x=\"217.388\" y=\"-504.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\n",
"<text text-anchor=\"start\" x=\"242.035\" y=\"-490.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.585</text>\n",
"<text text-anchor=\"start\" x=\"245.155\" y=\"-476.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 150</text>\n",
"<text text-anchor=\"start\" x=\"231.138\" y=\"-462.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [50, 50, 50]</text>\n",
"<text text-anchor=\"start\" x=\"246.328\" y=\"-448.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node2\" class=\"node\"><title>1</title>\n",
"<path fill=\"#e58139\" stroke=\"black\" d=\"M260.784,-399C260.784,-399 164.971,-399 164.971,-399 158.971,-399 152.971,-393 152.971,-387 152.971,-387 152.971,-347 152.971,-347 152.971,-341 158.971,-335 164.971,-335 164.971,-335 260.784,-335 260.784,-335 266.784,-335 272.784,-341 272.784,-347 272.784,-347 272.784,-387 272.784,-387 272.784,-393 266.784,-399 260.784,-399\"/>\n",
"<text text-anchor=\"start\" x=\"171.821\" y=\"-383.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"171.048\" y=\"-369.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\n",
"<text text-anchor=\"start\" x=\"160.924\" y=\"-355.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [50, 0, 0]</text>\n",
"<text text-anchor=\"start\" x=\"168.328\" y=\"-341.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;1 -->\n",
"<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M264.252,-441.769C256.515,-430.66 248.053,-418.509 240.27,-407.333\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"243.061,-405.216 234.474,-399.01 237.317,-409.216 243.061,-405.216\"/>\n",
"<text text-anchor=\"middle\" x=\"230.07\" y=\"-419.419\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node3\" class=\"node\"><title>2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M436.396,-406C436.396,-406 303.359,-406 303.359,-406 297.359,-406 291.359,-400 291.359,-394 291.359,-394 291.359,-340 291.359,-340 291.359,-334 297.359,-328 303.359,-328 303.359,-328 436.396,-328 436.396,-328 442.396,-328 448.396,-334 448.396,-340 448.396,-340 448.396,-394 448.396,-394 448.396,-400 442.396,-406 436.396,-406\"/>\n",
"<text text-anchor=\"start\" x=\"299.119\" y=\"-390.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\n",
"<text text-anchor=\"start\" x=\"328.821\" y=\"-376.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.0</text>\n",
"<text text-anchor=\"start\" x=\"324.155\" y=\"-362.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 100</text>\n",
"<text text-anchor=\"start\" x=\"314.031\" y=\"-348.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 50]</text>\n",
"<text text-anchor=\"start\" x=\"316\" y=\"-334.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;2 -->\n",
"<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M317.844,-441.769C324.072,-432.939 330.765,-423.451 337.207,-414.318\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"340.129,-416.248 343.033,-406.058 334.409,-412.213 340.129,-416.248\"/>\n",
"<text text-anchor=\"middle\" x=\"347.289\" y=\"-426.494\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
"</g>\n",
"<!-- 3 -->\n",
"<g id=\"node4\" class=\"node\"><title>3</title>\n",
"<path fill=\"#39e581\" fill-opacity=\"0.898039\" stroke=\"black\" d=\"M349.357,-292C349.357,-292 210.398,-292 210.398,-292 204.398,-292 198.398,-286 198.398,-280 198.398,-280 198.398,-226 198.398,-226 198.398,-220 204.398,-214 210.398,-214 210.398,-214 349.357,-214 349.357,-214 355.357,-214 361.357,-220 361.357,-226 361.357,-226 361.357,-280 361.357,-280 361.357,-286 355.357,-292 349.357,-292\"/>\n",
"<text text-anchor=\"start\" x=\"206.388\" y=\"-276.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\n",
"<text text-anchor=\"start\" x=\"231.035\" y=\"-262.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.445</text>\n",
"<text text-anchor=\"start\" x=\"238.048\" y=\"-248.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 54</text>\n",
"<text text-anchor=\"start\" x=\"227.924\" y=\"-234.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 49, 5]</text>\n",
"<text text-anchor=\"start\" x=\"226\" y=\"-220.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;3 -->\n",
"<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M339.156,-327.769C331.988,-318.849 324.281,-309.257 316.872,-300.038\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"319.452,-297.661 310.46,-292.058 313.996,-302.046 319.452,-297.661\"/>\n",
"</g>\n",
"<!-- 8 -->\n",
"<g id=\"node9\" class=\"node\"><title>8</title>\n",
"<path fill=\"#8139e5\" fill-opacity=\"0.976471\" stroke=\"black\" d=\"M530.357,-292C530.357,-292 391.398,-292 391.398,-292 385.398,-292 379.398,-286 379.398,-280 379.398,-280 379.398,-226 379.398,-226 379.398,-220 385.398,-214 391.398,-214 391.398,-214 530.357,-214 530.357,-214 536.357,-214 542.357,-220 542.357,-226 542.357,-226 542.357,-280 542.357,-280 542.357,-286 536.357,-292 530.357,-292\"/>\n",
"<text text-anchor=\"start\" x=\"387.388\" y=\"-276.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\n",
"<text text-anchor=\"start\" x=\"412.035\" y=\"-262.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.151</text>\n",
"<text text-anchor=\"start\" x=\"419.048\" y=\"-248.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 46</text>\n",
"<text text-anchor=\"start\" x=\"408.924\" y=\"-234.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1, 45]</text>\n",
"<text text-anchor=\"start\" x=\"411.276\" y=\"-220.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;8 -->\n",
"<g id=\"edge8\" class=\"edge\"><title>2&#45;&gt;8</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M400.94,-327.769C408.188,-318.849 415.981,-309.257 423.471,-300.038\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"426.365,-302.027 429.955,-292.058 420.933,-297.613 426.365,-302.027\"/>\n",
"</g>\n",
"<!-- 4 -->\n",
"<g id=\"node5\" class=\"node\"><title>4</title>\n",
"<path fill=\"#39e581\" fill-opacity=\"0.980392\" stroke=\"black\" d=\"M203.967,-178C203.967,-178 61.7876,-178 61.7876,-178 55.7876,-178 49.7876,-172 49.7876,-166 49.7876,-166 49.7876,-112 49.7876,-112 49.7876,-106 55.7876,-100 61.7876,-100 61.7876,-100 203.967,-100 203.967,-100 209.967,-100 215.967,-106 215.967,-112 215.967,-112 215.967,-166 215.967,-166 215.967,-172 209.967,-178 203.967,-178\"/>\n",
"<text text-anchor=\"start\" x=\"57.8325\" y=\"-162.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sepal length (cm) ≤ 5.15</text>\n",
"<text text-anchor=\"start\" x=\"84.0347\" y=\"-148.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.146</text>\n",
"<text text-anchor=\"start\" x=\"91.0483\" y=\"-134.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 48</text>\n",
"<text text-anchor=\"start\" x=\"80.9243\" y=\"-120.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 47, 1]</text>\n",
"<text text-anchor=\"start\" x=\"79\" y=\"-106.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
"</g>\n",
"<!-- 3&#45;&gt;4 -->\n",
"<g id=\"edge4\" class=\"edge\"><title>3&#45;&gt;4</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M229.7,-213.769C217.282,-204.308 203.874,-194.092 191.108,-184.366\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"192.905,-181.335 182.829,-178.058 188.662,-186.903 192.905,-181.335\"/>\n",
"</g>\n",
"<!-- 7 -->\n",
"<g id=\"node8\" class=\"node\"><title>7</title>\n",
"<path fill=\"#8139e5\" fill-opacity=\"0.498039\" stroke=\"black\" d=\"M337.581,-171C337.581,-171 246.174,-171 246.174,-171 240.174,-171 234.174,-165 234.174,-159 234.174,-159 234.174,-119 234.174,-119 234.174,-113 240.174,-107 246.174,-107 246.174,-107 337.581,-107 337.581,-107 343.581,-107 349.581,-113 349.581,-119 349.581,-119 349.581,-159 349.581,-159 349.581,-165 343.581,-171 337.581,-171\"/>\n",
"<text text-anchor=\"start\" x=\"243.035\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\n",
"<text text-anchor=\"start\" x=\"253.941\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\n",
"<text text-anchor=\"start\" x=\"243.817\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 2, 4]</text>\n",
"<text text-anchor=\"start\" x=\"242.276\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
"</g>\n",
"<!-- 3&#45;&gt;7 -->\n",
"<g id=\"edge7\" class=\"edge\"><title>3&#45;&gt;7</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M283.974,-213.769C285.106,-203.204 286.338,-191.698 287.486,-180.983\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"290.97,-181.326 288.555,-171.01 284.009,-180.58 290.97,-181.326\"/>\n",
"</g>\n",
"<!-- 5 -->\n",
"<g id=\"node6\" class=\"node\"><title>5</title>\n",
"<path fill=\"#39e581\" fill-opacity=\"0.749020\" stroke=\"black\" d=\"M111.633,-64C111.633,-64 12.1223,-64 12.1223,-64 6.12232,-64 0.122316,-58 0.122316,-52 0.122316,-52 0.122316,-12 0.122316,-12 0.122316,-6 6.12232,-0 12.1223,-0 12.1223,-0 111.633,-0 111.633,-0 117.633,-0 123.633,-6 123.633,-12 123.633,-12 123.633,-52 123.633,-52 123.633,-58 117.633,-64 111.633,-64\"/>\n",
"<text text-anchor=\"start\" x=\"13.0347\" y=\"-48.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.722</text>\n",
"<text text-anchor=\"start\" x=\"23.9414\" y=\"-34.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\n",
"<text text-anchor=\"start\" x=\"13.8174\" y=\"-20.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 4, 1]</text>\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-6.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
"</g>\n",
"<!-- 4&#45;&gt;5 -->\n",
"<g id=\"edge5\" class=\"edge\"><title>4&#45;&gt;5</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M107.044,-99.7956C101.035,-90.9084 94.6251,-81.4296 88.579,-72.4883\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"91.4195,-70.4406 82.9186,-64.1172 85.6207,-74.3616 91.4195,-70.4406\"/>\n",
"</g>\n",
"<!-- 6 -->\n",
"<g id=\"node7\" class=\"node\"><title>6</title>\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M253.633,-64C253.633,-64 154.122,-64 154.122,-64 148.122,-64 142.122,-58 142.122,-52 142.122,-52 142.122,-12 142.122,-12 142.122,-6 148.122,-0 154.122,-0 154.122,-0 253.633,-0 253.633,-0 259.633,-0 265.633,-6 265.633,-12 265.633,-12 265.633,-52 265.633,-52 265.633,-58 259.633,-64 253.633,-64\"/>\n",
"<text text-anchor=\"start\" x=\"162.821\" y=\"-48.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"162.048\" y=\"-34.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\n",
"<text text-anchor=\"start\" x=\"151.924\" y=\"-20.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 43, 0]</text>\n",
"<text text-anchor=\"start\" x=\"150\" y=\"-6.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
"</g>\n",
"<!-- 4&#45;&gt;6 -->\n",
"<g id=\"edge6\" class=\"edge\"><title>4&#45;&gt;6</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M158.711,-99.7956C164.72,-90.9084 171.13,-81.4296 177.176,-72.4883\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"180.134,-74.3616 182.836,-64.1172 174.335,-70.4406 180.134,-74.3616\"/>\n",
"</g>\n",
"<!-- 9 -->\n",
"<g id=\"node10\" class=\"node\"><title>9</title>\n",
"<path fill=\"#8139e5\" fill-opacity=\"0.800000\" stroke=\"black\" d=\"M494.581,-171C494.581,-171 403.174,-171 403.174,-171 397.174,-171 391.174,-165 391.174,-159 391.174,-159 391.174,-119 391.174,-119 391.174,-113 397.174,-107 403.174,-107 403.174,-107 494.581,-107 494.581,-107 500.581,-107 506.581,-113 506.581,-119 506.581,-119 506.581,-159 506.581,-159 506.581,-165 500.581,-171 494.581,-171\"/>\n",
"<text text-anchor=\"start\" x=\"403.928\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.65</text>\n",
"<text text-anchor=\"start\" x=\"410.941\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\n",
"<text text-anchor=\"start\" x=\"400.817\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1, 5]</text>\n",
"<text text-anchor=\"start\" x=\"399.276\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
"</g>\n",
"<!-- 8&#45;&gt;9 -->\n",
"<g id=\"edge9\" class=\"edge\"><title>8&#45;&gt;9</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M456.781,-213.769C455.649,-203.204 454.416,-191.698 453.269,-180.983\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"456.745,-180.58 452.2,-171.01 449.785,-181.326 456.745,-180.58\"/>\n",
"</g>\n",
"<!-- 10 -->\n",
"<g id=\"node11\" class=\"node\"><title>10</title>\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M632.784,-171C632.784,-171 536.971,-171 536.971,-171 530.971,-171 524.971,-165 524.971,-159 524.971,-159 524.971,-119 524.971,-119 524.971,-113 530.971,-107 536.971,-107 536.971,-107 632.784,-107 632.784,-107 638.784,-107 644.784,-113 644.784,-119 644.784,-119 644.784,-159 644.784,-159 644.784,-165 638.784,-171 632.784,-171\"/>\n",
"<text text-anchor=\"start\" x=\"543.821\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"543.048\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 40</text>\n",
"<text text-anchor=\"start\" x=\"532.924\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 40]</text>\n",
"<text text-anchor=\"start\" x=\"535.276\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
"</g>\n",
"<!-- 8&#45;&gt;10 -->\n",
"<g id=\"edge10\" class=\"edge\"><title>8&#45;&gt;10</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M503.204,-213.769C515.987,-202.224 530.013,-189.555 542.781,-178.023\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"545.47,-180.31 550.545,-171.01 540.778,-175.116 545.47,-180.31\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.files.Source at 0x10e209fd0>"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names=iris.feature_names, \n",
" class_names=iris.target_names, \n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) \n",
"graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,24 @@
digraph Tree {
node [shape=box] ;
0 [label="X[2] <= 2.45\nentropy = 1.585\nsamples = 150\nvalue = [50, 50, 50]"] ;
1 [label="entropy = 0.0\nsamples = 50\nvalue = [50, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[3] <= 1.75\nentropy = 1.0\nsamples = 100\nvalue = [0, 50, 50]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="X[2] <= 4.95\nentropy = 0.445\nsamples = 54\nvalue = [0, 49, 5]"] ;
2 -> 3 ;
4 [label="X[0] <= 5.15\nentropy = 0.146\nsamples = 48\nvalue = [0, 47, 1]"] ;
3 -> 4 ;
5 [label="entropy = 0.722\nsamples = 5\nvalue = [0, 4, 1]"] ;
4 -> 5 ;
6 [label="entropy = 0.0\nsamples = 43\nvalue = [0, 43, 0]"] ;
4 -> 6 ;
7 [label="entropy = 0.918\nsamples = 6\nvalue = [0, 2, 4]"] ;
3 -> 7 ;
8 [label="X[2] <= 4.95\nentropy = 0.151\nsamples = 46\nvalue = [0, 1, 45]"] ;
2 -> 8 ;
9 [label="entropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]"] ;
8 -> 9 ;
10 [label="entropy = 0.0\nsamples = 40\nvalue = [0, 0, 40]"] ;
8 -> 10 ;
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 427 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 892 B

View file

@ -336,7 +336,83 @@ mai viste.
Permette di trasformare un sistema induttivo in deduttivo Permette di trasformare un sistema induttivo in deduttivo
** TODO Path Through hyp. space ** TODO Path Through hyp. space
Vedi che vuole sapere Vedi che vuole sapere
** Trees ** TODO Trees (manca ranking e regression trees)
I decision tree sono molto espressivi e corrispondono a proposizioni
logiche in DNF.
Per evitare l'overfitting bisogna introdurre scegliendo un linguaggio
restrittivo per le ipotesi e penalizzando la complessita` di ogni
ipotesi nella funzione target.
*** Feature tree
Nei feature tree ogni nodo interno e` segnato con una feature e ogni
arco con un literal.
L'insieme dei literals in un nodo e` chiamato ~split~.
Dalle foglie possiamo costruire un'espressione logica tramite
congiunzione dei literals risalendo alla root.
Il set di istanze coperto dall'espressione e` chiamato ~instance space
segment~.
Tree learners eseguono una ricerca top-down di tutti i concetti.
*** Algoritmo Grow Tree
Procedura generica
- Homogeneous: D → bool; true if hom. enough to be labelled with a
single label
- Label: D → label; most appropriate label for a set of instances
- BestSplit: D×F → set of literals; best set of literals to be put at the
root of the tree
#+BEGIN_SRC
Input: Dataset D, set of features F
if Homogeneous(D) then return Label(D)
S ← BestSplit(D, F)
split D in Dᵢ secondo i literals in S
foreach i do:
if Dᵢ ≠ ∅ then Tᵢ ← GrowTree(Dᵢ, F)
else Tᵢ is a leaf labelled with Label(D)
return tree whose root is labelled with S and whose children are Tᵢ
#+END_SRC
*** Purity
La bonta` di uno split e` determinata dalla purezza.
Per esempio nel caso di due classi ⊕ e ⊖, la purezza puo` essere
definita in termini di probabilita` empirica.
La purezza misura i figli negli alberi, in rule learning la purezza e`
di un solo figlio il literal e` true. Si possono usare le purity
measure degli alberi ma senza bisogno di fare la media.
In the case of classes:
| minority-class: min{p̣, 1-p̣}
| Gini-index: ∑p̣ᵢ(1-p̣ᵢ); expected error rate if examples on leaves were labelled randomly
| Entropy: -∑p̣ᵢ·log₂(p̣ᵢ)
Impurity of a set: $Imp(D_1, D_2, ..., D_l) = \sum_{j=1}^l
\frac{|D_j|}{|D|} Imp(D_j)$
*** Decision Trees
Separa il dataset in partizioni disgiunte usando l'objective function
(ogni partizione e` pura nel suo target attribute).
L'objective function misura la purezza delle partizioni ottenute dopo
lo split.
- Information of an event
I(E) = log₂(1/p)
Se un evento e` molto probabile (p≊1), l'informazione che ne ricaviamo e`
poca, e viceversa.
Se un esperimento ha n outcomes ognuno con probabilita` pᵢ la
quantita` di informazione media ricavata e` esattamente l'entropia:
| ∑pᵢlog₂(1/pᵢ) = -∑pᵢlog₂(pᵢ)
**** BestSplit-Class Algorithm
#+BEGIN_SRC
input: dataset D, set of features F
Iₘᵢₙ ← 1
foreach f∈F:
split D into subsets D₁,...,Dₗ secondo i valori υⱼ of f
if Imp({D₁, ..., Dₗ}) < Iₘᵢₙ:
Iₘᵢₙ ← Imp({D₁, ..., Dₗ})
f_{best} ← f
return f_{best} (feature f to split on)
#+END_SRC
Il best split minimizza l'impurita` dei subset D₁, ..., Dₗ.
*** TODO Ranking Trees
- Spazio diviso in segmenti
- Gli alberi possono diventare rankers se imparano un ordinamento per
i segmenti
- Le foglie devono essere ordinate
** Rules ** Rules
Ordered rules are a chain of /if-then-else/. Ordered rules are a chain of /if-then-else/.
#+BEGIN_SRC #+BEGIN_SRC
@ -344,6 +420,58 @@ Ordered rules are a chain of /if-then-else/.
2. Select the label as the rule consequent 2. Select the label as the rule consequent
3. Delete the instance segment from the data, restart from 1 3. Delete the instance segment from the data, restart from 1
#+END_SRC #+END_SRC
La purezza misura i figli negli alberi, in rule learning la purezza e`
di un solo figlio il literal e` true. Si possono usare le purity *** LearnRuleList
measure degli alberi ma senza bisogno di fare la media. learn an ordered list of rules
- LearnRuleList:
#+BEGIN_SRC
Input: Labelled training dataset D
R ← ∅
while D ≠ ∅ :
r ← LearnRule(D)
append r to end of R
D ← D \ {x∈D | x is covered by r}
return R
#+END_SRC
- LearnRule(D):
#+BEGIN_SRC
b ← true
L ← set of available literals
while not Homogeneous(D):
l ← BestLiteral(D,L)
b ← b ∧ l
D ← {x∈D | x is covered by b}
L ← L \ {l'∈L | l' uses same fetures as l}
C ← Label(D)
r ← if b then Class = C
return r
#+END_SRC
*** Unordered rules
Rules can also refer to the same class and we can collect them in a
rule set.
- LearnRuleSet(D):
#+BEGIN_SRC
Input: Labelled training data D
R ← ∅
for every class Cᵢ :
Dᵢ ← D
while Dᵢ contains examples of class Cᵢ:
r ← LearnRuleForClass(Dᵢ, Cᵢ)
R ← R {r}
Dᵢ ← Dᵢ \ {x∈Cᵢ | x is covered by r} ;; remove only positives
return R
#+END_SRC
- LearnRuleForClass(Dᵢ, Cᵢ):
Stesso che LearnRule(D) ma usa Cᵢ invece che C←Label(D).
Il problema con queste regole e` che si concentrano troppo sulla
purezza quando ci sono regole quasi pure che pero` non possono essere
generalizzate: usa lo smoothing.
- Laplace correction: $\dot{p}_i^+ = \frac{n_i^+ + 1}{n_i + 2}$
Solitamente rulesets hanno una performance di ranking maggiore (n
contro 2ⁿ istanze riconoscibili) ma possono restituire una curva di
coverage non convessa.
** TODO Subgroup discovery
I sottogruppi sono un subset dell'instance space la cui class
distribution e` differente da quella di D.
Mapping ĝ: X → C; D = (xᵢ, l(xᵢ))ⁱ

View file

@ -1,4 +1,4 @@
* Apprendimento Automatico [3/5] * Apprendimento Automatico [3/6]
- [X] Scrivile per date di esame - [X] Scrivile per date di esame
- [X] Richiedi date esame - [X] Richiedi date esame
- [ ] Slides [0/5] - [ ] Slides [0/5]
@ -26,6 +26,10 @@
+ [ ] (w_0,w_1) ortogonale all'iperpiano + [ ] (w_0,w_1) ortogonale all'iperpiano
+ [ ] dimostrazione dualita` grangiana + [ ] dimostrazione dualita` grangiana
+ [ ] Mercer condition + [ ] Mercer condition
- [ ] Meo [0/3]
+ [ ] Vedi bene gini index
+ [ ] Ranking e regression trees
+ [ ] subgroup discovery and ongoing
- [X] Esercizi [3/3] - [X] Esercizi [3/3]
- [X] es1: perche` min_impurity decrease - [X] es1: perche` min_impurity decrease
- [X] chiedi a Galla`, Marco e Naz quali sono tutti gli es - [X] chiedi a Galla`, Marco e Naz quali sono tutti gli es