UniTO/anno3/altro_muovi/marco/classification_iris_aa_19_20-checkpoint.ipynb

1811 lines
126 KiB
Text
Raw Normal View History

2020-06-17 20:01:41 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# # Classifiers introduction\n",
"\n",
"In the following program we introduce the basic steps of classification of a dataset in a matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the package for learning and modeling trees"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from sklearn import tree"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define the matrix containing the data (one example per row)\n",
"and the vector containing the corresponding target value"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"X = [[0, 0, 0], [1, 1, 1], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]\n",
"Y = [1, 0, 0, 0, 1, 1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Declare the classification model you want to use and then fit the model to the data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"clf = tree.DecisionTreeClassifier()\n",
"clf = clf.fit(X, Y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict the target value (and print it) for the passed data, using the fitted model currently in clf"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\n"
]
}
],
"source": [
"print(clf.predict([[0, 1, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1 0]\n"
]
}
],
"source": [
"print(clf.predict([[1, 0, 1],[0, 0, 1]]))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: Tree Pages: 1 -->\r\n",
"<svg width=\"480pt\" height=\"373pt\"\r\n",
" viewBox=\"0.00 0.00 480.00 373.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 369)\">\r\n",
"<title>Tree</title>\r\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-369 476,-369 476,4 -4,4\"/>\r\n",
"<!-- 0 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"254,-365 163,-365 163,-297 254,-297 254,-365\"/>\r\n",
"<text text-anchor=\"middle\" x=\"208.5\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"208.5\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"208.5\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 6</text>\r\n",
"<text text-anchor=\"middle\" x=\"208.5\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [3, 3]</text>\r\n",
"</g>\r\n",
"<!-- 1 -->\r\n",
"<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"200,-261 109,-261 109,-193 200,-193 200,-261\"/>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.444</text>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 1]</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;1 -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M190.968,-296.884C186.488,-288.422 181.61,-279.207 176.921,-270.352\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"179.901,-268.5 172.129,-261.299 173.715,-271.775 179.901,-268.5\"/>\r\n",
"<text text-anchor=\"middle\" x=\"164.777\" y=\"-281.494\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
"</g>\r\n",
"<!-- 6 -->\r\n",
"<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"309,-261 218,-261 218,-193 309,-193 309,-261\"/>\r\n",
"<text text-anchor=\"middle\" x=\"263.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[2] &lt;= 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"263.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.444</text>\r\n",
"<text text-anchor=\"middle\" x=\"263.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
"<text text-anchor=\"middle\" x=\"263.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 2]</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;6 -->\r\n",
"<g id=\"edge6\" class=\"edge\"><title>0&#45;&gt;6</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M226.357,-296.884C230.92,-288.422 235.888,-279.207 240.663,-270.352\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"243.879,-271.762 245.544,-261.299 237.718,-268.44 243.879,-271.762\"/>\r\n",
"<text text-anchor=\"middle\" x=\"252.714\" y=\"-281.549\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"91,-157 0,-157 0,-89 91,-89 91,-157\"/>\r\n",
"<text text-anchor=\"middle\" x=\"45.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[2] &lt;= 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"45.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"45.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
"<text text-anchor=\"middle\" x=\"45.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 1]</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;2 -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M119.111,-192.884C109.402,-183.798 98.7662,-173.845 88.673,-164.4\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"90.7777,-161.577 81.0846,-157.299 85.9948,-166.688 90.7777,-161.577\"/>\r\n",
"</g>\r\n",
"<!-- 5 -->\r\n",
"<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"200,-149.5 109,-149.5 109,-96.5 200,-96.5 200,-149.5\"/>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 0]</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;5 -->\r\n",
"<g id=\"edge5\" class=\"edge\"><title>1&#45;&gt;5</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M154.5,-192.884C154.5,-182.326 154.5,-170.597 154.5,-159.854\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"158,-159.52 154.5,-149.52 151,-159.52 158,-159.52\"/>\r\n",
"</g>\r\n",
"<!-- 3 -->\r\n",
"<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"91,-53 0,-53 0,-0 91,-0 91,-53\"/>\r\n",
"<text text-anchor=\"middle\" x=\"45.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
"<text text-anchor=\"middle\" x=\"45.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
"<text text-anchor=\"middle\" x=\"45.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;3 -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M45.5,-88.9485C45.5,-80.7153 45.5,-71.848 45.5,-63.4814\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"49.0001,-63.2367 45.5,-53.2367 42.0001,-63.2367 49.0001,-63.2367\"/>\r\n",
"</g>\r\n",
"<!-- 4 -->\r\n",
"<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"200,-53 109,-53 109,-0 200,-0 200,-53\"/>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
"<text text-anchor=\"middle\" x=\"154.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 0]</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;4 -->\r\n",
"<g id=\"edge4\" class=\"edge\"><title>2&#45;&gt;4</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M83.6229,-88.9485C94.4911,-79.526 106.317,-69.2731 117.139,-59.8906\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"119.551,-62.4319 124.814,-53.2367 114.966,-57.1428 119.551,-62.4319\"/>\r\n",
"</g>\r\n",
"<!-- 7 -->\r\n",
"<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"309,-149.5 218,-149.5 218,-96.5 309,-96.5 309,-149.5\"/>\r\n",
"<text text-anchor=\"middle\" x=\"263.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
"<text text-anchor=\"middle\" x=\"263.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
"<text text-anchor=\"middle\" x=\"263.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
"</g>\r\n",
"<!-- 6&#45;&gt;7 -->\r\n",
"<g id=\"edge7\" class=\"edge\"><title>6&#45;&gt;7</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M263.5,-192.884C263.5,-182.326 263.5,-170.597 263.5,-159.854\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"267,-159.52 263.5,-149.52 260,-159.52 267,-159.52\"/>\r\n",
"</g>\r\n",
"<!-- 8 -->\r\n",
"<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"418,-157 327,-157 327,-89 418,-89 418,-157\"/>\r\n",
"<text text-anchor=\"middle\" x=\"372.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"372.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"372.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
"<text text-anchor=\"middle\" x=\"372.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 1]</text>\r\n",
"</g>\r\n",
"<!-- 6&#45;&gt;8 -->\r\n",
"<g id=\"edge8\" class=\"edge\"><title>6&#45;&gt;8</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M298.889,-192.884C308.598,-183.798 319.234,-173.845 329.327,-164.4\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"332.005,-166.688 336.915,-157.299 327.222,-161.577 332.005,-166.688\"/>\r\n",
"</g>\r\n",
"<!-- 9 -->\r\n",
"<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"363,-53 272,-53 272,-0 363,-0 363,-53\"/>\r\n",
"<text text-anchor=\"middle\" x=\"317.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
"<text text-anchor=\"middle\" x=\"317.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
"<text text-anchor=\"middle\" x=\"317.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;9 -->\r\n",
"<g id=\"edge9\" class=\"edge\"><title>8&#45;&gt;9</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M353.264,-88.9485C348.206,-80.2579 342.736,-70.8608 337.633,-62.0917\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"340.534,-60.1189 332.479,-53.2367 334.484,-63.6401 340.534,-60.1189\"/>\r\n",
"</g>\r\n",
"<!-- 10 -->\r\n",
"<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"472,-53 381,-53 381,-0 472,-0 472,-53\"/>\r\n",
"<text text-anchor=\"middle\" x=\"426.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
"<text text-anchor=\"middle\" x=\"426.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
"<text text-anchor=\"middle\" x=\"426.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [1, 0]</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;10 -->\r\n",
"<g id=\"edge10\" class=\"edge\"><title>8&#45;&gt;10</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M391.387,-88.9485C396.353,-80.2579 401.722,-70.8608 406.733,-62.0917\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"409.871,-63.6557 411.793,-53.2367 403.793,-60.1826 409.871,-63.6557\"/>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<graphviz.files.Source at 0x1b867311fc8>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"os.environ[\"PATH\"] += os.pathsep + 'C:/Users/galat/.conda/envs/aaut/Library/bin/graphviz'\n",
"import graphviz\n",
"dot_data = tree.export_graphviz(clf, out_file=None) \n",
"graph = graphviz.Source(dot_data) \n",
"graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the following we start using a dataset (from UCI Machine Learning repository)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris\n",
"iris = load_iris()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Declare the type of prediction model and the working criteria for the model induction algorithm"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=5,class_weight={0:1,1:1,2:1})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Split the dataset in training and test set"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Generate a random permutation of the indices of examples that will be later used \n",
"# for the training and the test set\n",
"import numpy as np\n",
"np.random.seed(1231)\n",
"indices = np.random.permutation(len(iris.data))\n",
"\n",
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
"indices_training=indices[:-10]\n",
"indices_test=indices[-10:]\n",
"\n",
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
"iris_y_train = iris.target[indices_training]\n",
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
"iris_y_test = iris.target[indices_test]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fit the learning model on training set"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# fit the model to the training data\n",
"clf = clf.fit(iris_X_train, iris_y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Obtain predictions"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predictions:\n",
"[0 0 0 1 0 0 1 2 0 0]\n",
"True classes:\n",
"[0 0 0 2 0 0 1 1 0 0]\n",
"['setosa' 'versicolor' 'virginica']\n"
]
}
],
"source": [
"# apply fitted model \"clf\" to the test set \n",
"predicted_y_test = clf.predict(iris_X_test)\n",
"\n",
"# print the predictions (class numbers associated to classes names in target names)\n",
"print(\"Predictions:\")\n",
"print(predicted_y_test)\n",
"print(\"True classes:\")\n",
"print(iris_y_test) \n",
"print(iris.target_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Print the index of the test instances and the corresponding predictions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Look at the specific examples"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Instance # 33: \n",
"sepal length (cm)=5.5, sepal width (cm)=4.2, petal length (cm)=1.4, petal width (cm)=0.2\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 2: \n",
"sepal length (cm)=4.7, sepal width (cm)=3.2, petal length (cm)=1.3, petal width (cm)=0.2\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 11: \n",
"sepal length (cm)=4.8, sepal width (cm)=3.4, petal length (cm)=1.6, petal width (cm)=0.2\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 126: \n",
"sepal length (cm)=6.2, sepal width (cm)=2.8, petal length (cm)=4.8, petal width (cm)=1.8\n",
"Predicted: versicolor\t True: virginica\n",
"\n",
"Instance # 49: \n",
"sepal length (cm)=5.0, sepal width (cm)=3.3, petal length (cm)=1.4, petal width (cm)=0.2\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 10: \n",
"sepal length (cm)=5.4, sepal width (cm)=3.7, petal length (cm)=1.5, petal width (cm)=0.2\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 85: \n",
"sepal length (cm)=6.0, sepal width (cm)=3.4, petal length (cm)=4.5, petal width (cm)=1.6\n",
"Predicted: versicolor\t True: versicolor\n",
"\n",
"Instance # 52: \n",
"sepal length (cm)=6.9, sepal width (cm)=3.1, petal length (cm)=4.9, petal width (cm)=1.5\n",
"Predicted: virginica\t True: versicolor\n",
"\n",
"Instance # 5: \n",
"sepal length (cm)=5.4, sepal width (cm)=3.9, petal length (cm)=1.7, petal width (cm)=0.4\n",
"Predicted: setosa\t True: setosa\n",
"\n",
"Instance # 21: \n",
"sepal length (cm)=5.1, sepal width (cm)=3.7, petal length (cm)=1.5, petal width (cm)=0.4\n",
"Predicted: setosa\t True: setosa\n",
"\n"
]
}
],
"source": [
"for i in range(len(iris_y_test)): \n",
" print(\"Instance # \"+str(indices_test[i])+\": \")\n",
" s=\"\"\n",
" for j in range(len(iris.feature_names)):\n",
" s=s+iris.feature_names[j]+\"=\"+str(iris_X_test[i][j])\n",
" if (j<len(iris.feature_names)-1): s=s+\", \"\n",
" print(s)\n",
" print(\"Predicted: \"+iris.target_names[predicted_y_test[i]]+\"\\t True: \"+iris.target_names[iris_y_test[i]]+\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Obtain model performance results"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy score: 0.8\n",
"F1 score: 0.5\n"
]
}
],
"source": [
"# print some metrics results\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.metrics import f1_score\n",
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
"print(\"Accuracy score: \"+ str(acc_score))\n",
"f1=f1_score(iris_y_test, predicted_y_test, average='macro')\n",
"print(\"F1 score: \"+str(f1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Use Cross Validation"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.96666667 1. 0.86666667 0.86666667 1. ]\n"
]
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import cross_val_score # will be used to separate training and test\n",
"iris = load_iris()\n",
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=5,class_weight={0:1,1:1,2:1})\n",
"clf = clf.fit(iris.data, iris.target)\n",
"scores = cross_val_score(clf, iris.data, iris.target, cv=5) # score will be the accuracy\n",
"print(scores)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.96658312 1. 0.86111111 0.86666667 1. ]\n"
]
}
],
"source": [
"# computes F1- score\n",
"f1_scores = cross_val_score(clf, iris.data, iris.target, cv=5, scoring='f1_macro')\n",
"print(f1_scores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Show the resulting tree "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Print the picture in a PDF file"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"'my_iris_predictions.pdf'"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import graphviz \n",
"dot_data = tree.export_graphviz(clf, out_file=None) \n",
"graph = graphviz.Source(dot_data) \n",
"graph.render(\"my_iris_predictions\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Generate a picture here"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n",
"['setosa', 'versicolor', 'virginica']\n"
]
}
],
"source": [
"print(list(iris.feature_names))\n",
"print(list(iris.target_names))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: Tree Pages: 1 -->\r\n",
"<svg width=\"660pt\" height=\"552pt\"\r\n",
" viewBox=\"0.00 0.00 660.00 552.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 548)\">\r\n",
"<title>Tree</title>\r\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-548 656,-548 656,4 -4,4\"/>\r\n",
"<!-- 0 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
"<path fill=\"#ffffff\" stroke=\"black\" d=\"M366,-544C366,-544 225,-544 225,-544 219,-544 213,-538 213,-532 213,-532 213,-473 213,-473 213,-467 219,-461 225,-461 225,-461 366,-461 366,-461 372,-461 378,-467 378,-473 378,-473 378,-532 378,-532 378,-538 372,-544 366,-544\"/>\r\n",
"<text text-anchor=\"start\" x=\"221\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\r\n",
"<text text-anchor=\"start\" x=\"245.5\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.585</text>\r\n",
"<text text-anchor=\"start\" x=\"248\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 150</text>\r\n",
"<text text-anchor=\"start\" x=\"235\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [50, 50, 50]</text>\r\n",
"<text text-anchor=\"start\" x=\"249.5\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\r\n",
"</g>\r\n",
"<!-- 1 -->\r\n",
"<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
"<path fill=\"#e58139\" stroke=\"black\" d=\"M265,-417.5C265,-417.5 168,-417.5 168,-417.5 162,-417.5 156,-411.5 156,-405.5 156,-405.5 156,-361.5 156,-361.5 156,-355.5 162,-349.5 168,-349.5 168,-349.5 265,-349.5 265,-349.5 271,-349.5 277,-355.5 277,-361.5 277,-361.5 277,-405.5 277,-405.5 277,-411.5 271,-417.5 265,-417.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"174.5\" y=\"-402.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"173\" y=\"-387.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\r\n",
"<text text-anchor=\"start\" x=\"164\" y=\"-372.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [50, 0, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"170.5\" y=\"-357.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;1 -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M268.091,-460.907C260.492,-449.652 252.231,-437.418 244.593,-426.106\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"247.391,-423.996 238.895,-417.667 241.59,-427.913 247.391,-423.996\"/>\r\n",
"<text text-anchor=\"middle\" x=\"234.136\" y=\"-438.51\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
"<path fill=\"#ffffff\" stroke=\"black\" d=\"M442,-425C442,-425 307,-425 307,-425 301,-425 295,-419 295,-413 295,-413 295,-354 295,-354 295,-348 301,-342 307,-342 307,-342 442,-342 442,-342 448,-342 454,-348 454,-354 454,-354 454,-413 454,-413 454,-419 448,-425 442,-425\"/>\r\n",
"<text text-anchor=\"start\" x=\"303\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\r\n",
"<text text-anchor=\"start\" x=\"332.5\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.0</text>\r\n",
"<text text-anchor=\"start\" x=\"327\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 100</text>\r\n",
"<text text-anchor=\"start\" x=\"318\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 50]</text>\r\n",
"<text text-anchor=\"start\" x=\"319\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;2 -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M322.909,-460.907C328.914,-452.014 335.331,-442.509 341.529,-433.331\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"344.444,-435.267 347.14,-425.021 338.643,-431.35 344.444,-435.267\"/>\r\n",
"<text text-anchor=\"middle\" x=\"351.898\" y=\"-445.864\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
"</g>\r\n",
"<!-- 3 -->\r\n",
"<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
"<path fill=\"#4de88e\" stroke=\"black\" d=\"M354,-306C354,-306 213,-306 213,-306 207,-306 201,-300 201,-294 201,-294 201,-235 201,-235 201,-229 207,-223 213,-223 213,-223 354,-223 354,-223 360,-223 366,-229 366,-235 366,-235 366,-294 366,-294 366,-300 360,-306 354,-306\"/>\r\n",
"<text text-anchor=\"start\" x=\"209\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\r\n",
"<text text-anchor=\"start\" x=\"233.5\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.445</text>\r\n",
"<text text-anchor=\"start\" x=\"240\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 54</text>\r\n",
"<text text-anchor=\"start\" x=\"231\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 49, 5]</text>\r\n",
"<text text-anchor=\"start\" x=\"228\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;3 -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M342.928,-341.907C335.94,-332.923 328.467,-323.315 321.261,-314.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"323.918,-311.766 315.016,-306.021 318.393,-316.063 323.918,-311.766\"/>\r\n",
"</g>\r\n",
"<!-- 8 -->\r\n",
"<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
"<path fill=\"#843de6\" stroke=\"black\" d=\"M537,-306C537,-306 396,-306 396,-306 390,-306 384,-300 384,-294 384,-294 384,-235 384,-235 384,-229 390,-223 396,-223 396,-223 537,-223 537,-223 543,-223 549,-229 549,-235 549,-235 549,-294 549,-294 549,-300 543,-306 537,-306\"/>\r\n",
"<text text-anchor=\"start\" x=\"392\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\r\n",
"<text text-anchor=\"start\" x=\"416.5\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.151</text>\r\n",
"<text text-anchor=\"start\" x=\"423\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 46</text>\r\n",
"<text text-anchor=\"start\" x=\"414\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1, 45]</text>\r\n",
"<text text-anchor=\"start\" x=\"416.5\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;8 -->\r\n",
"<g id=\"edge8\" class=\"edge\"><title>2&#45;&gt;8</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M406.419,-341.907C413.484,-332.923 421.039,-323.315 428.324,-314.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"431.208,-316.045 434.637,-306.021 425.705,-311.718 431.208,-316.045\"/>\r\n",
"</g>\r\n",
"<!-- 4 -->\r\n",
"<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
"<path fill=\"#3de684\" stroke=\"black\" d=\"M208,-187C208,-187 63,-187 63,-187 57,-187 51,-181 51,-175 51,-175 51,-116 51,-116 51,-110 57,-104 63,-104 63,-104 208,-104 208,-104 214,-104 220,-110 220,-116 220,-116 220,-175 220,-175 220,-181 214,-187 208,-187\"/>\r\n",
"<text text-anchor=\"start\" x=\"59\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sepal length (cm) ≤ 5.15</text>\r\n",
"<text text-anchor=\"start\" x=\"85.5\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.146</text>\r\n",
"<text text-anchor=\"start\" x=\"92\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 48</text>\r\n",
"<text text-anchor=\"start\" x=\"83\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 47, 1]</text>\r\n",
"<text text-anchor=\"start\" x=\"80\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 3&#45;&gt;4 -->\r\n",
"<g id=\"edge4\" class=\"edge\"><title>3&#45;&gt;4</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M232.152,-222.907C220.099,-213.379 207.158,-203.148 194.788,-193.37\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"196.772,-190.477 186.757,-187.021 192.431,-195.968 196.772,-190.477\"/>\r\n",
"</g>\r\n",
"<!-- 7 -->\r\n",
"<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M342.5,-179.5C342.5,-179.5 250.5,-179.5 250.5,-179.5 244.5,-179.5 238.5,-173.5 238.5,-167.5 238.5,-167.5 238.5,-123.5 238.5,-123.5 238.5,-117.5 244.5,-111.5 250.5,-111.5 250.5,-111.5 342.5,-111.5 342.5,-111.5 348.5,-111.5 354.5,-117.5 354.5,-123.5 354.5,-123.5 354.5,-167.5 354.5,-167.5 354.5,-173.5 348.5,-179.5 342.5,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"246.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\r\n",
"<text text-anchor=\"start\" x=\"257\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\r\n",
"<text text-anchor=\"start\" x=\"248\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 2, 4]</text>\r\n",
"<text text-anchor=\"start\" x=\"246.5\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 3&#45;&gt;7 -->\r\n",
"<g id=\"edge7\" class=\"edge\"><title>3&#45;&gt;7</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M288.01,-222.907C289.2,-212.204 290.487,-200.615 291.692,-189.776\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"295.189,-189.992 292.815,-179.667 288.232,-189.219 295.189,-189.992\"/>\r\n",
"</g>\r\n",
"<!-- 5 -->\r\n",
"<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
"<path fill=\"#6aeca0\" stroke=\"black\" d=\"M115,-68C115,-68 12,-68 12,-68 6,-68 -7.10543e-015,-62 -7.10543e-015,-56 -7.10543e-015,-56 -7.10543e-015,-12 -7.10543e-015,-12 -7.10543e-015,-6 6,-0 12,-0 12,-0 115,-0 115,-0 121,-0 127,-6 127,-12 127,-12 127,-56 127,-56 127,-62 121,-68 115,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"13.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.722</text>\r\n",
"<text text-anchor=\"start\" x=\"24\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\r\n",
"<text text-anchor=\"start\" x=\"15\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 4, 1]</text>\r\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 4&#45;&gt;5 -->\r\n",
"<g id=\"edge5\" class=\"edge\"><title>4&#45;&gt;5</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M108.69,-103.726C102.932,-94.9703 96.8391,-85.7032 91.054,-76.9051\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"93.8142,-74.7322 85.3956,-68.2996 87.9653,-78.5781 93.8142,-74.7322\"/>\r\n",
"</g>\r\n",
"<!-- 6 -->\r\n",
"<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M260,-68C260,-68 157,-68 157,-68 151,-68 145,-62 145,-56 145,-56 145,-12 145,-12 145,-6 151,-0 157,-0 157,-0 260,-0 260,-0 266,-0 272,-6 272,-12 272,-12 272,-56 272,-56 272,-62 266,-68 260,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"166.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"165\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\r\n",
"<text text-anchor=\"start\" x=\"156\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 43, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"153\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 4&#45;&gt;6 -->\r\n",
"<g id=\"edge6\" class=\"edge\"><title>4&#45;&gt;6</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M162.683,-103.726C168.581,-94.879 174.827,-85.51 180.746,-76.6303\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"183.665,-78.5616 186.3,-68.2996 177.841,-74.6787 183.665,-78.5616\"/>\r\n",
"</g>\r\n",
"<!-- 9 -->\r\n",
"<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
"<path fill=\"#9a61ea\" stroke=\"black\" d=\"M500.5,-179.5C500.5,-179.5 408.5,-179.5 408.5,-179.5 402.5,-179.5 396.5,-173.5 396.5,-167.5 396.5,-167.5 396.5,-123.5 396.5,-123.5 396.5,-117.5 402.5,-111.5 408.5,-111.5 408.5,-111.5 500.5,-111.5 500.5,-111.5 506.5,-111.5 512.5,-117.5 512.5,-123.5 512.5,-123.5 512.5,-167.5 512.5,-167.5 512.5,-173.5 506.5,-179.5 500.5,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"408.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.65</text>\r\n",
"<text text-anchor=\"start\" x=\"415\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\r\n",
"<text text-anchor=\"start\" x=\"406\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1, 5]</text>\r\n",
"<text text-anchor=\"start\" x=\"404.5\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;9 -->\r\n",
"<g id=\"edge9\" class=\"edge\"><title>8&#45;&gt;9</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M462.337,-222.907C461.239,-212.204 460.05,-200.615 458.939,-189.776\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"462.404,-189.258 457.902,-179.667 455.44,-189.972 462.404,-189.258\"/>\r\n",
"</g>\r\n",
"<!-- 10 -->\r\n",
"<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M640,-179.5C640,-179.5 543,-179.5 543,-179.5 537,-179.5 531,-173.5 531,-167.5 531,-167.5 531,-123.5 531,-123.5 531,-117.5 537,-111.5 543,-111.5 543,-111.5 640,-111.5 640,-111.5 646,-111.5 652,-117.5 652,-123.5 652,-123.5 652,-167.5 652,-167.5 652,-173.5 646,-179.5 640,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"549.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"548\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 40</text>\r\n",
"<text text-anchor=\"start\" x=\"539\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 40]</text>\r\n",
"<text text-anchor=\"start\" x=\"541.5\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;10 -->\r\n",
"<g id=\"edge10\" class=\"edge\"><title>8&#45;&gt;10</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M509.868,-222.907C522.364,-211.211 535.99,-198.457 548.466,-186.78\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"551.156,-189.056 556.065,-179.667 546.373,-183.945 551.156,-189.056\"/>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<graphviz.files.Source at 0x1b86722ec08>"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names=iris.feature_names, \n",
" class_names=iris.target_names, \n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) \n",
"graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Artificial inflation"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Generate a random permutation of the indices of examples that will be later used \n",
"# for the training and the test set\n",
"import numpy as np\n",
"np.random.seed(1231)\n",
"indices = np.random.permutation(len(iris.data))\n",
"\n",
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
"indices_training=indices[:-10]\n",
"indices_test=indices[-10:]\n",
"\n",
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
"iris_y_train = iris.target[indices_training]\n",
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
"iris_y_test = iris.target[indices_test]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"samples_x = []\n",
"samples_y = []\n",
"for i in range(0, len(iris_y_train)):\n",
" if iris_y_train[i] == 1:\n",
" for _ in range(9):\n",
" samples_x.append(iris_X_train[i])\n",
" samples_y.append(1)\n",
" elif iris_y_train[i] == 2:\n",
" for _ in range(9):\n",
" samples_x.append(iris_X_train[i])\n",
" samples_y.append(2)\n",
"\n",
"#Samples inflation\n",
"iris_X_train = np.append(iris_X_train, samples_x, axis = 0)\n",
"iris_y_train = np.append(iris_y_train, samples_y, axis = 0)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 1.0\n",
"F1: 1.0\n"
]
}
],
"source": [
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=10,class_weight={0:1,1:1,2:1})\n",
"clf = clf.fit(iris_X_train, iris_y_train)\n",
"predicted_y_test = clf.predict(iris_X_test)\n",
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
"f1 = f1_score(iris_y_test, predicted_y_test, average='macro')\n",
"print(\"Accuracy: \", acc_score)\n",
"print(\"F1: \", f1)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: Tree Pages: 1 -->\r\n",
"<svg width=\"641pt\" height=\"671pt\"\r\n",
" viewBox=\"0.00 0.00 641.00 671.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 667)\">\r\n",
"<title>Tree</title>\r\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-667 637,-667 637,4 -4,4\"/>\r\n",
"<!-- 0 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
"<path fill=\"#fdfbff\" stroke=\"black\" d=\"M384,-663C384,-663 243,-663 243,-663 237,-663 231,-657 231,-651 231,-651 231,-592 231,-592 231,-586 237,-580 243,-580 243,-580 384,-580 384,-580 390,-580 396,-586 396,-592 396,-592 396,-651 396,-651 396,-657 390,-663 384,-663\"/>\r\n",
"<text text-anchor=\"start\" x=\"239\" y=\"-647.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.85</text>\r\n",
"<text text-anchor=\"start\" x=\"263.5\" y=\"-632.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.211</text>\r\n",
"<text text-anchor=\"start\" x=\"262\" y=\"-617.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 1013</text>\r\n",
"<text text-anchor=\"start\" x=\"244.5\" y=\"-602.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 480, 490]</text>\r\n",
"<text text-anchor=\"start\" x=\"263.5\" y=\"-587.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 1 -->\r\n",
"<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
"<path fill=\"#54e892\" stroke=\"black\" d=\"M294,-544C294,-544 153,-544 153,-544 147,-544 141,-538 141,-532 141,-532 141,-473 141,-473 141,-467 147,-461 153,-461 153,-461 294,-461 294,-461 300,-461 306,-467 306,-473 306,-473 306,-532 306,-532 306,-538 300,-544 294,-544\"/>\r\n",
"<text text-anchor=\"start\" x=\"149\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\r\n",
"<text text-anchor=\"start\" x=\"173.5\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.648</text>\r\n",
"<text text-anchor=\"start\" x=\"176\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 513</text>\r\n",
"<text text-anchor=\"start\" x=\"158.5\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 450, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"168\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;1 -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M282.275,-579.907C275.364,-570.923 267.973,-561.315 260.846,-552.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"263.541,-549.813 254.67,-544.021 257.993,-554.081 263.541,-549.813\"/>\r\n",
"<text text-anchor=\"middle\" x=\"251.436\" y=\"-565.111\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
"</g>\r\n",
"<!-- 8 -->\r\n",
"<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
"<path fill=\"#8946e7\" stroke=\"black\" d=\"M471,-544C471,-544 336,-544 336,-544 330,-544 324,-538 324,-532 324,-532 324,-473 324,-473 324,-467 330,-461 336,-461 336,-461 471,-461 471,-461 477,-461 483,-467 483,-473 483,-473 483,-532 483,-532 483,-538 477,-544 471,-544\"/>\r\n",
"<text text-anchor=\"start\" x=\"332\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\r\n",
"<text text-anchor=\"start\" x=\"353.5\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.327</text>\r\n",
"<text text-anchor=\"start\" x=\"356\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 500</text>\r\n",
"<text text-anchor=\"start\" x=\"343\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 470]</text>\r\n",
"<text text-anchor=\"start\" x=\"353.5\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;8 -->\r\n",
"<g id=\"edge8\" class=\"edge\"><title>0&#45;&gt;8</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M344.725,-579.907C351.636,-570.923 359.027,-561.315 366.154,-552.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"369.007,-554.081 372.33,-544.021 363.459,-549.813 369.007,-554.081\"/>\r\n",
"<text text-anchor=\"middle\" x=\"375.564\" y=\"-565.111\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
"<path fill=\"#e58139\" stroke=\"black\" d=\"M114,-417.5C114,-417.5 17,-417.5 17,-417.5 11,-417.5 5,-411.5 5,-405.5 5,-405.5 5,-361.5 5,-361.5 5,-355.5 11,-349.5 17,-349.5 17,-349.5 114,-349.5 114,-349.5 120,-349.5 126,-355.5 126,-361.5 126,-361.5 126,-405.5 126,-405.5 126,-411.5 120,-417.5 114,-417.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"23.5\" y=\"-402.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"22\" y=\"-387.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\r\n",
"<text text-anchor=\"start\" x=\"13\" y=\"-372.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 0, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"19.5\" y=\"-357.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;2 -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M168.683,-460.907C152.44,-448.88 134.689,-435.735 118.559,-423.791\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"120.409,-420.805 110.29,-417.667 116.243,-426.431 120.409,-420.805\"/>\r\n",
"</g>\r\n",
"<!-- 3 -->\r\n",
"<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
"<path fill=\"#42e687\" stroke=\"black\" d=\"M291,-425C291,-425 156,-425 156,-425 150,-425 144,-419 144,-413 144,-413 144,-354 144,-354 144,-348 150,-342 156,-342 156,-342 291,-342 291,-342 297,-342 303,-348 303,-354 303,-354 303,-413 303,-413 303,-419 297,-425 291,-425\"/>\r\n",
"<text text-anchor=\"start\" x=\"152\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.65</text>\r\n",
"<text text-anchor=\"start\" x=\"173.5\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.254</text>\r\n",
"<text text-anchor=\"start\" x=\"176\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 470</text>\r\n",
"<text text-anchor=\"start\" x=\"163\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 450, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"168\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;3 -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>1&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M223.5,-460.907C223.5,-452.649 223.5,-443.864 223.5,-435.302\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"227,-435.021 223.5,-425.021 220,-435.021 227,-435.021\"/>\r\n",
"</g>\r\n",
"<!-- 4 -->\r\n",
"<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M117,-298.5C117,-298.5 12,-298.5 12,-298.5 6,-298.5 0,-292.5 0,-286.5 0,-286.5 0,-242.5 0,-242.5 0,-236.5 6,-230.5 12,-230.5 12,-230.5 117,-230.5 117,-230.5 123,-230.5 129,-236.5 129,-242.5 129,-242.5 129,-286.5 129,-286.5 129,-292.5 123,-298.5 117,-298.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"22.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"17\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 440</text>\r\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 440, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"9\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 3&#45;&gt;4 -->\r\n",
"<g id=\"edge4\" class=\"edge\"><title>3&#45;&gt;4</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M168.336,-341.907C151.991,-329.88 134.126,-316.735 117.895,-304.791\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"119.702,-301.775 109.573,-298.667 115.553,-307.413 119.702,-301.775\"/>\r\n",
"</g>\r\n",
"<!-- 5 -->\r\n",
"<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M290,-306C290,-306 159,-306 159,-306 153,-306 147,-300 147,-294 147,-294 147,-235 147,-235 147,-229 153,-223 159,-223 159,-223 290,-223 290,-223 296,-223 302,-229 302,-235 302,-235 302,-294 302,-294 302,-300 296,-306 290,-306\"/>\r\n",
"<text text-anchor=\"start\" x=\"155\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sepal width (cm) ≤ 3.1</text>\r\n",
"<text text-anchor=\"start\" x=\"174.5\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\r\n",
"<text text-anchor=\"start\" x=\"181\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 30</text>\r\n",
"<text text-anchor=\"start\" x=\"168\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"174.5\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 3&#45;&gt;5 -->\r\n",
"<g id=\"edge5\" class=\"edge\"><title>3&#45;&gt;5</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M223.847,-341.907C223.918,-333.649 223.993,-324.864 224.066,-316.302\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"227.568,-316.05 224.154,-306.021 220.568,-315.99 227.568,-316.05\"/>\r\n",
"</g>\r\n",
"<!-- 6 -->\r\n",
"<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M134,-179.5C134,-179.5 37,-179.5 37,-179.5 31,-179.5 25,-173.5 25,-167.5 25,-167.5 25,-123.5 25,-123.5 25,-117.5 31,-111.5 37,-111.5 37,-111.5 134,-111.5 134,-111.5 140,-111.5 146,-117.5 146,-123.5 146,-123.5 146,-167.5 146,-167.5 146,-173.5 140,-179.5 134,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"43.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"42\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 20</text>\r\n",
"<text text-anchor=\"start\" x=\"33\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"35.5\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 5&#45;&gt;6 -->\r\n",
"<g id=\"edge6\" class=\"edge\"><title>5&#45;&gt;6</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M176.275,-222.907C162.117,-210.99 146.655,-197.976 132.57,-186.12\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"134.808,-183.429 124.903,-179.667 130.3,-188.784 134.808,-183.429\"/>\r\n",
"</g>\r\n",
"<!-- 7 -->\r\n",
"<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M279,-179.5C279,-179.5 176,-179.5 176,-179.5 170,-179.5 164,-173.5 164,-167.5 164,-167.5 164,-123.5 164,-123.5 164,-117.5 170,-111.5 176,-111.5 176,-111.5 279,-111.5 279,-111.5 285,-111.5 291,-117.5 291,-123.5 291,-123.5 291,-167.5 291,-167.5 291,-173.5 285,-179.5 279,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"185.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"184\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 10</text>\r\n",
"<text text-anchor=\"start\" x=\"175\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"172\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 5&#45;&gt;7 -->\r\n",
"<g id=\"edge7\" class=\"edge\"><title>5&#45;&gt;7</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M225.541,-222.907C225.815,-212.204 226.112,-200.615 226.39,-189.776\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"229.892,-189.753 226.65,-179.667 222.894,-189.574 229.892,-189.753\"/>\r\n",
"</g>\r\n",
"<!-- 9 -->\r\n",
"<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
"<path fill=\"#e0cef8\" stroke=\"black\" d=\"M474,-425C474,-425 333,-425 333,-425 327,-425 321,-419 321,-413 321,-413 321,-354 321,-354 321,-348 327,-342 333,-342 333,-342 474,-342 474,-342 480,-342 486,-348 486,-354 486,-354 486,-413 486,-413 486,-419 480,-425 474,-425\"/>\r\n",
"<text text-anchor=\"start\" x=\"329\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 5.35</text>\r\n",
"<text text-anchor=\"start\" x=\"353.5\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.985</text>\r\n",
"<text text-anchor=\"start\" x=\"360\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 70</text>\r\n",
"<text text-anchor=\"start\" x=\"347\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 40]</text>\r\n",
"<text text-anchor=\"start\" x=\"353.5\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;9 -->\r\n",
"<g id=\"edge9\" class=\"edge\"><title>8&#45;&gt;9</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M403.5,-460.907C403.5,-452.649 403.5,-443.864 403.5,-435.302\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"407,-435.021 403.5,-425.021 400,-435.021 407,-435.021\"/>\r\n",
"</g>\r\n",
"<!-- 16 -->\r\n",
"<g id=\"node17\" class=\"node\"><title>16</title>\r\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M621,-417.5C621,-417.5 516,-417.5 516,-417.5 510,-417.5 504,-411.5 504,-405.5 504,-405.5 504,-361.5 504,-361.5 504,-355.5 510,-349.5 516,-349.5 516,-349.5 621,-349.5 621,-349.5 627,-349.5 633,-355.5 633,-361.5 633,-361.5 633,-405.5 633,-405.5 633,-411.5 627,-417.5 621,-417.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"526.5\" y=\"-402.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"521\" y=\"-387.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 430</text>\r\n",
"<text text-anchor=\"start\" x=\"512\" y=\"-372.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 430]</text>\r\n",
"<text text-anchor=\"start\" x=\"518.5\" y=\"-357.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;16 -->\r\n",
"<g id=\"edge16\" class=\"edge\"><title>8&#45;&gt;16</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M460.746,-460.907C477.864,-448.769 496.586,-435.493 513.553,-423.462\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"515.593,-426.306 521.726,-417.667 511.544,-420.596 515.593,-426.306\"/>\r\n",
"</g>\r\n",
"<!-- 10 -->\r\n",
"<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
"<path fill=\"#bdf6d5\" stroke=\"black\" d=\"M469,-306C469,-306 334,-306 334,-306 328,-306 322,-300 322,-294 322,-294 322,-235 322,-235 322,-229 328,-223 334,-223 334,-223 469,-223 469,-223 475,-223 481,-229 481,-235 481,-235 481,-294 481,-294 481,-300 475,-306 469,-306\"/>\r\n",
"<text text-anchor=\"start\" x=\"330\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.55</text>\r\n",
"<text text-anchor=\"start\" x=\"351.5\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.971</text>\r\n",
"<text text-anchor=\"start\" x=\"358\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\r\n",
"<text text-anchor=\"start\" x=\"345\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"346\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 9&#45;&gt;10 -->\r\n",
"<g id=\"edge10\" class=\"edge\"><title>9&#45;&gt;10</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M402.806,-341.907C402.663,-333.558 402.511,-324.671 402.364,-316.02\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"405.863,-315.959 402.193,-306.021 398.864,-316.079 405.863,-315.959\"/>\r\n",
"</g>\r\n",
"<!-- 15 -->\r\n",
"<g id=\"node16\" class=\"node\"><title>15</title>\r\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M608,-298.5C608,-298.5 511,-298.5 511,-298.5 505,-298.5 499,-292.5 499,-286.5 499,-286.5 499,-242.5 499,-242.5 499,-236.5 505,-230.5 511,-230.5 511,-230.5 608,-230.5 608,-230.5 614,-230.5 620,-236.5 620,-242.5 620,-242.5 620,-286.5 620,-286.5 620,-292.5 614,-298.5 608,-298.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"517.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"516\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 20</text>\r\n",
"<text text-anchor=\"start\" x=\"507\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"509.5\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 9&#45;&gt;15 -->\r\n",
"<g id=\"edge15\" class=\"edge\"><title>9&#45;&gt;15</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M457.623,-341.907C473.66,-329.88 491.187,-316.735 507.112,-304.791\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"509.377,-307.467 515.277,-298.667 505.177,-301.867 509.377,-307.467\"/>\r\n",
"</g>\r\n",
"<!-- 11 -->\r\n",
"<g id=\"node12\" class=\"node\"><title>11</title>\r\n",
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M467,-187C467,-187 326,-187 326,-187 320,-187 314,-181 314,-175 314,-175 314,-116 314,-116 314,-110 320,-104 326,-104 326,-104 467,-104 467,-104 473,-104 479,-110 479,-116 479,-116 479,-175 479,-175 479,-181 473,-187 467,-187\"/>\r\n",
"<text text-anchor=\"start\" x=\"322\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\r\n",
"<text text-anchor=\"start\" x=\"346.5\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\r\n",
"<text text-anchor=\"start\" x=\"353\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 30</text>\r\n",
"<text text-anchor=\"start\" x=\"340\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"346.5\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 10&#45;&gt;11 -->\r\n",
"<g id=\"edge11\" class=\"edge\"><title>10&#45;&gt;11</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M399.765,-222.907C399.408,-214.558 399.029,-205.671 398.659,-197.02\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"402.156,-196.862 398.232,-187.021 395.162,-197.161 402.156,-196.862\"/>\r\n",
"</g>\r\n",
"<!-- 14 -->\r\n",
"<g id=\"node15\" class=\"node\"><title>14</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M612,-179.5C612,-179.5 509,-179.5 509,-179.5 503,-179.5 497,-173.5 497,-167.5 497,-167.5 497,-123.5 497,-123.5 497,-117.5 503,-111.5 509,-111.5 509,-111.5 612,-111.5 612,-111.5 618,-111.5 624,-117.5 624,-123.5 624,-123.5 624,-167.5 624,-167.5 624,-173.5 618,-179.5 612,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"518.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"517\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 20</text>\r\n",
"<text text-anchor=\"start\" x=\"508\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 20, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"505\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 10&#45;&gt;14 -->\r\n",
"<g id=\"edge14\" class=\"edge\"><title>10&#45;&gt;14</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M456.664,-222.907C473.009,-210.88 490.874,-197.735 507.105,-185.791\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"509.447,-188.413 515.427,-179.667 505.298,-182.775 509.447,-188.413\"/>\r\n",
"</g>\r\n",
"<!-- 12 -->\r\n",
"<g id=\"node13\" class=\"node\"><title>12</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M377,-68C377,-68 274,-68 274,-68 268,-68 262,-62 262,-56 262,-56 262,-12 262,-12 262,-6 268,-0 274,-0 274,-0 377,-0 377,-0 383,-0 389,-6 389,-12 389,-12 389,-56 389,-56 389,-62 383,-68 377,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"283.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"282\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 10</text>\r\n",
"<text text-anchor=\"start\" x=\"273\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"270\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 11&#45;&gt;12 -->\r\n",
"<g id=\"edge12\" class=\"edge\"><title>11&#45;&gt;12</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M370.062,-103.726C364.385,-94.9703 358.376,-85.7032 352.671,-76.9051\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"355.469,-74.786 347.092,-68.2996 349.595,-78.5943 355.469,-74.786\"/>\r\n",
"</g>\r\n",
"<!-- 13 -->\r\n",
"<g id=\"node14\" class=\"node\"><title>13</title>\r\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M516,-68C516,-68 419,-68 419,-68 413,-68 407,-62 407,-56 407,-56 407,-12 407,-12 407,-6 413,-0 419,-0 419,-0 516,-0 516,-0 522,-0 528,-6 528,-12 528,-12 528,-56 528,-56 528,-62 522,-68 516,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"425.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"424\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 20</text>\r\n",
"<text text-anchor=\"start\" x=\"415\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"417.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 11&#45;&gt;13 -->\r\n",
"<g id=\"edge13\" class=\"edge\"><title>11&#45;&gt;13</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M422.938,-103.726C428.615,-94.9703 434.624,-85.7032 440.329,-76.9051\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"443.405,-78.5943 445.908,-68.2996 437.531,-74.786 443.405,-78.5943\"/>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<graphviz.files.Source at 0x1b8673ec0c8>"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names=iris.feature_names, \n",
" class_names=iris.target_names, \n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) \n",
"graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Class weights"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# Generate a random permutation of the indices of examples that will be later used \n",
"# for the training and the test set\n",
"import numpy as np\n",
"np.random.seed(1231)\n",
"indices = np.random.permutation(len(iris.data))\n",
"\n",
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
"indices_training=indices[:-10]\n",
"indices_test=indices[-10:]\n",
"\n",
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
"iris_y_train = iris.target[indices_training]\n",
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
"iris_y_test = iris.target[indices_test]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.8\n",
"F1: 0.5\n"
]
}
],
"source": [
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=5,class_weight={0:1,1:10,2:10})\n",
"clf = clf.fit(iris_X_train, iris_y_train)\n",
"predicted_y_test = clf.predict(iris_X_test)\n",
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
"f1 = f1_score(iris_y_test, predicted_y_test, average='macro')\n",
"print(\"Accuracy: \", acc_score)\n",
"print(\"F1: \", f1)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: Tree Pages: 1 -->\r\n",
"<svg width=\"609pt\" height=\"552pt\"\r\n",
" viewBox=\"0.00 0.00 609.00 552.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 548)\">\r\n",
"<title>Tree</title>\r\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-548 605,-548 605,4 -4,4\"/>\r\n",
"<!-- 0 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
"<path fill=\"#fdfbff\" stroke=\"black\" d=\"M374,-544C374,-544 233,-544 233,-544 227,-544 221,-538 221,-532 221,-532 221,-473 221,-473 221,-467 227,-461 233,-461 233,-461 374,-461 374,-461 380,-461 386,-467 386,-473 386,-473 386,-532 386,-532 386,-538 380,-544 374,-544\"/>\r\n",
"<text text-anchor=\"start\" x=\"229\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.85</text>\r\n",
"<text text-anchor=\"start\" x=\"253.5\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.211</text>\r\n",
"<text text-anchor=\"start\" x=\"256\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 140</text>\r\n",
"<text text-anchor=\"start\" x=\"234.5\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 480, 490]</text>\r\n",
"<text text-anchor=\"start\" x=\"253.5\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 1 -->\r\n",
"<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
"<path fill=\"#54e892\" stroke=\"black\" d=\"M284,-425C284,-425 143,-425 143,-425 137,-425 131,-419 131,-413 131,-413 131,-354 131,-354 131,-348 137,-342 143,-342 143,-342 284,-342 284,-342 290,-342 296,-348 296,-354 296,-354 296,-413 296,-413 296,-419 290,-425 284,-425\"/>\r\n",
"<text text-anchor=\"start\" x=\"139\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\r\n",
"<text text-anchor=\"start\" x=\"163.5\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.648</text>\r\n",
"<text text-anchor=\"start\" x=\"170\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 90</text>\r\n",
"<text text-anchor=\"start\" x=\"148.5\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 450, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"158\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;1 -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M272.275,-460.907C265.364,-451.923 257.973,-442.315 250.846,-433.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"253.541,-430.813 244.67,-425.021 247.993,-435.081 253.541,-430.813\"/>\r\n",
"<text text-anchor=\"middle\" x=\"241.436\" y=\"-446.111\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
"</g>\r\n",
"<!-- 8 -->\r\n",
"<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
"<path fill=\"#8946e7\" stroke=\"black\" d=\"M461,-425C461,-425 326,-425 326,-425 320,-425 314,-419 314,-413 314,-413 314,-354 314,-354 314,-348 320,-342 326,-342 326,-342 461,-342 461,-342 467,-342 473,-348 473,-354 473,-354 473,-413 473,-413 473,-419 467,-425 461,-425\"/>\r\n",
"<text text-anchor=\"start\" x=\"322\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\r\n",
"<text text-anchor=\"start\" x=\"343.5\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.327</text>\r\n",
"<text text-anchor=\"start\" x=\"350\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\r\n",
"<text text-anchor=\"start\" x=\"333\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 470]</text>\r\n",
"<text text-anchor=\"start\" x=\"343.5\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;8 -->\r\n",
"<g id=\"edge8\" class=\"edge\"><title>0&#45;&gt;8</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M334.725,-460.907C341.636,-451.923 349.027,-442.315 356.154,-433.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"359.007,-435.081 362.33,-425.021 353.459,-430.813 359.007,-435.081\"/>\r\n",
"<text text-anchor=\"middle\" x=\"365.564\" y=\"-446.111\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
"<path fill=\"#e58139\" stroke=\"black\" d=\"M109,-298.5C109,-298.5 12,-298.5 12,-298.5 6,-298.5 0,-292.5 0,-286.5 0,-286.5 0,-242.5 0,-242.5 0,-236.5 6,-230.5 12,-230.5 12,-230.5 109,-230.5 109,-230.5 115,-230.5 121,-236.5 121,-242.5 121,-242.5 121,-286.5 121,-286.5 121,-292.5 115,-298.5 109,-298.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"18.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"17\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\r\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 0, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"14.5\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;2 -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M160.417,-341.907C144.689,-329.88 127.499,-316.735 111.88,-304.791\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"113.942,-301.961 103.872,-298.667 109.69,-307.522 113.942,-301.961\"/>\r\n",
"</g>\r\n",
"<!-- 3 -->\r\n",
"<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
"<path fill=\"#42e687\" stroke=\"black\" d=\"M286,-306C286,-306 151,-306 151,-306 145,-306 139,-300 139,-294 139,-294 139,-235 139,-235 139,-229 145,-223 151,-223 151,-223 286,-223 286,-223 292,-223 298,-229 298,-235 298,-235 298,-294 298,-294 298,-300 292,-306 286,-306\"/>\r\n",
"<text text-anchor=\"start\" x=\"147\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.45</text>\r\n",
"<text text-anchor=\"start\" x=\"168.5\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.254</text>\r\n",
"<text text-anchor=\"start\" x=\"175\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 47</text>\r\n",
"<text text-anchor=\"start\" x=\"158\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 450, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"163\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;3 -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>1&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M215.235,-341.907C215.592,-333.558 215.971,-324.671 216.341,-316.02\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"219.838,-316.161 216.768,-306.021 212.844,-315.862 219.838,-316.161\"/>\r\n",
"</g>\r\n",
"<!-- 4 -->\r\n",
"<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M189,-179.5C189,-179.5 84,-179.5 84,-179.5 78,-179.5 72,-173.5 72,-167.5 72,-167.5 72,-123.5 72,-123.5 72,-117.5 78,-111.5 84,-111.5 84,-111.5 189,-111.5 189,-111.5 195,-111.5 201,-117.5 201,-123.5 201,-123.5 201,-167.5 201,-167.5 201,-173.5 195,-179.5 189,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"94.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"93\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 35</text>\r\n",
"<text text-anchor=\"start\" x=\"80\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 350, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"81\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 3&#45;&gt;4 -->\r\n",
"<g id=\"edge4\" class=\"edge\"><title>3&#45;&gt;4</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M190.051,-222.907C182.162,-211.652 173.588,-199.418 165.66,-188.106\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"168.351,-185.847 159.745,-179.667 162.618,-189.865 168.351,-185.847\"/>\r\n",
"</g>\r\n",
"<!-- 5 -->\r\n",
"<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
"<path fill=\"#61ea9a\" stroke=\"black\" d=\"M368,-187C368,-187 231,-187 231,-187 225,-187 219,-181 219,-175 219,-175 219,-116 219,-116 219,-110 225,-104 231,-104 231,-104 368,-104 368,-104 374,-104 380,-110 380,-116 380,-116 380,-175 380,-175 380,-181 374,-187 368,-187\"/>\r\n",
"<text text-anchor=\"start\" x=\"227\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sepal length (cm) ≤ 6.1</text>\r\n",
"<text text-anchor=\"start\" x=\"253.5\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.65</text>\r\n",
"<text text-anchor=\"start\" x=\"256\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 12</text>\r\n",
"<text text-anchor=\"start\" x=\"239\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 100, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"244\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 3&#45;&gt;5 -->\r\n",
"<g id=\"edge5\" class=\"edge\"><title>3&#45;&gt;5</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M246.603,-222.907C252.76,-214.014 259.34,-204.509 265.694,-195.331\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"268.633,-197.235 271.447,-187.021 262.877,-193.251 268.633,-197.235\"/>\r\n",
"</g>\r\n",
"<!-- 6 -->\r\n",
"<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
"<path fill=\"#88efb3\" stroke=\"black\" d=\"M279,-68C279,-68 174,-68 174,-68 168,-68 162,-62 162,-56 162,-56 162,-12 162,-12 162,-6 168,-0 174,-0 174,-0 279,-0 279,-0 285,-0 291,-6 291,-12 291,-12 291,-56 291,-56 291,-62 285,-68 279,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"176.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.863</text>\r\n",
"<text text-anchor=\"start\" x=\"187\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 7</text>\r\n",
"<text text-anchor=\"start\" x=\"170\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"171\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 5&#45;&gt;6 -->\r\n",
"<g id=\"edge6\" class=\"edge\"><title>5&#45;&gt;6</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M272.317,-103.726C266.419,-94.879 260.173,-85.51 254.254,-76.6303\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"257.159,-74.6787 248.7,-68.2996 251.335,-78.5616 257.159,-74.6787\"/>\r\n",
"</g>\r\n",
"<!-- 7 -->\r\n",
"<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M424,-68C424,-68 321,-68 321,-68 315,-68 309,-62 309,-56 309,-56 309,-12 309,-12 309,-6 315,-0 321,-0 321,-0 424,-0 424,-0 430,-0 436,-6 436,-12 436,-12 436,-56 436,-56 436,-62 430,-68 424,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"330.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"333\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\r\n",
"<text text-anchor=\"start\" x=\"320\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"317\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 5&#45;&gt;7 -->\r\n",
"<g id=\"edge7\" class=\"edge\"><title>5&#45;&gt;7</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M326.683,-103.726C332.581,-94.879 338.827,-85.51 344.746,-76.6303\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"347.665,-78.5616 350.3,-68.2996 341.841,-74.6787 347.665,-78.5616\"/>\r\n",
"</g>\r\n",
"<!-- 9 -->\r\n",
"<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
"<path fill=\"#e0cef8\" stroke=\"black\" d=\"M442,-298.5C442,-298.5 337,-298.5 337,-298.5 331,-298.5 325,-292.5 325,-286.5 325,-286.5 325,-242.5 325,-242.5 325,-236.5 331,-230.5 337,-230.5 337,-230.5 442,-230.5 442,-230.5 448,-230.5 454,-236.5 454,-242.5 454,-242.5 454,-286.5 454,-286.5 454,-292.5 448,-298.5 442,-298.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"339.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.985</text>\r\n",
"<text text-anchor=\"start\" x=\"350\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 7</text>\r\n",
"<text text-anchor=\"start\" x=\"333\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 40]</text>\r\n",
"<text text-anchor=\"start\" x=\"339.5\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;9 -->\r\n",
"<g id=\"edge9\" class=\"edge\"><title>8&#45;&gt;9</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M392.112,-341.907C391.746,-331.204 391.35,-319.615 390.98,-308.776\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"394.474,-308.541 390.634,-298.667 387.478,-308.781 394.474,-308.541\"/>\r\n",
"</g>\r\n",
"<!-- 10 -->\r\n",
"<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M589,-298.5C589,-298.5 484,-298.5 484,-298.5 478,-298.5 472,-292.5 472,-286.5 472,-286.5 472,-242.5 472,-242.5 472,-236.5 478,-230.5 484,-230.5 484,-230.5 589,-230.5 589,-230.5 595,-230.5 601,-236.5 601,-242.5 601,-242.5 601,-286.5 601,-286.5 601,-292.5 595,-298.5 589,-298.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"494.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"493\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\r\n",
"<text text-anchor=\"start\" x=\"480\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 430]</text>\r\n",
"<text text-anchor=\"start\" x=\"486.5\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 8&#45;&gt;10 -->\r\n",
"<g id=\"edge10\" class=\"edge\"><title>8&#45;&gt;10</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M443.113,-341.907C457.679,-329.99 473.585,-316.976 488.076,-305.12\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"490.439,-307.708 495.963,-298.667 486.007,-302.29 490.439,-307.708\"/>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<graphviz.files.Source at 0x1b8673e1288>"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names=iris.feature_names, \n",
" class_names=iris.target_names, \n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) \n",
"graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Avoid overfitting"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# Generate a random permutation of the indices of examples that will be later used \n",
"# for the training and the test set\n",
"import numpy as np\n",
"np.random.seed(1231)\n",
"indices = np.random.permutation(len(iris.data))\n",
"\n",
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
"indices_training=indices[:-10]\n",
"indices_test=indices[-10:]\n",
"\n",
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
"iris_y_train = iris.target[indices_training]\n",
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
"iris_y_test = iris.target[indices_test]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 1.0\n",
"F1: 1.0\n"
]
}
],
"source": [
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=3,class_weight={0:1,1:10,2:10}, min_impurity_decrease = 0.005, max_depth = 4, max_leaf_nodes = 6)\n",
"clf = clf.fit(iris_X_train, iris_y_train)\n",
"predicted_y_test = clf.predict(iris_X_test)\n",
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
"f1 = f1_score(iris_y_test, predicted_y_test, average='macro')\n",
"print(\"Accuracy: \", acc_score)\n",
"print(\"F1: \", f1)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
"<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
" -->\r\n",
"<!-- Title: Tree Pages: 1 -->\r\n",
"<svg width=\"636pt\" height=\"433pt\"\r\n",
" viewBox=\"0.00 0.00 636.00 433.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 429)\">\r\n",
"<title>Tree</title>\r\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-429 632,-429 632,4 -4,4\"/>\r\n",
"<!-- 0 -->\r\n",
"<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
"<path fill=\"#fdfbff\" stroke=\"black\" d=\"M379,-425C379,-425 238,-425 238,-425 232,-425 226,-419 226,-413 226,-413 226,-354 226,-354 226,-348 232,-342 238,-342 238,-342 379,-342 379,-342 385,-342 391,-348 391,-354 391,-354 391,-413 391,-413 391,-419 385,-425 379,-425\"/>\r\n",
"<text text-anchor=\"start\" x=\"234\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.85</text>\r\n",
"<text text-anchor=\"start\" x=\"258.5\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.211</text>\r\n",
"<text text-anchor=\"start\" x=\"261\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 140</text>\r\n",
"<text text-anchor=\"start\" x=\"239.5\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 480, 490]</text>\r\n",
"<text text-anchor=\"start\" x=\"258.5\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 1 -->\r\n",
"<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
"<path fill=\"#54e892\" stroke=\"black\" d=\"M289,-306C289,-306 148,-306 148,-306 142,-306 136,-300 136,-294 136,-294 136,-235 136,-235 136,-229 142,-223 148,-223 148,-223 289,-223 289,-223 295,-223 301,-229 301,-235 301,-235 301,-294 301,-294 301,-300 295,-306 289,-306\"/>\r\n",
"<text text-anchor=\"start\" x=\"144\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\r\n",
"<text text-anchor=\"start\" x=\"168.5\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.648</text>\r\n",
"<text text-anchor=\"start\" x=\"175\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 90</text>\r\n",
"<text text-anchor=\"start\" x=\"153.5\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 450, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"163\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;1 -->\r\n",
"<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M277.275,-341.907C270.364,-332.923 262.973,-323.315 255.846,-314.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"258.541,-311.813 249.67,-306.021 252.993,-316.081 258.541,-311.813\"/>\r\n",
"<text text-anchor=\"middle\" x=\"246.436\" y=\"-327.111\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node7\" class=\"node\"><title>2</title>\r\n",
"<path fill=\"#8946e7\" stroke=\"black\" d=\"M466,-306C466,-306 331,-306 331,-306 325,-306 319,-300 319,-294 319,-294 319,-235 319,-235 319,-229 325,-223 331,-223 331,-223 466,-223 466,-223 472,-223 478,-229 478,-235 478,-235 478,-294 478,-294 478,-300 472,-306 466,-306\"/>\r\n",
"<text text-anchor=\"start\" x=\"327\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\r\n",
"<text text-anchor=\"start\" x=\"348.5\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.327</text>\r\n",
"<text text-anchor=\"start\" x=\"355\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\r\n",
"<text text-anchor=\"start\" x=\"338\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 470]</text>\r\n",
"<text text-anchor=\"start\" x=\"348.5\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;2 -->\r\n",
"<g id=\"edge6\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M339.725,-341.907C346.636,-332.923 354.027,-323.315 361.154,-314.05\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"364.007,-316.081 367.33,-306.021 358.459,-311.813 364.007,-316.081\"/>\r\n",
"<text text-anchor=\"middle\" x=\"370.564\" y=\"-327.111\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
"</g>\r\n",
"<!-- 3 -->\r\n",
"<g id=\"node3\" class=\"node\"><title>3</title>\r\n",
"<path fill=\"#e58139\" stroke=\"black\" d=\"M109,-179.5C109,-179.5 12,-179.5 12,-179.5 6,-179.5 0,-173.5 0,-167.5 0,-167.5 0,-123.5 0,-123.5 0,-117.5 6,-111.5 12,-111.5 12,-111.5 109,-111.5 109,-111.5 115,-111.5 121,-117.5 121,-123.5 121,-123.5 121,-167.5 121,-167.5 121,-173.5 115,-179.5 109,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"18.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"17\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\r\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 0, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"14.5\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;3 -->\r\n",
"<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M163.683,-222.907C147.44,-210.88 129.689,-197.735 113.559,-185.791\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"115.409,-182.805 105.29,-179.667 111.243,-188.431 115.409,-182.805\"/>\r\n",
"</g>\r\n",
"<!-- 4 -->\r\n",
"<g id=\"node4\" class=\"node\"><title>4</title>\r\n",
"<path fill=\"#42e687\" stroke=\"black\" d=\"M286,-187C286,-187 151,-187 151,-187 145,-187 139,-181 139,-175 139,-175 139,-116 139,-116 139,-110 145,-104 151,-104 151,-104 286,-104 286,-104 292,-104 298,-110 298,-116 298,-116 298,-175 298,-175 298,-181 292,-187 286,-187\"/>\r\n",
"<text text-anchor=\"start\" x=\"147\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.65</text>\r\n",
"<text text-anchor=\"start\" x=\"168.5\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.254</text>\r\n",
"<text text-anchor=\"start\" x=\"175\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 47</text>\r\n",
"<text text-anchor=\"start\" x=\"158\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 450, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"163\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;4 -->\r\n",
"<g id=\"edge3\" class=\"edge\"><title>1&#45;&gt;4</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M218.5,-222.907C218.5,-214.649 218.5,-205.864 218.5,-197.302\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"222,-197.021 218.5,-187.021 215,-197.021 222,-197.021\"/>\r\n",
"</g>\r\n",
"<!-- 7 -->\r\n",
"<g id=\"node5\" class=\"node\"><title>7</title>\r\n",
"<path fill=\"#39e581\" stroke=\"black\" d=\"M140,-68C140,-68 35,-68 35,-68 29,-68 23,-62 23,-56 23,-56 23,-12 23,-12 23,-6 29,-0 35,-0 35,-0 140,-0 140,-0 146,-0 152,-6 152,-12 152,-12 152,-56 152,-56 152,-62 146,-68 140,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"45.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"44\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 44</text>\r\n",
"<text text-anchor=\"start\" x=\"31\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 440, 0]</text>\r\n",
"<text text-anchor=\"start\" x=\"32\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 4&#45;&gt;7 -->\r\n",
"<g id=\"edge4\" class=\"edge\"><title>4&#45;&gt;7</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M169.72,-103.726C158.372,-94.2406 146.307,-84.1551 135.014,-74.7159\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"137.255,-72.0276 127.338,-68.2996 132.766,-77.3984 137.255,-72.0276\"/>\r\n",
"</g>\r\n",
"<!-- 8 -->\r\n",
"<g id=\"node6\" class=\"node\"><title>8</title>\r\n",
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M287,-68C287,-68 182,-68 182,-68 176,-68 170,-62 170,-56 170,-56 170,-12 170,-12 170,-6 176,-0 182,-0 182,-0 287,-0 287,-0 293,-0 299,-6 299,-12 299,-12 299,-56 299,-56 299,-62 293,-68 287,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"184.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\r\n",
"<text text-anchor=\"start\" x=\"195\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
"<text text-anchor=\"start\" x=\"178\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 20]</text>\r\n",
"<text text-anchor=\"start\" x=\"184.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 4&#45;&gt;8 -->\r\n",
"<g id=\"edge5\" class=\"edge\"><title>4&#45;&gt;8</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M224.458,-103.726C225.671,-95.4263 226.95,-86.6671 228.175,-78.2834\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"231.652,-78.7006 229.634,-68.2996 224.725,-77.6885 231.652,-78.7006\"/>\r\n",
"</g>\r\n",
"<!-- 5 -->\r\n",
"<g id=\"node8\" class=\"node\"><title>5</title>\r\n",
"<path fill=\"#e0cef8\" stroke=\"black\" d=\"M469,-187C469,-187 328,-187 328,-187 322,-187 316,-181 316,-175 316,-175 316,-116 316,-116 316,-110 322,-104 328,-104 328,-104 469,-104 469,-104 475,-104 481,-110 481,-116 481,-116 481,-175 481,-175 481,-181 475,-187 469,-187\"/>\r\n",
"<text text-anchor=\"start\" x=\"324\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 5.05</text>\r\n",
"<text text-anchor=\"start\" x=\"348.5\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.985</text>\r\n",
"<text text-anchor=\"start\" x=\"359\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 7</text>\r\n",
"<text text-anchor=\"start\" x=\"342\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 40]</text>\r\n",
"<text text-anchor=\"start\" x=\"348.5\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;5 -->\r\n",
"<g id=\"edge7\" class=\"edge\"><title>2&#45;&gt;5</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M398.5,-222.907C398.5,-214.649 398.5,-205.864 398.5,-197.302\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"402,-197.021 398.5,-187.021 395,-197.021 402,-197.021\"/>\r\n",
"</g>\r\n",
"<!-- 6 -->\r\n",
"<g id=\"node11\" class=\"node\"><title>6</title>\r\n",
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M616,-179.5C616,-179.5 511,-179.5 511,-179.5 505,-179.5 499,-173.5 499,-167.5 499,-167.5 499,-123.5 499,-123.5 499,-117.5 505,-111.5 511,-111.5 511,-111.5 616,-111.5 616,-111.5 622,-111.5 628,-117.5 628,-123.5 628,-123.5 628,-167.5 628,-167.5 628,-173.5 622,-179.5 616,-179.5\"/>\r\n",
"<text text-anchor=\"start\" x=\"521.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\r\n",
"<text text-anchor=\"start\" x=\"520\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\r\n",
"<text text-anchor=\"start\" x=\"507\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 430]</text>\r\n",
"<text text-anchor=\"start\" x=\"513.5\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;6 -->\r\n",
"<g id=\"edge10\" class=\"edge\"><title>2&#45;&gt;6</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M455.746,-222.907C472.864,-210.769 491.586,-197.493 508.553,-185.462\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"510.593,-188.306 516.726,-179.667 506.544,-182.596 510.593,-188.306\"/>\r\n",
"</g>\r\n",
"<!-- 9 -->\r\n",
"<g id=\"node9\" class=\"node\"><title>9</title>\r\n",
"<path fill=\"#9cf2c0\" stroke=\"black\" d=\"M442,-68C442,-68 337,-68 337,-68 331,-68 325,-62 325,-56 325,-56 325,-12 325,-12 325,-6 331,-0 337,-0 337,-0 442,-0 442,-0 448,-0 454,-6 454,-12 454,-12 454,-56 454,-56 454,-62 448,-68 442,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"339.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\r\n",
"<text text-anchor=\"start\" x=\"350\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
"<text text-anchor=\"start\" x=\"333\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 20, 10]</text>\r\n",
"<text text-anchor=\"start\" x=\"334\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\r\n",
"</g>\r\n",
"<!-- 5&#45;&gt;9 -->\r\n",
"<g id=\"edge8\" class=\"edge\"><title>5&#45;&gt;9</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M395.149,-103.726C394.467,-95.4263 393.747,-86.6671 393.058,-78.2834\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"396.544,-77.9793 392.237,-68.2996 389.568,-78.5527 396.544,-77.9793\"/>\r\n",
"</g>\r\n",
"<!-- 10 -->\r\n",
"<g id=\"node10\" class=\"node\"><title>10</title>\r\n",
"<path fill=\"#ab7bee\" stroke=\"black\" d=\"M589,-68C589,-68 484,-68 484,-68 478,-68 472,-62 472,-56 472,-56 472,-12 472,-12 472,-6 478,-0 484,-0 484,-0 589,-0 589,-0 595,-0 601,-6 601,-12 601,-12 601,-56 601,-56 601,-62 595,-68 589,-68\"/>\r\n",
"<text text-anchor=\"start\" x=\"486.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.811</text>\r\n",
"<text text-anchor=\"start\" x=\"497\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 4</text>\r\n",
"<text text-anchor=\"start\" x=\"480\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 30]</text>\r\n",
"<text text-anchor=\"start\" x=\"486.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\r\n",
"</g>\r\n",
"<!-- 5&#45;&gt;10 -->\r\n",
"<g id=\"edge9\" class=\"edge\"><title>5&#45;&gt;10</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M449.886,-103.726C461.841,-94.2406 474.551,-84.1551 486.447,-74.7159\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"488.875,-77.2571 494.533,-68.2996 484.524,-71.7736 488.875,-77.2571\"/>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
],
"text/plain": [
"<graphviz.files.Source at 0x1b86742c348>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names=iris.feature_names, \n",
" class_names=iris.target_names, \n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) \n",
"graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Confusion Matrix"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([[7, 0, 0],\n",
" [0, 2, 0],\n",
" [0, 0, 1]])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initializes the confusion matrix\n",
"confusion = np.zeros([3, 3], dtype = int)\n",
"\n",
"# print the corresponding instances indexes and class names\n",
"for i in range(len(iris_y_test)): \n",
" #increments the indexed cell value\n",
" confusion[iris_y_test[i], predicted_y_test[i]]+=1\n",
"confusion"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. ROC Curves"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[(0.0, 43.0), (30.0, 0.0), (30.0, 0.0), (40.0, 0.0), (430.0, 0.0), (440.0, 0.0)], [(0.0, 440.0), (10.0, 20.0), (20.0, 10.0), (30.0, 10.0), (43.0, 0.0), (430.0, 0.0)], [(0.0, 430.0), (10.0, 30.0), (10.0, 20.0), (20.0, 10.0), (43.0, 0.0), (440.0, 0.0)]]\n"
]
},
{
"data": {
"text/plain": [
"[[[0, 0.0, 30.0, 60.0, 100.0, 530.0, 970.0],\n",
" [0, 43.0, 43.0, 43.0, 43.0, 43.0, 43.0]],\n",
" [[0, 0.0, 10.0, 30.0, 60.0, 103.0, 533.0],\n",
" [0, 440.0, 460.0, 470.0, 480.0, 480.0, 480.0]],\n",
" [[0, 0.0, 10.0, 20.0, 40.0, 83.0, 523.0],\n",
" [0, 430.0, 460.0, 480.0, 490.0, 490.0, 490.0]]]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Calculates the ROC curves (x, y)\n",
"leafs = []\n",
"class_pairs = [[],[],[]]\n",
"roc_curves = [[[0], [0]], [[0], [0]], [[0], [0]]]\n",
"for i in range(clf.tree_.node_count):\n",
" if (clf.tree_.feature[i] == -2):\n",
" leafs.append(i)\n",
"\n",
"# c = class index\n",
"for leaf in leafs:\n",
" for c in range(3):\n",
" #pairs(neg, pos)\n",
" class_pairs[c].append((clf.tree_.value[leaf][0].sum() - clf.tree_.value[leaf][0][c], clf.tree_.value[leaf][0][c]))\n",
"\n",
"#pairs sorting\n",
"for c in range(3):\n",
" class_pairs[c] = sorted(class_pairs[c], key=lambda t: t[0]/max(1,t[1]))\n",
"print(class_pairs)\n",
"\n",
"for i in range(1, len(leafs) + 1):\n",
" for c in range(3):\n",
" roc_curves[c][0].append(class_pairs[c][i - 1][0] + roc_curves[c][0][i - 1])\n",
" roc_curves[c][1].append(class_pairs[c][i - 1][1] + roc_curves[c][1][i - 1])\n",
"\n",
"roc_curves"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAM3UlEQVR4nO3dX4yl9V3H8fdHloK2qUAZyMqCC3FTISZdmgmCeGGgVMSmcIEJpNGNbrI3NVJtUkEvmiZelMSU1cQ03RTsxpQ/lRIhpJGQLcSYGOogSKFb3IVauoLskEKrXmixXy/Os3Tc7jLnnDmzs/Od9yuZzHme8xzO73d+5D3PPDOzJ1WFJKmXn1jrAUiSZs+4S1JDxl2SGjLuktSQcZekhjadyCc7++yza+vWrSfyKSVp3XvyySdfq6q5SR5zQuO+detWFhYWTuRTStK6l+Tbkz7GyzKS1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQyf099yntmcP3H33Wo9CkqazfTvs3n1Cn3J9nLnffTc8/fRaj0KS1o31ceYOo698jz++1qOQpHVhfZy5S5ImYtwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDU0dtyTnJLkqSQPD9sXJnkiyYEk9yV5x+oNU5I0iUnO3G8B9i/Zvh24o6q2Aa8DO2c5MEnS9MaKe5ItwK8Dnx+2A1wF3D8cshe4YTUGKEma3Lhn7ruBTwA/HLbfA7xRVW8O24eA8471wCS7kiwkWVhcXFzRYCVJ41k27kk+BByuqieX7j7GoXWsx1fVnqqar6r5ubm5KYcpSZrEOP/k75XAh5NcB5wOvJvRmfwZSTYNZ+9bgJdXb5iSpEkse+ZeVbdV1Zaq2grcBHy1qj4CPAbcOBy2A3hw1UYpSZrISn7P/Q+BP0hykNE1+DtnMyRJ0kpN9E5MVfU48Phw+0XgstkPSZK0Uv6FqiQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPLxj3J6Um+luSfkzyX5FPD/guTPJHkQJL7krxj9YcrSRrHOGfu/w1cVVXvA7YD1ya5HLgduKOqtgGvAztXb5iSpEksG/ca+c9h89Tho4CrgPuH/XuBG1ZlhJKkiY11zT3JKUmeBg4DjwIvAG9U1ZvDIYeA847z2F1JFpIsLC4uzmLMkqRljBX3qvrfqtoObAEuAy4+1mHHeeyeqpqvqvm5ubnpRypJGttEvy1TVW8AjwOXA2ck2TTctQV4ebZDkyRNa5zflplLcsZw+yeBDwD7gceAG4fDdgAPrtYgJUmT2bT8IWwG9iY5hdEXgy9V1cNJvgHcm+RPgKeAO1dxnJKkCSwb96p6Brj0GPtfZHT9XZJ0kvEvVCWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8ZdkhpaNu5Jzk/yWJL9SZ5Lcsuw/6wkjyY5MHw+c/WHK0kaxzhn7m8CH6+qi4HLgY8muQS4FdhXVduAfcO2JOkksGzcq+qVqvqn4fZ/APuB84Drgb3DYXuBG1ZrkJKkyUx0zT3JVuBS4Ang3Kp6BUZfAIBzjvOYXUkWkiwsLi6ubLSSpLGMHfck7wK+DHysqr4/7uOqak9VzVfV/Nzc3DRjlCRNaKy4JzmVUdi/WFUPDLtfTbJ5uH8zcHh1hihJmtQ4vy0T4E5gf1V9ZsldDwE7hts7gAdnPzxJ0jQ2jXHMlcBvAl9P8vSw74+ATwNfSrITeAn4jdUZoiRpUsvGvar+Hshx7r56tsORJM2Cf6EqSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLU0LJxT3JXksNJnl2y76wkjyY5MHw+c3WHKUmaxDhn7l8Arj1q363AvqraBuwbtiVJJ4ll415Vfwd896jd1wN7h9t7gRtmPC5J0gpMe8393Kp6BWD4fM7xDkyyK8lCkoXFxcUpn06SNIlV/4FqVe2pqvmqmp+bm1vtp5MkMX3cX02yGWD4fHh2Q5IkrdS0cX8I2DHc3gE8OJvhSJJmYZxfhbwH+AfgvUkOJdkJfBq4JskB4JphW5J0kti03AFVdfNx7rp6xmORJM2If6EqSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLUkHGXpIaMuyQ1ZNwlqSHjLkkNGXdJasi4S1JDxl2SGjLuktSQcZekhoy7JDVk3CWpIeMuSQ0Zd0lqyLhLUkPGXZIaMu6S1JBxl6SGjLskNWTcJakh4y5JDRl3SWrIuEtSQ8Zdkhoy7pLU0IrinuTaJM8nOZjk1lkNSpK0MlPHPckpwF8AvwZcAtyc5JJZDUySNL2VnLlfBhysqher6n+Ae4HrZzMsSdJKbFrBY88DvrNk+xDwi0cflGQXsAvgggsumO6Ztm+f7nGStEGtJO45xr76sR1Ve4A9APPz8z92/1h2757qYZK0Ua3ksswh4Pwl21uAl1c2HEnSLKwk7v8IbEtyYZJ3ADcBD81mWJKklZj6skxVvZnkd4FHgFOAu6rquZmNTJI0tZVcc6eqvgJ8ZUZjkSTNiH+hKkkNGXdJasi4S1JDxl2SGkrVdH9XNNWTJYvAt6d8+NnAazMcznri3DeejTpvcO7HmvvPVtXcJP+hExr3lUiyUFXzaz2OteDcN97cN+q8wbnPau5elpGkhoy7JDW0nuK+Z60HsIac+8azUecNzn0m1s01d0nS+NbTmbskaUzGXZIaWhdx7/xG3EnOT/JYkv1Jnktyy7D/rCSPJjkwfD5z2J8kfz68Fs8kef/azmDlkpyS5KkkDw/bFyZ5Ypj7fcM/KU2S04btg8P9W9dy3CuV5Iwk9yf55rD+V2yEdU/y+8P/688muSfJ6V3XPMldSQ4neXbJvonXOMmO4fgDSXaM89wnfdw3wBtxvwl8vKouBi4HPjrM71ZgX1VtA/YN2zB6HbYNH7uAz574Ic/cLcD+Jdu3A3cMc38d2Dns3wm8XlU/B9wxHLee/Rnwt1X188D7GL0Grdc9yXnA7wHzVfULjP658Jvou+ZfAK49at9Ea5zkLOCTjN7G9DLgk0e+ILytqjqpP4ArgEeWbN8G3LbW41rF+T4IXAM8D2we9m0Gnh9ufw64ecnxbx23Hj8YvYPXPuAq4GFGb9/4GrDp6PVn9N4BVwy3Nw3HZa3nMOW83w186+jxd193fvTey2cNa/gw8Kud1xzYCjw77RoDNwOfW7L//x13vI+T/sydY78R93lrNJZVNXzLeSnwBHBuVb0CMHw+Zzis2+uxG/gE8MNh+z3AG1X15rC9dH5vzX24/3vD8evRRcAi8JfDJanPJ3knzde9qv4N+FPgJeAVRmv4JBtjzY+YdI2nWvv1EPex3oh7vUvyLuDLwMeq6vtvd+gx9q3L1yPJh4DDVfXk0t3HOLTGuG+92QS8H/hsVV0K/Bc/+vb8WFrMfbiccD1wIfAzwDsZXY44Wsc1X87x5jrVa7Ae4t7+jbiTnMoo7F+sqgeG3a8m2Tzcvxk4POz
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD7CAYAAACRxdTpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQjUlEQVR4nO3da6xdZZ3H8e/PlosjKpceCLbVYigJvFDEwtSghIuDwKglBhIvkcY09g1j8JIojMbRZF7oGyEkE2IVpYoKKBIaQgZJucjEgB4EuYhKIUhrgRaBwsSIA/znxX6qh/a0Z7c9p6fn6feT7Ky1/uvZez//sPl1nXXWOjtVhSSpL6+Z7glIkiaf4S5JHTLcJalDhrskdchwl6QOGe6S1KGhwj3JY0nuT3JvktFWOzjJzUkebsuDWj1JLk2yJsl9SY6bygYkSVvbkSP3U6rq2Kpa1LYvBFZX1UJgddsGOBNY2B7Lgcsma7KSpOHM3oXnLgFObusrgduAL7T692pwd9SdSQ5McnhVPbGtF5ozZ04tWLBgF6YiSXufu+++++mqGhlv37DhXsDPkhTwzapaARy2ObCr6okkh7axc4G1Y567rtW2Ge4LFixgdHR0yKlIkgCS/HFb+4YN9xOran0L8JuT/G577zdObau/cZBkOYPTNrz5zW8echqSpGEMdc69qta35QbgOuAE4KkkhwO05YY2fB0wf8zT5wHrx3nNFVW1qKoWjYyM+1OFJGknTRjuSV6X5PWb14HTgQeAVcDSNmwpcH1bXwWc166aWQxs2t75dknS5BvmtMxhwHVJNo//YVX9d5JfAdckWQY8Dpzbxt8InAWsAf4CfGLSZy1J2q4Jw72qHgXePk79z8Bp49QLOH9SZidJ2ineoSpJHTLcJalDu3ITkwBeeAF+8QsYHYUXX5zu2UiaaT7wATj++El/WcN9Rz39NNxxx+Dx85/DPffAK68M9mW8S/wlaTve9CbDfVqsXfuPIL/jDvjtbwf1/feHxYvhi1+Ek04arB9wwPTOVZIaw32sKvjDH14d5o89Ntj3hjfAu98NH//4IMzf+U7Yb79pna4kbYvhvnYtXHfdP8J8Q7vR9tBD4T3vgc98ZrB829tg1qzpnaskDWnvDPdNm+Daa+H734fbbx8csb/lLfC+9w2C/KST4KijPIcuacbae8L9b3+Dm24aBPqqVYMrWxYuhK98BT76UTjyyOmeoSRNmr7DvQruvBOuvBKuvhr+/GeYMwc++cnBufPjj/foXFKX+gz3P/0JvvWtQag/8sjgypYlSwaBfvrpsM8+0z1DSZpS/YX788/DiSfC44/DKafAl74EH/rQ4GoXSdpL9BfuF1wwuALm9tsHvxyVpL1QX39b5rrr4Ior4KKLDHZJe7V+wv3JJ2H5cjjuOPjyl6d7NpI0rfoI96rBFTAvvDC41HHffad7RpI0rfo45/7tb8MNN8All8Axx0z3bCRp2s38I/dHHhn8iYDTToNPfWq6ZyNJe4SZHe4vvwznnQezZ8N3vwuvmdntSNJkmdmnZX74w8EXZVx5JcyfP92zkaQ9xsw+1H3yycHy7LOndx6StIeZ2eEuSRqX4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ0OHe5JZSe5JckPbPiLJXUkeTnJ1kn1bfb+2vabtXzA1U5ckbcuOHLlfADw0ZvvrwMVVtRB4FljW6suAZ6vqSODiNk6StBsNFe5J5gH/Cny7bQc4FfhJG7IS2PynGZe0bdr+09p4SdJuMuyR+yXA54FX2vYhwHNV9VLbXgfMbetzgbUAbf+mNl6StJtMGO5J3g9sqKq7x5bHGVpD7Bv7usuTjCYZ3bhx41CTlSQNZ5gj9xOBDyZ5DLiKwemYS4ADk2z+Jqd5wPq2vg6YD9D2vxF4ZssXraoVVbWoqhaNjIzsUhOSpFebMNyr6qKqmldVC4APA7dU1ceAW4Fz2rClwPVtfVXbpu2/paq2OnKXJE2dXbnO/QvAZ5OsYXBO/fJWvxw4pNU/C1y4a1OUJO2oHfqC7Kq6DbitrT8KnDDOmL8C507C3CRJO8k7VCWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUMThnuS/ZP8MslvkjyY5KutfkSSu5I8nOTqJPu2+n5te03bv2BqW5AkbWmYI/cXgVOr6u3AscAZSRYDXwcurqqFwLPAsjZ+GfBsVR0JXNzGSZJ2ownDvQb+t23u0x4FnAr8pNVXAme39SVtm7b/tCSZtBlLkiY01Dn3JLOS3AtsAG4GHgGeq6qX2pB1wNy2PhdYC9D2bwIOGec1lycZTTK6cePGXetCkvQqQ4V7Vb1cVccC84ATgKPHG9aW4x2l11aFqhVVtaiqFo2MjAw7X0nSEHboapmqeg64DVgMHJhkdts1D1jf1tcB8wHa/jcCz0zGZCVJwxnmapmRJAe29dcC7wUeAm4FzmnDlgLXt/VVbZu2/5aq2urIXZI0dWZPPITDgZVJZjH4x+CaqrohyW+Bq5L8J3APcHkbfznw/SRrGByxf3gK5i1J2o4Jw72q7gPeMU79UQbn37es/xU4d1JmJ0naKd6hKkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SerQhOGeZH6SW5M8lOTBJBe0+sFJbk7ycFse1OpJcmmSNUnuS3LcVDchSXq1YY7cXwI+V1VHA4uB85McA1wIrK6qhcDqtg1wJrCwPZYDl036rCVJ2zVhuFfVE1X167b+AvAQMBdYAqxsw1YCZ7f1JcD3auBO4MAkh0/6zCVJ27RD59yTLADeAdwFHFZVT8DgHwDg0DZsLrB2zNPWtZokaTcZOtyTHABcC3y6qp7f3tBxajXO6y1PMppkdOPGjcNOQ5I0hKHCPck+DIL9B1X101Z+avPplrbc0OrrgPljnj4PWL/la1bViqpaVFWLRkZGdnb+kqRxDHO1TIDLgYeq6htjdq0Clrb1pcD1Y+rntatmFgObNp++kSTtHrOHGHMi8HHg/iT3ttq/A18DrkmyDHgcOLftuxE4C1gD/AX4xKTOWJI0oQnDvar+h/HPowOcNs74As7fxXlJknaBd6hKUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOjRhuCf5TpINSR4YUzs4yc1JHm7Lg1o9SS5NsibJfUmOm8rJS5LGN8yR+xXAGVvULgRWV9VCYHXbBjgTWNgey4HLJmeakqQdMWG4V9XPgWe2KC8BVrb1lcDZY+rfq4E7gQOTHD5Zk5UkDWdnz7kfVlVPALTloa0+F1g7Zty6VpMk7UaT/QvVjFOrcQcmy5OMJhnduHHjJE9DkvZuOxvuT20+3dKWG1p9HTB/zLh5wPrxXqCqVlTVoqpaNDIyspPTkCSNZ2fDfRWwtK0vBa4fUz+vXTWzGNi0+fSNJGn3mT3RgCQ/Ak4G5iRZB/wH8DXgmiTLgMeBc9vwG4GzgDXAX4BPTMGcJUkTmDDcq+oj29h12jhjCzh/VyclSdo13qEqSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHd
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQlklEQVR4nO3df4xdZZ3H8ffXlh+CK4V2qLXT7GAYE0xcfmTEGjZGwTWAaEmECNGlMSVNFBWD0YXdZFfWRSQxgCQbsnUh1I1LQURpSLPQFIhsosgUSilb2Y4E7aSEjgtU118IfveP84x7aW+Z25k7c2eeeb+Sm3PO9zz33u8Ths+cPnN/RGYiSarLG3rdgCSp+wx3SaqQ4S5JFTLcJalChrskVWhhrxsAWLJkSQ4MDPS6DUmaU7Zu3fqLzOxrd25WhPvAwADDw8O9bkOS5pSI+NnBzrksI0kV6ijcI+LZiHgyIrZFxHCpHRcRmyNiV9keW+oRETdFxEhEbI+I06ZzApKkAx3Klfv7M/OUzBwqx1cCWzJzENhSjgHOAQbLbS1wc7ealSR1ZirLMquA9WV/PXB+S/1b2fgRsCgilk3heSRJh6jTcE/g/ojYGhFrS21pZj4HULbHl/pyYHfLfUdL7TUiYm1EDEfE8NjY2OS6lyS11emrZc7IzD0RcTywOSJ+8jpjo03tgE8ny8x1wDqAoaEhP71Mkrqooyv3zNxTtnuB7wGnA8+PL7eU7d4yfBRY0XL3fmBPtxqWJE1swiv3iDgaeENm/qrsfxD4R2AjsBr4WtneU+6yEfhMRGwA3g3sG1++mTMyYft22LQJfvvbXncjqWYf/jC8611df9hOlmWWAt+LiPHx/56Z/xERjwJ3RsQa4OfAhWX8JuBcYAT4DfDJrnc9XXbtgg0b4PbbYefOphbtVpkkqUve+tbehHtmPgOc3Kb+P8BZbeoJXNaV7mbC6CjccUcT6Fu3NrX3vhc+9zn46Eehr+07eyVpVpsVHz8w48bG4K67mqv0hx9ulmGGhuDrX4ePfQz6+3vdoSRNyfwK982b4frrm+2rr8JJJ8HVV8NFF8HgYK+7k6SumT/hvm0bnHceLF0KX/wiXHwxvPOdrqlLqtL8CPdf/7oJ88WL4bHHYMmSXnckSdNqfoT7FVfA0083yzEGu6R5oP6P/L37bli3rlmKOeuAF/dIUpXqDvfdu+HSS5tXwnzlK73uRpJmTL3h/uqr8IlPwMsvN69hP/zwXnckSTOm3jX3a6+FH/wAbrsNTjyx191I0oyq88r9hz+EL3+5eYXMJZf0uhtJmnH1hXsmfPazsHw53Hyzr2OXNC/Vtyxz//3NZ8R885twzDG97kaSeqK+K/drrmk+G8blGEnzWF1X7g8/3Ny+8Q1fHSNpXqvryv2aa5qP6L300l53Ikk9VU+4P/oo3Hdf81EDRx3V624kqafqCfevfhUWLYJPf7rXnUhSz9UR7jt2wPe/33x70pvf3OtuJKnn6gj3a6+Fo49uwl2SVEG4j4w0X5f3qU81n9cuSaog3K+7Dg47rPlDqiQJmOvhvns3rF8Pa9bAsmW97kaSZo25He533w1/+INX7ZK0n7kd7i+/3Gzf8pbe9iFJs8zcDndJUluGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SapQx+EeEQsi4vGIuLccnxARj0TEroi4IyIOL/UjyvFIOT8wPa1Lkg7mUK7cLwd2thxfB9yQmYPAi8CaUl8DvJiZJwI3lHGSpBnUUbhHRD/wIeBfy3EAZwJ3lSHrgfPL/qpyTDl/VhkvSZohnV653wh8CfhjOV4MvJSZr5TjUWB52V8O7AYo5/eV8ZKkGTJhuEfEecDezNzaWm4zNDs41/q4ayNiOCKGx8bGOmpWktSZTq7czwA+EhHPAhtolmNuBBZFxMIyph/YU/ZHgRUA5fwxwAv7P2hmrsvMocwc6uvrm9IkJEmvNWG4Z+ZVmdmfmQPARcADmflx4EHggjJsNXBP2d9YjinnH8jMA67cJUnTZyqvc/8b4IqIGKFZU7+l1G8BFpf6FcCVU2tRknSoFk485P9l5kPAQ2X/GeD0NmN+B1zYhd4kSZPkO1QlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFZow3CPiyIj4cUQ8ERFPRcTVpX5CRDwSEbsi4o6IOLzUjyjHI+X8wPROQZK0v06u3H8PnJmZJwOnAGdHxErgOuCGzBwEXgTWlPFrgBcz80TghjJOkjSDJgz3bPxvOTys3BI4E7ir1NcD55f9VeWYcv6siIiudSxJmlBHa+4RsSAitgF7gc3AT4GXMvOVMmQUWF72lwO7Acr5fcDiNo+5NiKGI2J4bGxsarOQJL1GR+Gema9m5ilAP3A6cFK7YWXb7io9DyhkrsvMocwc6uvr67RfSVIHDunVMpn5EvAQsBJYFBELy6l+YE/ZHwVWAJTzxwAvdKNZSVJnOnm1TF9ELCr7bwQ+AOwEHgQuKMNWA/eU/Y3lmHL+gcw84MpdkjR9Fk48hGXA+ohYQPPL4M7MvDci/gvYEBH/BDwO3FLG3wL8W0SM0FyxXzQNfUuSXseE4Z6Z24FT29SfoVl/37/+O+DCrnQnSZoU36EqSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFJgz3iFgREQ9GxM6IeCoiLi/14yJic0TsKttjSz0i4qaIGImI7RFx2nRPQpL0Wp1cub8CfCEzTwJWApdFxDuAK4EtmTkIbCnHAOcAg+W2Fri5611Lkl7XhOGemc9l5mNl/1fATmA5sApYX4atB84v+6uAb2XjR8CiiFjW9c4lSQd1SGvuETEAnAo8AizNzOeg+QUAHF+GLQd2t9xttNT2f6y1ETEcEcNjY2OH3rkk6aA6DveIeBPwXeDzmfnL1xvappYHFDLXZeZQZg719fV12oYkqQMdhXtEHEYT7N/OzLtL+fnx5Zay3Vvqo8CKlrv3A3u6064kqROdvFomgFuAnZl5fcupjcDqsr8auKelfkl51cxKYN/48o0kaWYs7GDMGcBfA09GxLZS+1vga8CdEbEG+DlwYTm3CTgXGAF+A3yyqx1LkiY0Ybhn5n/Sfh0d4Kw24xO4bIp9SZKmwHeoSlKFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6SVCHDXZIqZLhLUoUMd0mqkOEuSRUy3CWpQoa7JFXIcJekChnuklQhw12SKmS4S1KFDHdJqpDhLkkVMtwlqUKGuyRVyHCXpAoZ7pJUoQnDPSJujYi9EbGjpXZcRGyOiF1le2ypR0TcFBEjEbE9Ik6bzuYlSe11cuV+G3D2frUrgS2ZOQhsKccA5wCD5bYWuLk7bUqSDsWE4Z6ZPwBe2K+8Clhf9tcD57fUv5WNHwGLImJZt5qVJHVmsmvuSzPzOYCyPb7UlwO7W8aNlpokaQZ1+w+q0aaWbQdGrI2I4YgYHhsb63IbkjS/TTbcnx9fbinbvaU+CqxoGdcP7Gn3AJm5LjOHMnOor69vkm1IktqZbLhvBFaX/dXAPS31S8qrZlYC+8aXbyRJM2fhRAMi4nbgfcCSiBgF/gH4GnBnRKwBfg5cWIZvAs4FRoDfAJ+chp4lSROYMNwz8+KDnDqrzdgELptqU5KkqfEdqpJUIcNdkipkuEtShQx3SaqQ4S5JFTLcJalChrskVchwl6QKGe6
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# Not ordered\n",
"for c in range(3):\n",
" plt.plot(roc_curves[c][0], roc_curves[c][1], color = \"red\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}