1929 lines
125 KiB
Text
Executable file
1929 lines
125 KiB
Text
Executable file
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# # Classifiers introduction\n",
|
|
"\n",
|
|
"In the following program we introduce the basic steps of classification of a dataset in a matrix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Import the package for learning and modeling trees"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn import tree"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Define the matrix containing the data (one example per row)\n",
|
|
"and the vector containing the corresponding target value"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X = [[0, 0, 0], [1, 1, 1], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]\n",
|
|
"Y = [1, 0, 0, 0, 1, 1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Declare the classification model you want to use and then fit the model to the data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clf = tree.DecisionTreeClassifier()\n",
|
|
"clf = clf.fit(X, Y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Predict the target value (and print it) for the passed data, using the fitted model currently in clf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(clf.predict([[0, 1, 1]]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[1 0]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(clf.predict([[1, 0, 1],[0, 0, 1]]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
|
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
|
|
" -->\n",
|
|
"<!-- Title: Tree Pages: 1 -->\n",
|
|
"<svg width=\"485pt\" height=\"373pt\"\n",
|
|
" viewBox=\"0.00 0.00 485.00 373.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 369)\">\n",
|
|
"<title>Tree</title>\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-369 481,-369 481,4 -4,4\"/>\n",
|
|
"<!-- 0 -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>0</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"257,-365 165,-365 165,-297 257,-297 257,-365\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"211\" y=\"-349.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[2] <= 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"211\" y=\"-334.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"211\" y=\"-319.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 6</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"211\" y=\"-304.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [3, 3]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1 -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>1</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"202,-261 110,-261 110,-193 202,-193 202,-261\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-245.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[0] <= 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-230.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.444</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-215.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 3</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-200.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 2]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->1 -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>0->1</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M193.14,-296.88C188.58,-288.42 183.61,-279.21 178.84,-270.35\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"181.78,-268.44 173.96,-261.3 175.62,-271.76 181.78,-268.44\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"166.79\" y=\"-281.55\" font-family=\"Times,serif\" font-size=\"14.00\">True</text>\n",
|
|
"</g>\n",
|
|
"<!-- 6 -->\n",
|
|
"<g id=\"node7\" class=\"node\">\n",
|
|
"<title>6</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"312,-261 220,-261 220,-193 312,-193 312,-261\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"266\" y=\"-245.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[0] <= 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"266\" y=\"-230.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.444</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"266\" y=\"-215.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 3</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"266\" y=\"-200.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [2, 1]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->6 -->\n",
|
|
"<g id=\"edge6\" class=\"edge\">\n",
|
|
"<title>0->6</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M228.86,-296.88C233.42,-288.42 238.39,-279.21 243.16,-270.35\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"246.38,-271.76 248.04,-261.3 240.22,-268.44 246.38,-271.76\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"255.21\" y=\"-281.55\" font-family=\"Times,serif\" font-size=\"14.00\">False</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2 -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>2</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"92,-157 0,-157 0,-89 92,-89 92,-157\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"46\" y=\"-141.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[1] <= 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"46\" y=\"-126.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"46\" y=\"-111.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 2</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"46\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 1]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->2 -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>1->2</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M120.29,-192.88C110.39,-183.71 99.54,-173.65 89.27,-164.12\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"91.62,-161.53 81.91,-157.3 86.86,-166.67 91.62,-161.53\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 5 -->\n",
|
|
"<g id=\"node6\" class=\"node\">\n",
|
|
"<title>5</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"202,-149.5 110,-149.5 110,-96.5 202,-96.5 202,-149.5\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-134.3\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-119.3\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-104.3\" font-family=\"Times,serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->5 -->\n",
|
|
"<g id=\"edge5\" class=\"edge\">\n",
|
|
"<title>1->5</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M156,-192.88C156,-182.33 156,-170.6 156,-159.85\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"159.5,-159.52 156,-149.52 152.5,-159.52 159.5,-159.52\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 3 -->\n",
|
|
"<g id=\"node4\" class=\"node\">\n",
|
|
"<title>3</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"92,-53 0,-53 0,0 92,0 92,-53\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"46\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"46\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"46\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2->3 -->\n",
|
|
"<g id=\"edge3\" class=\"edge\">\n",
|
|
"<title>2->3</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M46,-88.95C46,-80.72 46,-71.85 46,-63.48\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"49.5,-63.24 46,-53.24 42.5,-63.24 49.5,-63.24\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 4 -->\n",
|
|
"<g id=\"node5\" class=\"node\">\n",
|
|
"<title>4</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"202,-53 110,-53 110,0 202,0 202,-53\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"156\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 0]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2->4 -->\n",
|
|
"<g id=\"edge4\" class=\"edge\">\n",
|
|
"<title>2->4</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M84.47,-88.95C95.44,-79.53 107.38,-69.27 118.3,-59.89\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"120.74,-62.41 126.04,-53.24 116.18,-57.1 120.74,-62.41\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 7 -->\n",
|
|
"<g id=\"node8\" class=\"node\">\n",
|
|
"<title>7</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"312,-149.5 220,-149.5 220,-96.5 312,-96.5 312,-149.5\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"266\" y=\"-134.3\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"266\" y=\"-119.3\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"266\" y=\"-104.3\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 0]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 6->7 -->\n",
|
|
"<g id=\"edge7\" class=\"edge\">\n",
|
|
"<title>6->7</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M266,-192.88C266,-182.33 266,-170.6 266,-159.85\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"269.5,-159.52 266,-149.52 262.5,-159.52 269.5,-159.52\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 8 -->\n",
|
|
"<g id=\"node9\" class=\"node\">\n",
|
|
"<title>8</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"422,-157 330,-157 330,-89 422,-89 422,-157\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"376\" y=\"-141.8\" font-family=\"Times,serif\" font-size=\"14.00\">X[1] <= 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"376\" y=\"-126.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.5</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"376\" y=\"-111.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 2</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"376\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 1]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 6->8 -->\n",
|
|
"<g id=\"edge8\" class=\"edge\">\n",
|
|
"<title>6->8</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M301.71,-192.88C311.61,-183.71 322.46,-173.65 332.73,-164.12\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"335.14,-166.67 340.09,-157.3 330.38,-161.53 335.14,-166.67\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 9 -->\n",
|
|
"<g id=\"node10\" class=\"node\">\n",
|
|
"<title>9</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"367,-53 275,-53 275,0 367,0 367,-53\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"321\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"321\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"321\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [0, 1]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 8->9 -->\n",
|
|
"<g id=\"edge9\" class=\"edge\">\n",
|
|
"<title>8->9</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M356.76,-88.95C351.71,-80.26 346.24,-70.86 341.13,-62.09\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"344.03,-60.12 335.98,-53.24 337.98,-63.64 344.03,-60.12\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 10 -->\n",
|
|
"<g id=\"node11\" class=\"node\">\n",
|
|
"<title>10</title>\n",
|
|
"<polygon fill=\"none\" stroke=\"black\" points=\"477,-53 385,-53 385,0 477,0 477,-53\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"431\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\">gini = 0.0</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"431\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\">samples = 1</text>\n",
|
|
"<text text-anchor=\"middle\" x=\"431\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\">value = [1, 0]</text>\n",
|
|
"</g>\n",
|
|
"<!-- 8->10 -->\n",
|
|
"<g id=\"edge10\" class=\"edge\">\n",
|
|
"<title>8->10</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M395.24,-88.95C400.29,-80.26 405.76,-70.86 410.87,-62.09\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"414.02,-63.64 416.02,-53.24 407.97,-60.12 414.02,-63.64\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>\n"
|
|
],
|
|
"text/plain": [
|
|
"<graphviz.files.Source at 0x7f35f4eb5290>"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"os.environ[\"PATH\"] += os.pathsep + 'C:/Users/galat/.conda/envs/aaut/Library/bin/graphviz'\n",
|
|
"import graphviz\n",
|
|
"dot_data = tree.export_graphviz(clf, out_file=None) \n",
|
|
"graph = graphviz.Source(dot_data) \n",
|
|
"graph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"In the following we start using a dataset (from UCI Machine Learning repository)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.datasets import load_iris\n",
|
|
"iris = load_iris()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Declare the type of prediction model and the working criteria for the model induction algorithm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=5,class_weight={0:1,1:1,2:1})"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Split the dataset in training and test set"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a random permutation of the indices of examples that will be later used \n",
|
|
"# for the training and the test set\n",
|
|
"import numpy as np\n",
|
|
"np.random.seed(0)\n",
|
|
"indices = np.random.permutation(len(iris.data))\n",
|
|
"\n",
|
|
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
|
|
"indices_training=indices[:-10]\n",
|
|
"indices_test=indices[-10:]\n",
|
|
"\n",
|
|
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
|
|
"iris_y_train = iris.target[indices_training]\n",
|
|
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
|
|
"iris_y_test = iris.target[indices_test]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Fit the learning model on training set"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# fit the model to the training data\n",
|
|
"clf = clf.fit(iris_X_train, iris_y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Obtain predictions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Predictions:\n",
|
|
"[1 2 1 0 0 0 2 1 2 0]\n",
|
|
"True classes:\n",
|
|
"[1 1 1 0 0 0 2 1 2 0]\n",
|
|
"['setosa' 'versicolor' 'virginica']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# apply fitted model \"clf\" to the test set \n",
|
|
"predicted_y_test = clf.predict(iris_X_test)\n",
|
|
"\n",
|
|
"# print the predictions (class numbers associated to classes names in target names)\n",
|
|
"print(\"Predictions:\")\n",
|
|
"print(predicted_y_test)\n",
|
|
"print(\"True classes:\")\n",
|
|
"print(iris_y_test) \n",
|
|
"print(iris.target_names)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Print the index of the test instances and the corresponding predictions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Look at the specific examples"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Instance # 88: \n",
|
|
"sepal length (cm)=5.6, sepal width (cm)=3.0, petal length (cm)=4.1, petal width (cm)=1.3\n",
|
|
"Predicted: versicolor\t True: versicolor\n",
|
|
"\n",
|
|
"Instance # 70: \n",
|
|
"sepal length (cm)=5.9, sepal width (cm)=3.2, petal length (cm)=4.8, petal width (cm)=1.8\n",
|
|
"Predicted: virginica\t True: versicolor\n",
|
|
"\n",
|
|
"Instance # 87: \n",
|
|
"sepal length (cm)=6.3, sepal width (cm)=2.3, petal length (cm)=4.4, petal width (cm)=1.3\n",
|
|
"Predicted: versicolor\t True: versicolor\n",
|
|
"\n",
|
|
"Instance # 36: \n",
|
|
"sepal length (cm)=5.5, sepal width (cm)=3.5, petal length (cm)=1.3, petal width (cm)=0.2\n",
|
|
"Predicted: setosa\t True: setosa\n",
|
|
"\n",
|
|
"Instance # 21: \n",
|
|
"sepal length (cm)=5.1, sepal width (cm)=3.7, petal length (cm)=1.5, petal width (cm)=0.4\n",
|
|
"Predicted: setosa\t True: setosa\n",
|
|
"\n",
|
|
"Instance # 9: \n",
|
|
"sepal length (cm)=4.9, sepal width (cm)=3.1, petal length (cm)=1.5, petal width (cm)=0.1\n",
|
|
"Predicted: setosa\t True: setosa\n",
|
|
"\n",
|
|
"Instance # 103: \n",
|
|
"sepal length (cm)=6.3, sepal width (cm)=2.9, petal length (cm)=5.6, petal width (cm)=1.8\n",
|
|
"Predicted: virginica\t True: virginica\n",
|
|
"\n",
|
|
"Instance # 67: \n",
|
|
"sepal length (cm)=5.8, sepal width (cm)=2.7, petal length (cm)=4.1, petal width (cm)=1.0\n",
|
|
"Predicted: versicolor\t True: versicolor\n",
|
|
"\n",
|
|
"Instance # 117: \n",
|
|
"sepal length (cm)=7.7, sepal width (cm)=3.8, petal length (cm)=6.7, petal width (cm)=2.2\n",
|
|
"Predicted: virginica\t True: virginica\n",
|
|
"\n",
|
|
"Instance # 47: \n",
|
|
"sepal length (cm)=4.6, sepal width (cm)=3.2, petal length (cm)=1.4, petal width (cm)=0.2\n",
|
|
"Predicted: setosa\t True: setosa\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for i in range(len(iris_y_test)): \n",
|
|
" print(\"Instance # \"+str(indices_test[i])+\": \")\n",
|
|
" s=\"\"\n",
|
|
" for j in range(len(iris.feature_names)):\n",
|
|
" s=s+iris.feature_names[j]+\"=\"+str(iris_X_test[i][j])\n",
|
|
" if (j<len(iris.feature_names)-1): s=s+\", \"\n",
|
|
" print(s)\n",
|
|
" print(\"Predicted: \"+iris.target_names[predicted_y_test[i]]+\"\\t True: \"+iris.target_names[iris_y_test[i]]+\"\\n\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Obtain model performance results"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Accuracy score: 0.9\n",
|
|
"F1 score: 0.8857142857142858\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# print some metrics results\n",
|
|
"from sklearn.metrics import accuracy_score\n",
|
|
"from sklearn.metrics import f1_score\n",
|
|
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
|
|
"print(\"Accuracy score: \"+ str(acc_score))\n",
|
|
"f1=f1_score(iris_y_test, predicted_y_test, average='macro')\n",
|
|
"print(\"F1 score: \"+str(f1))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Use Cross Validation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.96666667 1. 0.86666667 0.86666667 1. ]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.datasets import load_iris\n",
|
|
"from sklearn.model_selection import cross_val_score # will be used to separate training and test\n",
|
|
"iris = load_iris()\n",
|
|
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=5,class_weight={0:1,1:1,2:1})\n",
|
|
"clf = clf.fit(iris.data, iris.target)\n",
|
|
"scores = cross_val_score(clf, iris.data, iris.target, cv=5) # score will be the accuracy\n",
|
|
"print(scores)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.96658312 1. 0.86111111 0.86666667 1. ]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# computes F1- score\n",
|
|
"f1_scores = cross_val_score(clf, iris.data, iris.target, cv=5, scoring='f1_macro')\n",
|
|
"print(f1_scores)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Show the resulting tree "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. Print the picture in a PDF file"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'my_iris_predictions.pdf'"
|
|
]
|
|
},
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import graphviz \n",
|
|
"dot_data = tree.export_graphviz(clf, out_file=None) \n",
|
|
"graph = graphviz.Source(dot_data) \n",
|
|
"graph.render(\"my_iris_predictions\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Generate a picture here"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n",
|
|
"['setosa', 'versicolor', 'virginica']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(list(iris.feature_names))\n",
|
|
"print(list(iris.target_names))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
|
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
|
|
" -->\n",
|
|
"<!-- Title: Tree Pages: 1 -->\n",
|
|
"<svg width=\"639pt\" height=\"552pt\"\n",
|
|
" viewBox=\"0.00 0.00 639.00 552.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 548)\">\n",
|
|
"<title>Tree</title>\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-548 635,-548 635,4 -4,4\"/>\n",
|
|
"<!-- 0 -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>0</title>\n",
|
|
"<path fill=\"#ffffff\" stroke=\"black\" d=\"M353.5,-544C353.5,-544 217.5,-544 217.5,-544 211.5,-544 205.5,-538 205.5,-532 205.5,-532 205.5,-473 205.5,-473 205.5,-467 211.5,-461 217.5,-461 217.5,-461 353.5,-461 353.5,-461 359.5,-461 365.5,-467 365.5,-473 365.5,-473 365.5,-532 365.5,-532 365.5,-538 359.5,-544 353.5,-544\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"213.5\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\n",
|
|
"<text text-anchor=\"start\" x=\"238\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.585</text>\n",
|
|
"<text text-anchor=\"start\" x=\"240.5\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 150</text>\n",
|
|
"<text text-anchor=\"start\" x=\"227.5\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [50, 50, 50]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"242\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1 -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>1</title>\n",
|
|
"<path fill=\"#e58139\" stroke=\"black\" d=\"M255,-417.5C255,-417.5 162,-417.5 162,-417.5 156,-417.5 150,-411.5 150,-405.5 150,-405.5 150,-361.5 150,-361.5 150,-355.5 156,-349.5 162,-349.5 162,-349.5 255,-349.5 255,-349.5 261,-349.5 267,-355.5 267,-361.5 267,-361.5 267,-405.5 267,-405.5 267,-411.5 261,-417.5 255,-417.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"168.5\" y=\"-402.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"167.5\" y=\"-387.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\n",
|
|
"<text text-anchor=\"start\" x=\"158\" y=\"-372.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [50, 0, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"165\" y=\"-357.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->1 -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>0->1</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M258.79,-460.91C251.38,-449.65 243.33,-437.42 235.88,-426.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"238.75,-424.1 230.33,-417.67 232.9,-427.94 238.75,-424.1\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"225.28\" y=\"-438.45\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2 -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>2</title>\n",
|
|
"<path fill=\"#ffffff\" stroke=\"black\" d=\"M428,-425C428,-425 297,-425 297,-425 291,-425 285,-419 285,-413 285,-413 285,-354 285,-354 285,-348 291,-342 297,-342 297,-342 428,-342 428,-342 434,-342 440,-348 440,-354 440,-354 440,-413 440,-413 440,-419 434,-425 428,-425\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"293\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\n",
|
|
"<text text-anchor=\"start\" x=\"322.5\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"317.5\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 100</text>\n",
|
|
"<text text-anchor=\"start\" x=\"308\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 50]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"310\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->2 -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>0->2</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M312.21,-460.91C318.01,-452.1 324.2,-442.7 330.18,-433.61\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"333.26,-435.3 335.83,-425.02 327.41,-431.45 333.26,-435.3\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"340.88\" y=\"-445.81\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
|
|
"</g>\n",
|
|
"<!-- 3 -->\n",
|
|
"<g id=\"node4\" class=\"node\">\n",
|
|
"<title>3</title>\n",
|
|
"<path fill=\"#4de88e\" stroke=\"black\" d=\"M341.5,-306C341.5,-306 205.5,-306 205.5,-306 199.5,-306 193.5,-300 193.5,-294 193.5,-294 193.5,-235 193.5,-235 193.5,-229 199.5,-223 205.5,-223 205.5,-223 341.5,-223 341.5,-223 347.5,-223 353.5,-229 353.5,-235 353.5,-235 353.5,-294 353.5,-294 353.5,-300 347.5,-306 341.5,-306\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"201.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\n",
|
|
"<text text-anchor=\"start\" x=\"226\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.445</text>\n",
|
|
"<text text-anchor=\"start\" x=\"232.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 54</text>\n",
|
|
"<text text-anchor=\"start\" x=\"223\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 49, 5]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"221\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2->3 -->\n",
|
|
"<g id=\"edge3\" class=\"edge\">\n",
|
|
"<title>2->3</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M331.62,-341.91C324.79,-332.92 317.48,-323.32 310.43,-314.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"313.16,-311.86 304.32,-306.02 307.59,-316.1 313.16,-311.86\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 8 -->\n",
|
|
"<g id=\"node9\" class=\"node\">\n",
|
|
"<title>8</title>\n",
|
|
"<path fill=\"#843de6\" stroke=\"black\" d=\"M519.5,-306C519.5,-306 383.5,-306 383.5,-306 377.5,-306 371.5,-300 371.5,-294 371.5,-294 371.5,-235 371.5,-235 371.5,-229 377.5,-223 383.5,-223 383.5,-223 519.5,-223 519.5,-223 525.5,-223 531.5,-229 531.5,-235 531.5,-235 531.5,-294 531.5,-294 531.5,-300 525.5,-306 519.5,-306\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"379.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\n",
|
|
"<text text-anchor=\"start\" x=\"404\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.151</text>\n",
|
|
"<text text-anchor=\"start\" x=\"410.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 46</text>\n",
|
|
"<text text-anchor=\"start\" x=\"401\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1, 45]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"403\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2->8 -->\n",
|
|
"<g id=\"edge8\" class=\"edge\">\n",
|
|
"<title>2->8</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M393.38,-341.91C400.21,-332.92 407.52,-323.32 414.57,-314.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"417.41,-316.1 420.68,-306.02 411.84,-311.86 417.41,-316.1\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 4 -->\n",
|
|
"<g id=\"node5\" class=\"node\">\n",
|
|
"<title>4</title>\n",
|
|
"<path fill=\"#3de684\" stroke=\"black\" d=\"M199,-187C199,-187 60,-187 60,-187 54,-187 48,-181 48,-175 48,-175 48,-116 48,-116 48,-110 54,-104 60,-104 60,-104 199,-104 199,-104 205,-104 211,-110 211,-116 211,-116 211,-175 211,-175 211,-181 205,-187 199,-187\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"56\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sepal length (cm) ≤ 5.15</text>\n",
|
|
"<text text-anchor=\"start\" x=\"82\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.146</text>\n",
|
|
"<text text-anchor=\"start\" x=\"88.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 48</text>\n",
|
|
"<text text-anchor=\"start\" x=\"79\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 47, 1]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"77\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 3->4 -->\n",
|
|
"<g id=\"edge4\" class=\"edge\">\n",
|
|
"<title>3->4</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M223.54,-222.91C211.81,-213.38 199.22,-203.15 187.19,-193.37\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"189.34,-190.61 179.37,-187.02 184.93,-196.04 189.34,-190.61\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 7 -->\n",
|
|
"<g id=\"node8\" class=\"node\">\n",
|
|
"<title>7</title>\n",
|
|
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M330,-179.5C330,-179.5 241,-179.5 241,-179.5 235,-179.5 229,-173.5 229,-167.5 229,-167.5 229,-123.5 229,-123.5 229,-117.5 235,-111.5 241,-111.5 241,-111.5 330,-111.5 330,-111.5 336,-111.5 342,-117.5 342,-123.5 342,-123.5 342,-167.5 342,-167.5 342,-173.5 336,-179.5 330,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"238\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\n",
|
|
"<text text-anchor=\"start\" x=\"248\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\n",
|
|
"<text text-anchor=\"start\" x=\"238.5\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 2, 4]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"237\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 3->7 -->\n",
|
|
"<g id=\"edge7\" class=\"edge\">\n",
|
|
"<title>3->7</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M277.66,-222.91C278.76,-212.2 279.95,-200.62 281.06,-189.78\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"284.56,-189.97 282.1,-179.67 277.6,-189.26 284.56,-189.97\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 5 -->\n",
|
|
"<g id=\"node6\" class=\"node\">\n",
|
|
"<title>5</title>\n",
|
|
"<path fill=\"#6aeca0\" stroke=\"black\" d=\"M109,-68C109,-68 12,-68 12,-68 6,-68 0,-62 0,-56 0,-56 0,-12 0,-12 0,-6 6,0 12,0 12,0 109,0 109,0 115,0 121,-6 121,-12 121,-12 121,-56 121,-56 121,-62 115,-68 109,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"13\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.722</text>\n",
|
|
"<text text-anchor=\"start\" x=\"23\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\n",
|
|
"<text text-anchor=\"start\" x=\"13.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 4, 1]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"8\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 4->5 -->\n",
|
|
"<g id=\"edge5\" class=\"edge\">\n",
|
|
"<title>4->5</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M103.81,-103.73C98.29,-94.97 92.45,-85.7 86.91,-76.91\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"89.78,-74.89 81.48,-68.3 83.85,-78.63 89.78,-74.89\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 6 -->\n",
|
|
"<g id=\"node7\" class=\"node\">\n",
|
|
"<title>6</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M248,-68C248,-68 151,-68 151,-68 145,-68 139,-62 139,-56 139,-56 139,-12 139,-12 139,-6 145,0 151,0 151,0 248,0 248,0 254,0 260,-6 260,-12 260,-12 260,-56 260,-56 260,-62 254,-68 248,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"159.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"158.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\n",
|
|
"<text text-anchor=\"start\" x=\"149\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 43, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"147\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 4->6 -->\n",
|
|
"<g id=\"edge6\" class=\"edge\">\n",
|
|
"<title>4->6</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M155.57,-103.73C161.16,-94.97 167.09,-85.7 172.71,-76.91\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"175.78,-78.61 178.21,-68.3 169.88,-74.84 175.78,-78.61\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 9 -->\n",
|
|
"<g id=\"node10\" class=\"node\">\n",
|
|
"<title>9</title>\n",
|
|
"<path fill=\"#9a61ea\" stroke=\"black\" d=\"M484,-179.5C484,-179.5 395,-179.5 395,-179.5 389,-179.5 383,-173.5 383,-167.5 383,-167.5 383,-123.5 383,-123.5 383,-117.5 389,-111.5 395,-111.5 395,-111.5 484,-111.5 484,-111.5 490,-111.5 496,-117.5 496,-123.5 496,-123.5 496,-167.5 496,-167.5 496,-173.5 490,-179.5 484,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"395.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.65</text>\n",
|
|
"<text text-anchor=\"start\" x=\"402\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\n",
|
|
"<text text-anchor=\"start\" x=\"392.5\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1, 5]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"391\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 8->9 -->\n",
|
|
"<g id=\"edge9\" class=\"edge\">\n",
|
|
"<title>8->9</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M447.34,-222.91C446.24,-212.2 445.05,-200.62 443.94,-189.78\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"447.4,-189.26 442.9,-179.67 440.44,-189.97 447.4,-189.26\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 10 -->\n",
|
|
"<g id=\"node11\" class=\"node\">\n",
|
|
"<title>10</title>\n",
|
|
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M619,-179.5C619,-179.5 526,-179.5 526,-179.5 520,-179.5 514,-173.5 514,-167.5 514,-167.5 514,-123.5 514,-123.5 514,-117.5 520,-111.5 526,-111.5 526,-111.5 619,-111.5 619,-111.5 625,-111.5 631,-117.5 631,-123.5 631,-123.5 631,-167.5 631,-167.5 631,-173.5 625,-179.5 619,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"532.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"531.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 40</text>\n",
|
|
"<text text-anchor=\"start\" x=\"522\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 40]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"524\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 8->10 -->\n",
|
|
"<g id=\"edge10\" class=\"edge\">\n",
|
|
"<title>8->10</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M493.48,-222.91C505.58,-211.21 518.77,-198.46 530.84,-186.78\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"533.44,-189.13 538.2,-179.67 528.58,-184.1 533.44,-189.13\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>\n"
|
|
],
|
|
"text/plain": [
|
|
"<graphviz.files.Source at 0x7f36064a3450>"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
|
|
" feature_names=iris.feature_names, \n",
|
|
" class_names=iris.target_names, \n",
|
|
" filled=True, rounded=True, \n",
|
|
" special_characters=True) \n",
|
|
"graph = graphviz.Source(dot_data) \n",
|
|
"graph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 1. Artificial inflation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a random permutation of the indices of examples that will be later used \n",
|
|
"# for the training and the test set\n",
|
|
"import numpy as np\n",
|
|
"np.random.seed(42)\n",
|
|
"indices = np.random.permutation(len(iris.data))\n",
|
|
"\n",
|
|
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
|
|
"indices_training=indices[:-10]\n",
|
|
"indices_test=indices[-10:]\n",
|
|
"\n",
|
|
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
|
|
"iris_y_train = iris.target[indices_training]\n",
|
|
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
|
|
"iris_y_test = iris.target[indices_test]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"samples_x = []\n",
|
|
"samples_y = []\n",
|
|
"for i in range(0, len(iris_y_train)):\n",
|
|
" if iris_y_train[i] == 1:\n",
|
|
" for _ in range(9):\n",
|
|
" samples_x.append(iris_X_train[i])\n",
|
|
" samples_y.append(1)\n",
|
|
" elif iris_y_train[i] == 2:\n",
|
|
" for _ in range(9):\n",
|
|
" samples_x.append(iris_X_train[i])\n",
|
|
" samples_y.append(2)\n",
|
|
"\n",
|
|
"#Samples inflation\n",
|
|
"iris_X_train = np.append(iris_X_train, samples_x, axis = 0)\n",
|
|
"iris_y_train = np.append(iris_y_train, samples_y, axis = 0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Accuracy: 0.9\n",
|
|
"F1: 0.9153439153439153\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=10,class_weight={0:1,1:1,2:1})\n",
|
|
"clf = clf.fit(iris_X_train, iris_y_train)\n",
|
|
"predicted_y_test = clf.predict(iris_X_test)\n",
|
|
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
|
|
"f1 = f1_score(iris_y_test, predicted_y_test, average='macro')\n",
|
|
"print(\"Accuracy: \", acc_score)\n",
|
|
"print(\"F1: \", f1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
|
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
|
|
" -->\n",
|
|
"<!-- Title: Tree Pages: 1 -->\n",
|
|
"<svg width=\"767pt\" height=\"671pt\"\n",
|
|
" viewBox=\"0.00 0.00 767.00 671.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 667)\">\n",
|
|
"<title>Tree</title>\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-667 763,-667 763,4 -4,4\"/>\n",
|
|
"<!-- 0 -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>0</title>\n",
|
|
"<path fill=\"#ffffff\" stroke=\"black\" d=\"M349.5,-663C349.5,-663 213.5,-663 213.5,-663 207.5,-663 201.5,-657 201.5,-651 201.5,-651 201.5,-592 201.5,-592 201.5,-586 207.5,-580 213.5,-580 213.5,-580 349.5,-580 349.5,-580 355.5,-580 361.5,-586 361.5,-592 361.5,-592 361.5,-651 361.5,-651 361.5,-657 355.5,-663 349.5,-663\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"209.5\" y=\"-647.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.75</text>\n",
|
|
"<text text-anchor=\"start\" x=\"234\" y=\"-632.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.235</text>\n",
|
|
"<text text-anchor=\"start\" x=\"236.5\" y=\"-617.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 968</text>\n",
|
|
"<text text-anchor=\"start\" x=\"216\" y=\"-602.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [48, 460, 460]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"229\" y=\"-587.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1 -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>1</title>\n",
|
|
"<path fill=\"#51e890\" stroke=\"black\" d=\"M261.5,-544C261.5,-544 125.5,-544 125.5,-544 119.5,-544 113.5,-538 113.5,-532 113.5,-532 113.5,-473 113.5,-473 113.5,-467 119.5,-461 125.5,-461 125.5,-461 261.5,-461 261.5,-461 267.5,-461 273.5,-467 273.5,-473 273.5,-473 273.5,-532 273.5,-532 273.5,-538 267.5,-544 261.5,-544\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"121.5\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\n",
|
|
"<text text-anchor=\"start\" x=\"146\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.491</text>\n",
|
|
"<text text-anchor=\"start\" x=\"148.5\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 448</text>\n",
|
|
"<text text-anchor=\"start\" x=\"135.5\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [48, 400, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"141\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->1 -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>0->1</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M250.97,-579.91C244.21,-570.92 236.98,-561.32 230.02,-552.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"232.79,-549.91 223.98,-544.02 227.19,-554.12 232.79,-549.91\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"220.48\" y=\"-565.07\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
|
|
"</g>\n",
|
|
"<!-- 4 -->\n",
|
|
"<g id=\"node5\" class=\"node\">\n",
|
|
"<title>4</title>\n",
|
|
"<path fill=\"#9153e8\" stroke=\"black\" d=\"M435,-544C435,-544 304,-544 304,-544 298,-544 292,-538 292,-532 292,-532 292,-473 292,-473 292,-467 298,-461 304,-461 304,-461 435,-461 435,-461 441,-461 447,-467 447,-473 447,-473 447,-532 447,-532 447,-538 441,-544 435,-544\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"300\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\n",
|
|
"<text text-anchor=\"start\" x=\"322\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.516</text>\n",
|
|
"<text text-anchor=\"start\" x=\"324.5\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 520</text>\n",
|
|
"<text text-anchor=\"start\" x=\"311.5\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 60, 460]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"321\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->4 -->\n",
|
|
"<g id=\"edge4\" class=\"edge\">\n",
|
|
"<title>0->4</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M312.03,-579.91C318.79,-570.92 326.02,-561.32 332.98,-552.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"335.81,-554.12 339.02,-544.02 330.21,-549.91 335.81,-554.12\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"342.52\" y=\"-565.07\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2 -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>2</title>\n",
|
|
"<path fill=\"#e58139\" stroke=\"black\" d=\"M105,-417.5C105,-417.5 12,-417.5 12,-417.5 6,-417.5 0,-411.5 0,-405.5 0,-405.5 0,-361.5 0,-361.5 0,-355.5 6,-349.5 12,-349.5 12,-349.5 105,-349.5 105,-349.5 111,-349.5 117,-355.5 117,-361.5 117,-361.5 117,-405.5 117,-405.5 117,-411.5 111,-417.5 105,-417.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"18.5\" y=\"-402.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"17.5\" y=\"-387.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 48</text>\n",
|
|
"<text text-anchor=\"start\" x=\"8\" y=\"-372.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [48, 0, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"15\" y=\"-357.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->2 -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>1->2</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M146.66,-460.91C133.04,-449.1 118.17,-436.22 104.6,-424.45\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"106.62,-421.57 96.77,-417.67 102.03,-426.86 106.62,-421.57\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 3 -->\n",
|
|
"<g id=\"node4\" class=\"node\">\n",
|
|
"<title>3</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M248,-417.5C248,-417.5 147,-417.5 147,-417.5 141,-417.5 135,-411.5 135,-405.5 135,-405.5 135,-361.5 135,-361.5 135,-355.5 141,-349.5 147,-349.5 147,-349.5 248,-349.5 248,-349.5 254,-349.5 260,-355.5 260,-361.5 260,-361.5 260,-405.5 260,-405.5 260,-411.5 254,-417.5 248,-417.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"157.5\" y=\"-402.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"152.5\" y=\"-387.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 400</text>\n",
|
|
"<text text-anchor=\"start\" x=\"143\" y=\"-372.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 400, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"145\" y=\"-357.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->3 -->\n",
|
|
"<g id=\"edge3\" class=\"edge\">\n",
|
|
"<title>1->3</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M194.89,-460.91C195.25,-450.2 195.65,-438.62 196.02,-427.78\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"199.52,-427.78 196.37,-417.67 192.53,-427.54 199.52,-427.78\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 5 -->\n",
|
|
"<g id=\"node6\" class=\"node\">\n",
|
|
"<title>5</title>\n",
|
|
"<path fill=\"#d7fae6\" stroke=\"black\" d=\"M433.5,-425C433.5,-425 297.5,-425 297.5,-425 291.5,-425 285.5,-419 285.5,-413 285.5,-413 285.5,-354 285.5,-354 285.5,-348 291.5,-342 297.5,-342 297.5,-342 433.5,-342 433.5,-342 439.5,-342 445.5,-348 445.5,-354 445.5,-354 445.5,-413 445.5,-413 445.5,-419 439.5,-425 433.5,-425\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"293.5\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.95</text>\n",
|
|
"<text text-anchor=\"start\" x=\"318\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.991</text>\n",
|
|
"<text text-anchor=\"start\" x=\"324.5\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 90</text>\n",
|
|
"<text text-anchor=\"start\" x=\"311\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 40]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"313\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 4->5 -->\n",
|
|
"<g id=\"edge5\" class=\"edge\">\n",
|
|
"<title>4->5</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M368.11,-460.91C367.83,-452.56 367.52,-443.67 367.23,-435.02\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"370.73,-434.9 366.89,-425.02 363.73,-435.13 370.73,-434.9\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 12 -->\n",
|
|
"<g id=\"node13\" class=\"node\">\n",
|
|
"<title>12</title>\n",
|
|
"<path fill=\"#843ee6\" stroke=\"black\" d=\"M611.5,-425C611.5,-425 475.5,-425 475.5,-425 469.5,-425 463.5,-419 463.5,-413 463.5,-413 463.5,-354 463.5,-354 463.5,-348 469.5,-342 475.5,-342 475.5,-342 611.5,-342 611.5,-342 617.5,-342 623.5,-348 623.5,-354 623.5,-354 623.5,-413 623.5,-413 623.5,-419 617.5,-425 611.5,-425\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"471.5\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.85</text>\n",
|
|
"<text text-anchor=\"start\" x=\"496\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.159</text>\n",
|
|
"<text text-anchor=\"start\" x=\"498.5\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 430</text>\n",
|
|
"<text text-anchor=\"start\" x=\"485.5\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 420]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"495\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 4->12 -->\n",
|
|
"<g id=\"edge12\" class=\"edge\">\n",
|
|
"<title>4->12</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M429.87,-460.91C444.31,-451.2 459.83,-440.76 474.63,-430.81\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"476.89,-433.51 483.24,-425.02 472.99,-427.7 476.89,-433.51\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 6 -->\n",
|
|
"<g id=\"node7\" class=\"node\">\n",
|
|
"<title>6</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M261,-298.5C261,-298.5 164,-298.5 164,-298.5 158,-298.5 152,-292.5 152,-286.5 152,-286.5 152,-242.5 152,-242.5 152,-236.5 158,-230.5 164,-230.5 164,-230.5 261,-230.5 261,-230.5 267,-230.5 273,-236.5 273,-242.5 273,-242.5 273,-286.5 273,-286.5 273,-292.5 267,-298.5 261,-298.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"172.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"171.5\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 30</text>\n",
|
|
"<text text-anchor=\"start\" x=\"162\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"160\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 5->6 -->\n",
|
|
"<g id=\"edge6\" class=\"edge\">\n",
|
|
"<title>5->6</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M312.42,-341.91C296.69,-329.88 279.5,-316.73 263.88,-304.79\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"265.94,-301.96 255.87,-298.67 261.69,-307.52 265.94,-301.96\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 7 -->\n",
|
|
"<g id=\"node8\" class=\"node\">\n",
|
|
"<title>7</title>\n",
|
|
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M434,-306C434,-306 303,-306 303,-306 297,-306 291,-300 291,-294 291,-294 291,-235 291,-235 291,-229 297,-223 303,-223 303,-223 434,-223 434,-223 440,-223 446,-229 446,-235 446,-235 446,-294 446,-294 446,-300 440,-306 434,-306\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"299\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.55</text>\n",
|
|
"<text text-anchor=\"start\" x=\"321\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\n",
|
|
"<text text-anchor=\"start\" x=\"327.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 60</text>\n",
|
|
"<text text-anchor=\"start\" x=\"314\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 20, 40]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"320\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 5->7 -->\n",
|
|
"<g id=\"edge7\" class=\"edge\">\n",
|
|
"<title>5->7</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M366.54,-341.91C366.75,-333.56 366.98,-324.67 367.2,-316.02\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"370.7,-316.11 367.46,-306.02 363.71,-315.93 370.7,-316.11\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 8 -->\n",
|
|
"<g id=\"node9\" class=\"node\">\n",
|
|
"<title>8</title>\n",
|
|
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M262,-179.5C262,-179.5 169,-179.5 169,-179.5 163,-179.5 157,-173.5 157,-167.5 157,-167.5 157,-123.5 157,-123.5 157,-117.5 163,-111.5 169,-111.5 169,-111.5 262,-111.5 262,-111.5 268,-111.5 274,-117.5 274,-123.5 274,-123.5 274,-167.5 274,-167.5 274,-173.5 268,-179.5 262,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"175.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"174.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 30</text>\n",
|
|
"<text text-anchor=\"start\" x=\"165\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 30]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"167\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 7->8 -->\n",
|
|
"<g id=\"edge8\" class=\"edge\">\n",
|
|
"<title>7->8</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M315.42,-222.91C299.69,-210.88 282.5,-197.73 266.88,-185.79\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"268.94,-182.96 258.87,-179.67 264.69,-188.52 268.94,-182.96\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 9 -->\n",
|
|
"<g id=\"node10\" class=\"node\">\n",
|
|
"<title>9</title>\n",
|
|
"<path fill=\"#9cf2c0\" stroke=\"black\" d=\"M440.5,-187C440.5,-187 304.5,-187 304.5,-187 298.5,-187 292.5,-181 292.5,-175 292.5,-175 292.5,-116 292.5,-116 292.5,-110 298.5,-104 304.5,-104 304.5,-104 440.5,-104 440.5,-104 446.5,-104 452.5,-110 452.5,-116 452.5,-116 452.5,-175 452.5,-175 452.5,-181 446.5,-187 440.5,-187\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"300.5\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 5.45</text>\n",
|
|
"<text text-anchor=\"start\" x=\"325\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\n",
|
|
"<text text-anchor=\"start\" x=\"331.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 30</text>\n",
|
|
"<text text-anchor=\"start\" x=\"318\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 20, 10]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"320\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 7->9 -->\n",
|
|
"<g id=\"edge9\" class=\"edge\">\n",
|
|
"<title>7->9</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M369.89,-222.91C370.17,-214.56 370.48,-205.67 370.77,-197.02\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"374.27,-197.13 371.11,-187.02 367.27,-196.9 374.27,-197.13\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 10 -->\n",
|
|
"<g id=\"node11\" class=\"node\">\n",
|
|
"<title>10</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M353,-68C353,-68 256,-68 256,-68 250,-68 244,-62 244,-56 244,-56 244,-12 244,-12 244,-6 250,0 256,0 256,0 353,0 353,0 359,0 365,-6 365,-12 365,-12 365,-56 365,-56 365,-62 359,-68 353,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"264.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"263.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 20</text>\n",
|
|
"<text text-anchor=\"start\" x=\"254\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 20, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"252\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 9->10 -->\n",
|
|
"<g id=\"edge10\" class=\"edge\">\n",
|
|
"<title>9->10</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M347.18,-103.73C341.74,-94.97 335.99,-85.7 330.52,-76.91\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"333.43,-74.95 325.18,-68.3 327.48,-78.64 333.43,-74.95\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 11 -->\n",
|
|
"<g id=\"node12\" class=\"node\">\n",
|
|
"<title>11</title>\n",
|
|
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M488,-68C488,-68 395,-68 395,-68 389,-68 383,-62 383,-56 383,-56 383,-12 383,-12 383,-6 389,0 395,0 395,0 488,0 488,0 494,0 500,-6 500,-12 500,-12 500,-56 500,-56 500,-62 494,-68 488,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"401.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"400.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 10</text>\n",
|
|
"<text text-anchor=\"start\" x=\"391\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 10]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"393\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 9->11 -->\n",
|
|
"<g id=\"edge11\" class=\"edge\">\n",
|
|
"<title>9->11</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M398.19,-103.73C403.71,-94.97 409.55,-85.7 415.09,-76.91\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"418.15,-78.63 420.52,-68.3 412.22,-74.89 418.15,-78.63\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 13 -->\n",
|
|
"<g id=\"node14\" class=\"node\">\n",
|
|
"<title>13</title>\n",
|
|
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M604.5,-306C604.5,-306 478.5,-306 478.5,-306 472.5,-306 466.5,-300 466.5,-294 466.5,-294 466.5,-235 466.5,-235 466.5,-229 472.5,-223 478.5,-223 478.5,-223 604.5,-223 604.5,-223 610.5,-223 616.5,-229 616.5,-235 616.5,-235 616.5,-294 616.5,-294 616.5,-300 610.5,-306 604.5,-306\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"474.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sepal width (cm) ≤ 3.1</text>\n",
|
|
"<text text-anchor=\"start\" x=\"494\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\n",
|
|
"<text text-anchor=\"start\" x=\"500.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 30</text>\n",
|
|
"<text text-anchor=\"start\" x=\"487\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"493\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 12->13 -->\n",
|
|
"<g id=\"edge13\" class=\"edge\">\n",
|
|
"<title>12->13</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M542.81,-341.91C542.66,-333.56 542.51,-324.67 542.36,-316.02\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"545.86,-315.96 542.19,-306.02 538.86,-316.08 545.86,-315.96\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 16 -->\n",
|
|
"<g id=\"node17\" class=\"node\">\n",
|
|
"<title>16</title>\n",
|
|
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M747,-298.5C747,-298.5 646,-298.5 646,-298.5 640,-298.5 634,-292.5 634,-286.5 634,-286.5 634,-242.5 634,-242.5 634,-236.5 640,-230.5 646,-230.5 646,-230.5 747,-230.5 747,-230.5 753,-230.5 759,-236.5 759,-242.5 759,-242.5 759,-286.5 759,-286.5 759,-292.5 753,-298.5 747,-298.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"656.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"651.5\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 400</text>\n",
|
|
"<text text-anchor=\"start\" x=\"642\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 400]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"648\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 12->16 -->\n",
|
|
"<g id=\"edge16\" class=\"edge\">\n",
|
|
"<title>12->16</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M596.58,-341.91C612.31,-329.88 629.5,-316.73 645.12,-304.79\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"647.31,-307.52 653.13,-298.67 643.06,-301.96 647.31,-307.52\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 14 -->\n",
|
|
"<g id=\"node15\" class=\"node\">\n",
|
|
"<title>14</title>\n",
|
|
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M586,-179.5C586,-179.5 493,-179.5 493,-179.5 487,-179.5 481,-173.5 481,-167.5 481,-167.5 481,-123.5 481,-123.5 481,-117.5 487,-111.5 493,-111.5 493,-111.5 586,-111.5 586,-111.5 592,-111.5 598,-117.5 598,-123.5 598,-123.5 598,-167.5 598,-167.5 598,-173.5 592,-179.5 586,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"499.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"498.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 20</text>\n",
|
|
"<text text-anchor=\"start\" x=\"489\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"491\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 13->14 -->\n",
|
|
"<g id=\"edge14\" class=\"edge\">\n",
|
|
"<title>13->14</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M540.81,-222.91C540.62,-212.2 540.43,-200.62 540.24,-189.78\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"543.74,-189.61 540.07,-179.67 536.74,-189.73 543.74,-189.61\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 15 -->\n",
|
|
"<g id=\"node16\" class=\"node\">\n",
|
|
"<title>15</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M725,-179.5C725,-179.5 628,-179.5 628,-179.5 622,-179.5 616,-173.5 616,-167.5 616,-167.5 616,-123.5 616,-123.5 616,-117.5 622,-111.5 628,-111.5 628,-111.5 725,-111.5 725,-111.5 731,-111.5 737,-117.5 737,-123.5 737,-123.5 737,-167.5 737,-167.5 737,-173.5 731,-179.5 725,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"636.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"635.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 10</text>\n",
|
|
"<text text-anchor=\"start\" x=\"626\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"624\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 13->15 -->\n",
|
|
"<g id=\"edge15\" class=\"edge\">\n",
|
|
"<title>13->15</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M588.34,-222.91C601.96,-211.1 616.83,-198.22 630.4,-186.45\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"632.97,-188.86 638.23,-179.67 628.38,-183.57 632.97,-188.86\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>\n"
|
|
],
|
|
"text/plain": [
|
|
"<graphviz.files.Source at 0x7f360644d750>"
|
|
]
|
|
},
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
|
|
" feature_names=iris.feature_names, \n",
|
|
" class_names=iris.target_names, \n",
|
|
" filled=True, rounded=True, \n",
|
|
" special_characters=True) \n",
|
|
"graph = graphviz.Source(dot_data) \n",
|
|
"graph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 2. Class weights"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a random permutation of the indices of examples that will be later used \n",
|
|
"# for the training and the test set\n",
|
|
"import numpy as np\n",
|
|
"np.random.seed(1231)\n",
|
|
"indices = np.random.permutation(len(iris.data))\n",
|
|
"\n",
|
|
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
|
|
"indices_training=indices[:-10]\n",
|
|
"indices_test=indices[-10:]\n",
|
|
"\n",
|
|
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
|
|
"iris_y_train = iris.target[indices_training]\n",
|
|
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
|
|
"iris_y_test = iris.target[indices_test]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Accuracy: 0.8\n",
|
|
"F1: 0.5\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=5,class_weight={0:1,1:10,2:10})\n",
|
|
"clf = clf.fit(iris_X_train, iris_y_train)\n",
|
|
"predicted_y_test = clf.predict(iris_X_test)\n",
|
|
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
|
|
"f1 = f1_score(iris_y_test, predicted_y_test, average='macro')\n",
|
|
"print(\"Accuracy: \", acc_score)\n",
|
|
"print(\"F1: \", f1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
|
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
|
|
" -->\n",
|
|
"<!-- Title: Tree Pages: 1 -->\n",
|
|
"<svg width=\"593pt\" height=\"552pt\"\n",
|
|
" viewBox=\"0.00 0.00 593.00 552.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 548)\">\n",
|
|
"<title>Tree</title>\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-548 589,-548 589,4 -4,4\"/>\n",
|
|
"<!-- 0 -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>0</title>\n",
|
|
"<path fill=\"#fdfbff\" stroke=\"black\" d=\"M363.5,-544C363.5,-544 227.5,-544 227.5,-544 221.5,-544 215.5,-538 215.5,-532 215.5,-532 215.5,-473 215.5,-473 215.5,-467 221.5,-461 227.5,-461 227.5,-461 363.5,-461 363.5,-461 369.5,-461 375.5,-467 375.5,-473 375.5,-473 375.5,-532 375.5,-532 375.5,-538 369.5,-544 363.5,-544\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"223.5\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.85</text>\n",
|
|
"<text text-anchor=\"start\" x=\"248\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.211</text>\n",
|
|
"<text text-anchor=\"start\" x=\"250.5\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 140</text>\n",
|
|
"<text text-anchor=\"start\" x=\"230\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 480, 490]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"247\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1 -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>1</title>\n",
|
|
"<path fill=\"#54e892\" stroke=\"black\" d=\"M275.5,-425C275.5,-425 139.5,-425 139.5,-425 133.5,-425 127.5,-419 127.5,-413 127.5,-413 127.5,-354 127.5,-354 127.5,-348 133.5,-342 139.5,-342 139.5,-342 275.5,-342 275.5,-342 281.5,-342 287.5,-348 287.5,-354 287.5,-354 287.5,-413 287.5,-413 287.5,-419 281.5,-425 275.5,-425\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"135.5\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\n",
|
|
"<text text-anchor=\"start\" x=\"160\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.648</text>\n",
|
|
"<text text-anchor=\"start\" x=\"166.5\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 90</text>\n",
|
|
"<text text-anchor=\"start\" x=\"145.5\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 450, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"155\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->1 -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>0->1</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M264.97,-460.91C258.21,-451.92 250.98,-442.32 244.02,-433.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"246.79,-430.91 237.98,-425.02 241.19,-435.12 246.79,-430.91\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"234.48\" y=\"-446.07\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
|
|
"</g>\n",
|
|
"<!-- 8 -->\n",
|
|
"<g id=\"node9\" class=\"node\">\n",
|
|
"<title>8</title>\n",
|
|
"<path fill=\"#8946e7\" stroke=\"black\" d=\"M449,-425C449,-425 318,-425 318,-425 312,-425 306,-419 306,-413 306,-413 306,-354 306,-354 306,-348 312,-342 318,-342 318,-342 449,-342 449,-342 455,-342 461,-348 461,-354 461,-354 461,-413 461,-413 461,-419 455,-425 449,-425\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"314\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\n",
|
|
"<text text-anchor=\"start\" x=\"336\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.327</text>\n",
|
|
"<text text-anchor=\"start\" x=\"342.5\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\n",
|
|
"<text text-anchor=\"start\" x=\"325.5\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 470]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"335\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->8 -->\n",
|
|
"<g id=\"edge8\" class=\"edge\">\n",
|
|
"<title>0->8</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M326.03,-460.91C332.79,-451.92 340.02,-442.32 346.98,-433.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"349.81,-435.12 353.02,-425.02 344.21,-430.91 349.81,-435.12\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"356.52\" y=\"-446.07\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2 -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>2</title>\n",
|
|
"<path fill=\"#e58139\" stroke=\"black\" d=\"M105,-298.5C105,-298.5 12,-298.5 12,-298.5 6,-298.5 0,-292.5 0,-286.5 0,-286.5 0,-242.5 0,-242.5 0,-236.5 6,-230.5 12,-230.5 12,-230.5 105,-230.5 105,-230.5 111,-230.5 117,-236.5 117,-242.5 117,-242.5 117,-286.5 117,-286.5 117,-292.5 111,-298.5 105,-298.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"18.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"17.5\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\n",
|
|
"<text text-anchor=\"start\" x=\"8\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 0, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"15\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->2 -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>1->2</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M155.81,-341.91C140.63,-329.99 124.05,-316.98 108.96,-305.12\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"110.76,-302.09 100.74,-298.67 106.44,-307.6 110.76,-302.09\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 3 -->\n",
|
|
"<g id=\"node4\" class=\"node\">\n",
|
|
"<title>3</title>\n",
|
|
"<path fill=\"#42e687\" stroke=\"black\" d=\"M278,-306C278,-306 147,-306 147,-306 141,-306 135,-300 135,-294 135,-294 135,-235 135,-235 135,-229 141,-223 147,-223 147,-223 278,-223 278,-223 284,-223 290,-229 290,-235 290,-235 290,-294 290,-294 290,-300 284,-306 278,-306\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"143\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.45</text>\n",
|
|
"<text text-anchor=\"start\" x=\"165\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.254</text>\n",
|
|
"<text text-anchor=\"start\" x=\"171.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 47</text>\n",
|
|
"<text text-anchor=\"start\" x=\"154.5\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 450, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"160\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->3 -->\n",
|
|
"<g id=\"edge3\" class=\"edge\">\n",
|
|
"<title>1->3</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M209.23,-341.91C209.59,-333.56 209.97,-324.67 210.34,-316.02\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"213.84,-316.16 210.77,-306.02 206.84,-315.86 213.84,-316.16\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 4 -->\n",
|
|
"<g id=\"node5\" class=\"node\">\n",
|
|
"<title>4</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M184,-179.5C184,-179.5 83,-179.5 83,-179.5 77,-179.5 71,-173.5 71,-167.5 71,-167.5 71,-123.5 71,-123.5 71,-117.5 77,-111.5 83,-111.5 83,-111.5 184,-111.5 184,-111.5 190,-111.5 196,-117.5 196,-123.5 196,-123.5 196,-167.5 196,-167.5 196,-173.5 190,-179.5 184,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"93.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"92.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 35</text>\n",
|
|
"<text text-anchor=\"start\" x=\"79\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 350, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"81\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 3->4 -->\n",
|
|
"<g id=\"edge4\" class=\"edge\">\n",
|
|
"<title>3->4</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M185.09,-222.91C177.49,-211.65 169.23,-199.42 161.59,-188.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"164.39,-186 155.89,-179.67 158.59,-189.91 164.39,-186\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 5 -->\n",
|
|
"<g id=\"node6\" class=\"node\">\n",
|
|
"<title>5</title>\n",
|
|
"<path fill=\"#61ea9a\" stroke=\"black\" d=\"M357,-187C357,-187 226,-187 226,-187 220,-187 214,-181 214,-175 214,-175 214,-116 214,-116 214,-110 220,-104 226,-104 226,-104 357,-104 357,-104 363,-104 369,-110 369,-116 369,-116 369,-175 369,-175 369,-181 363,-187 357,-187\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"222\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sepal length (cm) ≤ 6.1</text>\n",
|
|
"<text text-anchor=\"start\" x=\"247.5\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.65</text>\n",
|
|
"<text text-anchor=\"start\" x=\"250.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 12</text>\n",
|
|
"<text text-anchor=\"start\" x=\"233.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 100, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"239\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 3->5 -->\n",
|
|
"<g id=\"edge5\" class=\"edge\">\n",
|
|
"<title>3->5</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M239.91,-222.91C245.91,-214.01 252.33,-204.51 258.53,-195.33\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"261.44,-197.27 264.14,-187.02 255.64,-193.35 261.44,-197.27\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 6 -->\n",
|
|
"<g id=\"node7\" class=\"node\">\n",
|
|
"<title>6</title>\n",
|
|
"<path fill=\"#88efb3\" stroke=\"black\" d=\"M272,-68C272,-68 171,-68 171,-68 165,-68 159,-62 159,-56 159,-56 159,-12 159,-12 159,-6 165,0 171,0 171,0 272,0 272,0 278,0 284,-6 284,-12 284,-12 284,-56 284,-56 284,-62 278,-68 272,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"174\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.863</text>\n",
|
|
"<text text-anchor=\"start\" x=\"184\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 7</text>\n",
|
|
"<text text-anchor=\"start\" x=\"167\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"169\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 5->6 -->\n",
|
|
"<g id=\"edge6\" class=\"edge\">\n",
|
|
"<title>5->6</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M265.43,-103.73C259.84,-94.97 253.91,-85.7 248.29,-76.91\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"251.12,-74.84 242.79,-68.3 245.22,-78.61 251.12,-74.84\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 7 -->\n",
|
|
"<g id=\"node8\" class=\"node\">\n",
|
|
"<title>7</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M411,-68C411,-68 314,-68 314,-68 308,-68 302,-62 302,-56 302,-56 302,-12 302,-12 302,-6 308,0 314,0 314,0 411,0 411,0 417,0 423,-6 423,-12 423,-12 423,-56 423,-56 423,-62 417,-68 411,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"322.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"325\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 5</text>\n",
|
|
"<text text-anchor=\"start\" x=\"312\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 50, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"310\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 5->7 -->\n",
|
|
"<g id=\"edge7\" class=\"edge\">\n",
|
|
"<title>5->7</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M317.94,-103.73C323.62,-94.97 329.62,-85.7 335.33,-76.91\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"338.4,-78.59 340.91,-68.3 332.53,-74.79 338.4,-78.59\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 9 -->\n",
|
|
"<g id=\"node10\" class=\"node\">\n",
|
|
"<title>9</title>\n",
|
|
"<path fill=\"#e0cef8\" stroke=\"black\" d=\"M430,-298.5C430,-298.5 329,-298.5 329,-298.5 323,-298.5 317,-292.5 317,-286.5 317,-286.5 317,-242.5 317,-242.5 317,-236.5 323,-230.5 329,-230.5 329,-230.5 430,-230.5 430,-230.5 436,-230.5 442,-236.5 442,-242.5 442,-242.5 442,-286.5 442,-286.5 442,-292.5 436,-298.5 430,-298.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"332\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.985</text>\n",
|
|
"<text text-anchor=\"start\" x=\"342\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 7</text>\n",
|
|
"<text text-anchor=\"start\" x=\"325\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 40]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"331\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 8->9 -->\n",
|
|
"<g id=\"edge9\" class=\"edge\">\n",
|
|
"<title>8->9</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M382.11,-341.91C381.75,-331.2 381.35,-319.62 380.98,-308.78\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"384.47,-308.54 380.63,-298.67 377.48,-308.78 384.47,-308.54\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 10 -->\n",
|
|
"<g id=\"node11\" class=\"node\">\n",
|
|
"<title>10</title>\n",
|
|
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M573,-298.5C573,-298.5 472,-298.5 472,-298.5 466,-298.5 460,-292.5 460,-286.5 460,-286.5 460,-242.5 460,-242.5 460,-236.5 466,-230.5 472,-230.5 472,-230.5 573,-230.5 573,-230.5 579,-230.5 585,-236.5 585,-242.5 585,-242.5 585,-286.5 585,-286.5 585,-292.5 579,-298.5 573,-298.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"482.5\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"481.5\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\n",
|
|
"<text text-anchor=\"start\" x=\"468\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 430]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"474\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 8->10 -->\n",
|
|
"<g id=\"edge10\" class=\"edge\">\n",
|
|
"<title>8->10</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M431.73,-341.91C445.88,-329.99 461.35,-316.98 475.43,-305.12\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"477.7,-307.78 483.1,-298.67 473.19,-302.43 477.7,-307.78\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>\n"
|
|
],
|
|
"text/plain": [
|
|
"<graphviz.files.Source at 0x7f36064a3cd0>"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
|
|
" feature_names=iris.feature_names, \n",
|
|
" class_names=iris.target_names, \n",
|
|
" filled=True, rounded=True, \n",
|
|
" special_characters=True) \n",
|
|
"graph = graphviz.Source(dot_data) \n",
|
|
"graph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 3. Avoid overfitting"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate a random permutation of the indices of examples that will be later used \n",
|
|
"# for the training and the test set\n",
|
|
"import numpy as np\n",
|
|
"np.random.seed(42)\n",
|
|
"indices = np.random.permutation(len(iris.data))\n",
|
|
"\n",
|
|
"# We now decide to keep the last 10 indices for test set, the remaining for the training set\n",
|
|
"indices_training=indices[:-10]\n",
|
|
"indices_test=indices[-10:]\n",
|
|
"\n",
|
|
"iris_X_train = iris.data[indices_training] # keep for training all the matrix elements with the exception of the last 10 \n",
|
|
"iris_y_train = iris.target[indices_training]\n",
|
|
"iris_X_test = iris.data[indices_test] # keep the last 10 elements for test set\n",
|
|
"iris_y_test = iris.target[indices_test]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Accuracy: 0.9\n",
|
|
"F1: 0.9153439153439153\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"clf = tree.DecisionTreeClassifier(criterion=\"entropy\",random_state=300,min_samples_leaf=3,class_weight={0:1,1:10,2:10}, min_impurity_decrease = 0.005, max_depth = 4, max_leaf_nodes = 6)\n",
|
|
"clf = clf.fit(iris_X_train, iris_y_train)\n",
|
|
"predicted_y_test = clf.predict(iris_X_test)\n",
|
|
"acc_score = accuracy_score(iris_y_test, predicted_y_test)\n",
|
|
"f1 = f1_score(iris_y_test, predicted_y_test, average='macro')\n",
|
|
"print(\"Accuracy: \", acc_score)\n",
|
|
"print(\"F1: \", f1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
|
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n",
|
|
" -->\n",
|
|
"<!-- Title: Tree Pages: 1 -->\n",
|
|
"<svg width=\"620pt\" height=\"433pt\"\n",
|
|
" viewBox=\"0.00 0.00 620.00 433.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 429)\">\n",
|
|
"<title>Tree</title>\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-429 616,-429 616,4 -4,4\"/>\n",
|
|
"<!-- 0 -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>0</title>\n",
|
|
"<path fill=\"#fdfbff\" stroke=\"black\" d=\"M368.5,-425C368.5,-425 232.5,-425 232.5,-425 226.5,-425 220.5,-419 220.5,-413 220.5,-413 220.5,-354 220.5,-354 220.5,-348 226.5,-342 232.5,-342 232.5,-342 368.5,-342 368.5,-342 374.5,-342 380.5,-348 380.5,-354 380.5,-354 380.5,-413 380.5,-413 380.5,-419 374.5,-425 368.5,-425\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"228.5\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 4.85</text>\n",
|
|
"<text text-anchor=\"start\" x=\"253\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 1.211</text>\n",
|
|
"<text text-anchor=\"start\" x=\"255.5\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 140</text>\n",
|
|
"<text text-anchor=\"start\" x=\"235\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 480, 490]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"252\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1 -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>1</title>\n",
|
|
"<path fill=\"#54e892\" stroke=\"black\" d=\"M280.5,-306C280.5,-306 144.5,-306 144.5,-306 138.5,-306 132.5,-300 132.5,-294 132.5,-294 132.5,-235 132.5,-235 132.5,-229 138.5,-223 144.5,-223 144.5,-223 280.5,-223 280.5,-223 286.5,-223 292.5,-229 292.5,-235 292.5,-235 292.5,-294 292.5,-294 292.5,-300 286.5,-306 280.5,-306\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"140.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 2.45</text>\n",
|
|
"<text text-anchor=\"start\" x=\"165\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.648</text>\n",
|
|
"<text text-anchor=\"start\" x=\"171.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 90</text>\n",
|
|
"<text text-anchor=\"start\" x=\"150.5\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 450, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"160\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->1 -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>0->1</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M269.97,-341.91C263.21,-332.92 255.98,-323.32 249.02,-314.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"251.79,-311.91 242.98,-306.02 246.19,-316.12 251.79,-311.91\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"239.48\" y=\"-327.07\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2 -->\n",
|
|
"<g id=\"node7\" class=\"node\">\n",
|
|
"<title>2</title>\n",
|
|
"<path fill=\"#8946e7\" stroke=\"black\" d=\"M454,-306C454,-306 323,-306 323,-306 317,-306 311,-300 311,-294 311,-294 311,-235 311,-235 311,-229 317,-223 323,-223 323,-223 454,-223 454,-223 460,-223 466,-229 466,-235 466,-235 466,-294 466,-294 466,-300 460,-306 454,-306\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"319\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.75</text>\n",
|
|
"<text text-anchor=\"start\" x=\"341\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.327</text>\n",
|
|
"<text text-anchor=\"start\" x=\"347.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 50</text>\n",
|
|
"<text text-anchor=\"start\" x=\"330.5\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 470]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"340\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 0->2 -->\n",
|
|
"<g id=\"edge6\" class=\"edge\">\n",
|
|
"<title>0->2</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M331.03,-341.91C337.79,-332.92 345.02,-323.32 351.98,-314.05\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"354.81,-316.12 358.02,-306.02 349.21,-311.91 354.81,-316.12\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"361.52\" y=\"-327.07\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
|
|
"</g>\n",
|
|
"<!-- 3 -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>3</title>\n",
|
|
"<path fill=\"#e58139\" stroke=\"black\" d=\"M105,-179.5C105,-179.5 12,-179.5 12,-179.5 6,-179.5 0,-173.5 0,-167.5 0,-167.5 0,-123.5 0,-123.5 0,-117.5 6,-111.5 12,-111.5 12,-111.5 105,-111.5 105,-111.5 111,-111.5 117,-117.5 117,-123.5 117,-123.5 117,-167.5 117,-167.5 117,-173.5 111,-179.5 105,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"18.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"17.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\n",
|
|
"<text text-anchor=\"start\" x=\"8\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [43, 0, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"15\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = setosa</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->3 -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>1->3</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M159.07,-222.91C143.24,-210.88 125.94,-197.73 110.22,-185.79\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"112.24,-182.93 102.16,-179.67 108,-188.5 112.24,-182.93\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 4 -->\n",
|
|
"<g id=\"node4\" class=\"node\">\n",
|
|
"<title>4</title>\n",
|
|
"<path fill=\"#42e687\" stroke=\"black\" d=\"M278,-187C278,-187 147,-187 147,-187 141,-187 135,-181 135,-175 135,-175 135,-116 135,-116 135,-110 141,-104 147,-104 147,-104 278,-104 278,-104 284,-104 290,-110 290,-116 290,-116 290,-175 290,-175 290,-181 284,-187 278,-187\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"143\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal width (cm) ≤ 1.65</text>\n",
|
|
"<text text-anchor=\"start\" x=\"165\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.254</text>\n",
|
|
"<text text-anchor=\"start\" x=\"171.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 47</text>\n",
|
|
"<text text-anchor=\"start\" x=\"154.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 450, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"160\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 1->4 -->\n",
|
|
"<g id=\"edge3\" class=\"edge\">\n",
|
|
"<title>1->4</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M212.5,-222.91C212.5,-214.65 212.5,-205.86 212.5,-197.3\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"216,-197.02 212.5,-187.02 209,-197.02 216,-197.02\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 7 -->\n",
|
|
"<g id=\"node5\" class=\"node\">\n",
|
|
"<title>7</title>\n",
|
|
"<path fill=\"#39e581\" stroke=\"black\" d=\"M128,-68C128,-68 27,-68 27,-68 21,-68 15,-62 15,-56 15,-56 15,-12 15,-12 15,-6 21,0 27,0 27,0 128,0 128,0 134,0 140,-6 140,-12 140,-12 140,-56 140,-56 140,-62 134,-68 128,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"37.5\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"36.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 44</text>\n",
|
|
"<text text-anchor=\"start\" x=\"23\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 440, 0]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"25\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 4->7 -->\n",
|
|
"<g id=\"edge4\" class=\"edge\">\n",
|
|
"<title>4->7</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M162.23,-103.73C150.54,-94.24 138.1,-84.16 126.46,-74.72\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"128.53,-71.88 118.55,-68.3 124.12,-77.32 128.53,-71.88\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 8 -->\n",
|
|
"<g id=\"node6\" class=\"node\">\n",
|
|
"<title>8</title>\n",
|
|
"<path fill=\"#c09cf2\" stroke=\"black\" d=\"M271,-68C271,-68 170,-68 170,-68 164,-68 158,-62 158,-56 158,-56 158,-12 158,-12 158,-6 164,0 170,0 170,0 271,0 271,0 277,0 283,-6 283,-12 283,-12 283,-56 283,-56 283,-62 277,-68 271,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"173\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\n",
|
|
"<text text-anchor=\"start\" x=\"183\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\n",
|
|
"<text text-anchor=\"start\" x=\"166\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 20]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"172\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 4->8 -->\n",
|
|
"<g id=\"edge5\" class=\"edge\">\n",
|
|
"<title>4->8</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M215.48,-103.73C216.09,-95.43 216.73,-86.67 217.34,-78.28\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"220.83,-78.53 218.07,-68.3 213.85,-78.02 220.83,-78.53\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 5 -->\n",
|
|
"<g id=\"node8\" class=\"node\">\n",
|
|
"<title>5</title>\n",
|
|
"<path fill=\"#e0cef8\" stroke=\"black\" d=\"M456.5,-187C456.5,-187 320.5,-187 320.5,-187 314.5,-187 308.5,-181 308.5,-175 308.5,-175 308.5,-116 308.5,-116 308.5,-110 314.5,-104 320.5,-104 320.5,-104 456.5,-104 456.5,-104 462.5,-104 468.5,-110 468.5,-116 468.5,-116 468.5,-175 468.5,-175 468.5,-181 462.5,-187 456.5,-187\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"316.5\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">petal length (cm) ≤ 5.05</text>\n",
|
|
"<text text-anchor=\"start\" x=\"341\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.985</text>\n",
|
|
"<text text-anchor=\"start\" x=\"351\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 7</text>\n",
|
|
"<text text-anchor=\"start\" x=\"334\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 30, 40]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"340\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2->5 -->\n",
|
|
"<g id=\"edge7\" class=\"edge\">\n",
|
|
"<title>2->5</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M388.5,-222.91C388.5,-214.65 388.5,-205.86 388.5,-197.3\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"392,-197.02 388.5,-187.02 385,-197.02 392,-197.02\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 6 -->\n",
|
|
"<g id=\"node11\" class=\"node\">\n",
|
|
"<title>6</title>\n",
|
|
"<path fill=\"#8139e5\" stroke=\"black\" d=\"M600,-179.5C600,-179.5 499,-179.5 499,-179.5 493,-179.5 487,-173.5 487,-167.5 487,-167.5 487,-123.5 487,-123.5 487,-117.5 493,-111.5 499,-111.5 499,-111.5 600,-111.5 600,-111.5 606,-111.5 612,-117.5 612,-123.5 612,-123.5 612,-167.5 612,-167.5 612,-173.5 606,-179.5 600,-179.5\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"509.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.0</text>\n",
|
|
"<text text-anchor=\"start\" x=\"508.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 43</text>\n",
|
|
"<text text-anchor=\"start\" x=\"495\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 0, 430]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"501\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 2->6 -->\n",
|
|
"<g id=\"edge10\" class=\"edge\">\n",
|
|
"<title>2->6</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M444.36,-222.91C460.91,-210.88 479,-197.73 495.43,-185.79\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"497.83,-188.38 503.86,-179.67 493.71,-182.71 497.83,-188.38\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 9 -->\n",
|
|
"<g id=\"node9\" class=\"node\">\n",
|
|
"<title>9</title>\n",
|
|
"<path fill=\"#9cf2c0\" stroke=\"black\" d=\"M430,-68C430,-68 329,-68 329,-68 323,-68 317,-62 317,-56 317,-56 317,-12 317,-12 317,-6 323,0 329,0 329,0 430,0 430,0 436,0 442,-6 442,-12 442,-12 442,-56 442,-56 442,-62 436,-68 430,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"332\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.918</text>\n",
|
|
"<text text-anchor=\"start\" x=\"342\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\n",
|
|
"<text text-anchor=\"start\" x=\"325\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 20, 10]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"327\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = versicolor</text>\n",
|
|
"</g>\n",
|
|
"<!-- 5->9 -->\n",
|
|
"<g id=\"edge8\" class=\"edge\">\n",
|
|
"<title>5->9</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M385.15,-103.73C384.47,-95.43 383.75,-86.67 383.06,-78.28\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"386.54,-77.98 382.24,-68.3 379.57,-78.55 386.54,-77.98\"/>\n",
|
|
"</g>\n",
|
|
"<!-- 10 -->\n",
|
|
"<g id=\"node10\" class=\"node\">\n",
|
|
"<title>10</title>\n",
|
|
"<path fill=\"#ab7bee\" stroke=\"black\" d=\"M573,-68C573,-68 472,-68 472,-68 466,-68 460,-62 460,-56 460,-56 460,-12 460,-12 460,-6 466,0 472,0 472,0 573,0 573,0 579,0 585,-6 585,-12 585,-12 585,-56 585,-56 585,-62 579,-68 573,-68\"/>\n",
|
|
"<text text-anchor=\"start\" x=\"475\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">entropy = 0.811</text>\n",
|
|
"<text text-anchor=\"start\" x=\"485\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 4</text>\n",
|
|
"<text text-anchor=\"start\" x=\"468\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 10, 30]</text>\n",
|
|
"<text text-anchor=\"start\" x=\"474\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = virginica</text>\n",
|
|
"</g>\n",
|
|
"<!-- 5->10 -->\n",
|
|
"<g id=\"edge9\" class=\"edge\">\n",
|
|
"<title>5->10</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M438.4,-103.73C450,-94.24 462.35,-84.16 473.9,-74.72\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"476.22,-77.34 481.75,-68.3 471.79,-71.92 476.22,-77.34\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>\n"
|
|
],
|
|
"text/plain": [
|
|
"<graphviz.files.Source at 0x7f360648aad0>"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
|
|
" feature_names=iris.feature_names, \n",
|
|
" class_names=iris.target_names, \n",
|
|
" filled=True, rounded=True, \n",
|
|
" special_characters=True) \n",
|
|
"graph = graphviz.Source(dot_data) \n",
|
|
"graph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 4. Confusion Matrix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[2, 0, 0],\n",
|
|
" [0, 4, 0],\n",
|
|
" [0, 1, 3]])"
|
|
]
|
|
},
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# initializes the confusion matrix\n",
|
|
"confusion = np.zeros([3, 3], dtype = int)\n",
|
|
"\n",
|
|
"# print the corresponding instances indexes and class names\n",
|
|
"for i in range(len(iris_y_test)): \n",
|
|
" #increments the indexed cell value\n",
|
|
" confusion[iris_y_test[i], predicted_y_test[i]]+=1\n",
|
|
"confusion"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 5. ROC Curves"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[[(0.0, 48.0), (30.0, 0.0), (30.0, 0.0), (60.0, 0.0), (400.0, 0.0), (400.0, 0.0)], [(0.0, 400.0), (0.0, 30.0), (20.0, 10.0), (40.0, 20.0), (48.0, 0.0), (400.0, 0.0)], [(0.0, 400.0), (10.0, 20.0), (20.0, 40.0), (30.0, 0.0), (48.0, 0.0), (400.0, 0.0)]]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[[[0, 0.0, 30.0, 60.0, 120.0, 520.0, 920.0],\n",
|
|
" [0, 48.0, 48.0, 48.0, 48.0, 48.0, 48.0]],\n",
|
|
" [[0, 0.0, 0.0, 20.0, 60.0, 108.0, 508.0],\n",
|
|
" [0, 400.0, 430.0, 440.0, 460.0, 460.0, 460.0]],\n",
|
|
" [[0, 0.0, 10.0, 30.0, 60.0, 108.0, 508.0],\n",
|
|
" [0, 400.0, 420.0, 460.0, 460.0, 460.0, 460.0]]]"
|
|
]
|
|
},
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Calculates the ROC curves (x, y)\n",
|
|
"leafs = []\n",
|
|
"class_pairs = [[],[],[]]\n",
|
|
"roc_curves = [[[0], [0]], [[0], [0]], [[0], [0]]]\n",
|
|
"for i in range(clf.tree_.node_count):\n",
|
|
" if (clf.tree_.feature[i] == -2):\n",
|
|
" leafs.append(i)\n",
|
|
"\n",
|
|
"# c = class index\n",
|
|
"for leaf in leafs:\n",
|
|
" for c in range(3):\n",
|
|
" #pairs(neg, pos)\n",
|
|
" class_pairs[c].append((clf.tree_.value[leaf][0].sum() - clf.tree_.value[leaf][0][c], clf.tree_.value[leaf][0][c]))\n",
|
|
"\n",
|
|
"#pairs sorting\n",
|
|
"for c in range(3):\n",
|
|
" class_pairs[c] = sorted(class_pairs[c], key=lambda t: t[0]/max(1,t[1]))\n",
|
|
"print(class_pairs)\n",
|
|
"\n",
|
|
"for i in range(1, len(leafs) + 1):\n",
|
|
" for c in range(3):\n",
|
|
" roc_curves[c][0].append(class_pairs[c][i - 1][0] + roc_curves[c][0][i - 1])\n",
|
|
" roc_curves[c][1].append(class_pairs[c][i - 1][1] + roc_curves[c][1][i - 1])\n",
|
|
"\n",
|
|
"roc_curves"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD7CAYAAABzGc+QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANSklEQVR4nO3dX4xc9XmH8edbm0BLqgBha7kY10RYRKgSJlohELlIIbQ0jQIXCIGi1heWfJOqpI2UQnsVqRdBqgJUqqJYIa1VhQAlpCAUJaUOqKpUOVkXSgBDMQQSLIOXFJK0F22dvL2YY7wstne8nvX69T4fabTn3+z85vj48dmzM55UFZKkfn5puQcgSVocAy5JTRlwSWrKgEtSUwZckpoy4JLU1OpxNkryMvAz4OfAgaqaTnIOcB+wAXgZuLGq3lyaYUqS5juWM/DfqqpNVTU9zN8K7KiqjcCOYV6SdIJknDfyDGfg01X1xpxlzwMfqap9SdYCj1fVRUf7Pueee25t2LDh+EYsSSvMrl273qiqqfnLx7qEAhTwj0kK+FJVbQPWVNW+Yf1rwJqFvsmGDRuYmZkZd8ySJCDJK4dbPm7AP1xVe5P8GvBokufmrqyqGuJ+uAfeCmwFWL9+/TEMWZJ0NGNdA6+qvcPX/cA3gMuA14dLJwxf9x/hvtuqarqqpqem3vUTgCRpkRYMeJIzk/zqwWngt4GngYeBzcNmm4GHlmqQkqR3G+cSyhrgG0kObn9PVX0ryfeA+5NsAV4Bbly6YUqS5lsw4FX1EnDJYZb/GLh6KQYlSVqY78SUpKYMuCQ1Ne7LCJfXtm1wzz3LPQpJWpxNm+DOOyf+bXucgd9zDzz55HKPQpJOKj3OwGH0L9jjjy/3KCTppNHjDFyS9C4GXJKaMuCS1JQBl6SmDLgkNWXAJakpAy5JTRlwSWrKgEtSUwZckpoy4JLUlAGXpKYMuCQ1ZcAlqSkDLklNGXBJasqAS1JTBlySmjLgktSUAZekpgy4JDVlwCWpKQMuSU0ZcElqyoBLUlMGXJKaMuCS1NTYAU+yKskTSR4Z5i9IsjPJniT3JXnP0g1TkjTfsZyB3wLsnjN/O3BHVV0IvAlsmeTAJElHN1bAk6wDfg/48jAf4CrggWGT7cD1SzFASdLhjXsGfifwWeAXw/z7gbeq6sAw/ypw3oTHJkk6igUDnuTjwP6q2rWYB0iyNclMkpnZ2dnFfAtJ0mGMcwZ+JfCJJC8D9zK6dHIXcFaS1cM264C9h7tzVW2rqumqmp6amprAkCVJMEbAq+q2qlpXVRuAm4DvVNUngceAG4bNNgMPLdkoJUnvcjyvA/9T4E+S7GF0TfzuyQxJkjSO1QtvckhVPQ48Pky/BFw2+SFJksbhOzElqSkDLklNGXBJasqAS1JTBlySmjLgktSUAZekpgy4JDVlwCWpKQMuSU0ZcElqyoBLUlMGXJKaMuCS1JQBl6SmDLgkNWXAJakpAy5JTRlwSWrKgEtSUwZckpoy4JLUlAGXpKYMuCQ1ZcAlqSkDLklNGXBJasqAS1JTBlySmjLgktSUAZekpgy4JDVlwCWpqQUDnuSMJN9N8u9JnknyuWH5BUl2JtmT5L4k71n64UqSDhrnDPx/gKuq6hJgE3BtksuB24E7qupC4E1gy9INU5I034IBr5H/GmZPG24FXAU8MCzfDly/JCOUJB3WWNfAk6xK8iSwH3gUeBF4q6oODJu8Cpx3hPtuTTKTZGZ2dnYSY5YkMWbAq+rnVbUJWAdcBnxw3Aeoqm1VNV1V01NTU4scpiRpvmN6FUpVvQU8BlwBnJVk9bBqHbB3wmOTJB3FOK9CmUpy1jD9y8A1wG5GIb9h2Gwz8NBSDVKS9G6rF96EtcD2JKsYBf/+qnokybPAvUn+AngCuHsJxylJmmfBgFfVU8Clh1n+EqPr4ZKkZeA7MSWpKQMuSU0ZcElqyoBLUlMGXJKaMuCS1JQBl6SmDLgkNWXAJakpAy5JTRlwSWrKgEtSUwZckpoy4JLUlAGXpKYMuCQ1ZcAlqSkDLklNGXBJasqAS1JTBlySmjLgktSUAZekpgy4JDVlwCWpKQMuSU0ZcElqyoBLUlMGXJKaMuCS1JQBl6SmDLgkNbVgwJOcn+SxJM8meSbJLcPyc5I8muSF4evZSz9cSdJB45yBHwA+U1UXA5cDn0pyMXArsKOqNgI7hnlJ0gmyYMCral9V/dsw/TNgN3AecB2wfdhsO3D9Ug1SkvRux3QNPMkG4FJgJ7CmqvYNq14D1kx0ZJKkoxo74EneC3wd+HRV/XTuuqoqoI5wv61JZpLMzM7OHtdgJUmHjBXwJKcxivdXq+rBYfHrSdYO69cC+w9336raVlXTVTU9NTU1iTFLkhjvVSgB7gZ2V9UX5qx6GNg8TG8GHpr88CRJR7J6jG2uBH4f+H6SJ4dlfwZ8Hrg/yRbgFeDGpRmiJOlwFgx4Vf0LkCOsvnqyw5Ekjct3YkpSUwZckpoy4JLUlAGXpKYMuCQ1ZcAlqSkDLklNGXBJasqAS1JTBlySmjLgktSUAZekpgy4JDVlwCWpKQMuSU0ZcElqyoBLUlMGXJKaMuCS1JQBl6SmDLgkNWXAJakpAy5JTRlwSWrKgEtSUwZckpoy4JLUlAGXpKYMuCQ1ZcAlqSkDLklNGXBJamrBgCf5SpL9SZ6es+ycJI8meWH4evbSDlOSNN84Z+B/C1w7b9mtwI6q2gjsGOYlSSfQggGvqn8G/nPe4uuA7cP0duD6CY9LkrSAxV4DX1NV+4bp14A1ExqPJGlMx/1LzKoqoI60PsnWJDNJZmZnZ4/34SRJg8UG/PUkawGGr/uPtGFVbauq6aqanpqaWuTDSZLmW2zAHwY2D9ObgYcmMxxJ0rjGeRnh14B/BS5K8mqSLcDngWuSvAB8dJiXJJ1AqxfaoKpuPsKqqyc8FknSMfCdmJLUlAGXpKYMuCQ1ZcAlqSkDLklNGXBJasqAS1JTBlySmjLgktSUAZekpgy4JDVlwCWpKQMuSU0ZcElqyoBLUlMGXJKaMuCS1JQBl6SmDLgkNWXAJakpAy5JTRlwSWrKgEtSUwZckpoy4JLUlAGXpKYMuCQ1ZcAlqSkDLklNGXBJasqAS1JTBlySmjqugCe5NsnzSfYkuXVSg5IkLWzRAU+yCvhr4HeBi4Gbk1w8qYFJko7ueM7ALwP2VNVLVfW/wL3AdZMZliRpIccT8POAH82Zf3VYJkk6AVYv9QMk2QpsBVi/fv3ivsmmTRMckSSdGo4n4HuB8+fMrxuWvUNVbQO2AUxPT9eiHunOOxd1N0k6lR3PJZTvARuTXJDkPcBNwMOTGZYkaSGLPgOvqgNJ/hD4NrAK+EpVPTOxkUmSjuq4roFX1TeBb05oLJKkY+A7MSWpKQMuSU0ZcElqyoBLUlMGXJKaStXi3luzqAdLZoFXFnn3c4E3JjicrtwPI+6HQ9wXI6fyfviNqpqav/CEBvx4JJmpqunlHsdycz+MuB8OcV+MrMT94CUUSWrKgEtSU50Cvm25B3CScD+MuB8OcV+MrLj90OYauCTpnTqdgUuS5mgR8JX04clJzk/yWJJnkzyT5JZh+TlJHk3ywvD17GF5kvzVsG+eSvKh5X0Gk5NkVZInkjwyzF+QZOfwXO8b/htjkpw+zO8Z1m9YznFPWpKzkjyQ5Lkku5NcsUKPhz8e/k48neRrSc5YqcfEQSd9wFfghycfAD5TVRcDlwOfGp7vrcCOqtoI7BjmYbRfNg63rcAXT/yQl8wtwO4587cDd1TVhcCbwJZh+RbgzWH5HcN2p5K7gG9V1QeBSxjtkxV1PCQ5D/gjYLqqfpPRf2F9Eyv3mBipqpP6BlwBfHvO/G3Abcs9rhP4/B8CrgGeB9YOy9YCzw/TXwJunrP929t1vjH6hKcdwFXAI0AYvUlj9fzjgtH/SX/FML162C7L/RwmtB/eB/xg/vNZgcfDwc/gPWf4M34E+J2VeEzMvZ30Z+Cs4A9PHn7suxTYCaypqn3DqteANcP0qbp/7gQ+C/ximH8/8FZVHRjm5z7Pt/fBsP4nw/angguAWeBvhstJX05yJivseKiqvcBfAj8E9jH6M97Fyjwm3tYh4CtSkvcCXwc+XVU/nbuuRqcVp+zLh5J8HNhfVbuWeywngdXAh4AvVtWlwH9z6HIJcOofDwDDNf7rGP2D9uvAmcC1yzqok0CHgI/14cmnkiSnMYr3V6vqwWHx60nWDuvXAvuH5afi/rkS+ESSl4F7GV1GuQs4K8nBT5Ga+zzf3gfD+vcBPz6RA15CrwKvVtXOYf4BRkFfSccDwEeBH1TVbFX9H/Ago+NkJR4Tb+sQ8BX14clJAtwN7K6qL8xZ9TCweZjezOja+MHlfzC8+uBy4CdzfrRuqapuq6p1VbWB0Z/3d6rqk8BjwA3DZvP3wcF9c8Ow/SlxRlpVrwE/SnLRsOhq4FlW0PEw+CFweZJfGf6OHNwPK+6YeIflvgg/5i8wPgb8B/Ai8OfLPZ4lfq4fZvTj8FPAk8PtY4yu3+0AXgD+CThn2D6MXqXzIvB9Rr+lX/bnMcH98RHgkWH6A8B3gT3A3wOnD8vPGOb3DOs/sNzjnvA+2ATMDMfEPwBnr8TjAfgc8BzwNPB3wOkr9Zg4ePOdmJLUVIdLKJKkwzDgktSUAZekpgy4JDVlwCWpKQMuSU0ZcElqyoBLUlP/DwNqiysvGuY6AAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPIElEQVR4nO3dfYydZZmA8eu2BaqySwUGxLbZwVijJC7VVKxREoG4FlBKtBIQtTFNmhg2wUjiAia7MVkTPxLxIxsCWYx1+RKQTRskYbttje4fVAeptdBFRgNpm2pHPuoSglq494/3KTnUlpnOnJnTuXv9ksm87/O+M+d5yuHq6TvnzInMRJJUy2sGPQFJUv8Zd0kqyLhLUkHGXZIKMu6SVNDcQU8A4NRTT83h4eFBT0OSZpWHHnroD5k5dKhjR0Xch4eHGRkZGfQ0JGlWiYgnD3fMyzKSVJBxl6SCjLskFWTcJakg4y5JBRl3SSrIuEtSQUfF89xL+MtfYMMG2LIF/DXKkibqIx+Bd7+779/WuE9FJvz853DrrXDnnTA21o1HDHZekmaPN73JuB81Rkfhttu6qI+Owrx5cMkl8MlPwoc+BMcfP+gZSjrGGfeJGhuDu+7qgv7gg92j8/POg+uvh49+FE46adAzlKSXGfdX8/zzsH59F/QHHoD9++Hss+HrX4crroAFCwY9Q0k6JON+sBdfhE2buqDfey889xwsXAjXXANXXgnveMegZyhJ4zLu0P1gdOvWLuh33AF79nSXWS6/vLuOfu658BqfNSpp9ji24/7EE3D77V3Ud+yA446Diy/ugn7xxd0PSiVpFjr24v7003DPPV3Qf/rTbuzcc+Gmm2DlSjj55MHOT5L64NiI+wsvwI9+1AX9/vvhz3+Gt78dvvxl+MQnwHeBklRM3bi/9BL85Cfd89Hvvhv27YM3vhGuugo+9SlYssQXG0kqq17ct2/vHqHffjvs3Aknnggf+1h3Hf2882DOnEHPUJKmXY2479rVPcvl1lth27Yu4MuXw9e+1r1y9HWvG/QMJWlGze64//rX8NnPwubN3dMZ3/Me+M534LLL4LTTBj07SRqY2R33deu6Fxx98YuwahUsXjzoGUnSUWF2x/2A666D179+0LOQpKOGL7uUpIKMuyQVZNwlqSDjLkkFGXdJKsi4S1JBxl2SCjLuklTQhOMeEXMi4uGIuK/tnxkRWyJiNCJ+EBHHt/ET2v5oOz48PVOXJB3OkTxyvxrY0bP/VeCGzHwL8Aywuo2vBp5p4ze08yRJM2hCcY+IhcDFwL+3/QDOB+5pp6wFLm3bK9o+7fgF7XxJ0gyZ6CP3bwJfAF5q+6cAz2bm/ra/C1jQthcAOwHa8X3t/FeIiDURMRIRI2NjY5OcviTpUMaNe0R8GNibmQ/184Yz8+bMXJqZS4eGhvr5rSXpmDeR3wr5PuCSiLgImAf8LfAtYH5EzG2PzhcCu9v5u4FFwK6ImAucBDzV95lLkg5r3EfumXldZi7MzGHgcmBTZl4JbAZWttNWAeva9vq2Tzu+KTOzr7OWJL2qqTzP/Z+Az0fEKN019Vva+C3AKW3888C1U5uiJOlIHdGbdWTmj4Eft+3fAucc4pwXgI/3YW6SpEnyFaqSVJBxl6SCjLskFWTcJakg4y5JBRl3SSrIuEtSQcZdkgoy7pJUkHGXpIKMuyQVZNwlqSDjLkkFGXdJKsi4S1JBxl2SCjLuklSQcZekgoy7JBVk3CWpIOMuSQUZd0kqyLhLUkHGXZIKMu6SVJBxl6SCjLskFWTcJakg4y5JBRl3SSrIuEtSQcZdkgoy7pJUkHGXpILGjXtEzIuIn0XELyPikYj4Uhs/MyK2RMRoRPwgIo5v4ye0/dF2fHh6lyBJOthEHrn/CTg/M88GlgDLI2IZ8FXghsx8C/AMsLqdvxp4po3f0M6TJM2gceOenefa7nHtI4HzgXva+Frg0ra9ou3Tjl8QEdG3GUuSxjWha+4RMScitgJ7gQ3Ab4BnM3N/O2UXsKBtLwB2ArTj+4BTDvE910TESESMjI2NTW0VkqRXmFDcM/PFzFwCLATOAd421RvOzJszc2lmLh0aGprqt5Mk9TiiZ8tk5rPAZuC9wPyImNsOLQR2t+3dwCKAdvwk4Km+zFaSNCETebbMUETMb9uvBT4I7KCL/Mp22ipgXdte3/ZpxzdlZvZz0pKkVzd3/FM4A1gbEXPo/jK4KzPvi4hHgTsj4l+Bh4Fb2vm3AP8REaPA08Dl0zBvSdKrGDfumbkNeOchxn9Ld/394PEXgI/3ZXaSpEnxFaqSVJBxl6SCjLskFWTcJakg4y5JBRl3SSrIuEtSQcZdkgoy7pJUkHGXpIKMuyQVZNwlqSDjLkkFGXdJKsi4S1JBxl2SCjLuklSQcZekgoy7JBVk3CWpIOMuSQUZd0kqyLhLUkHGXZIKMu6SVJBxl6SCjLskFWTcJakg4y5JBRl3SSrIuEtSQcZdkgoy7pJUkHGXpILGjXtELIqIzRHxaEQ8EhFXt/GTI2JDRDzePr+hjUdEfDsiRiNiW0S8a7oXIUl6pYk8ct8PXJOZZwHLgKsi4izgWmBjZi4GNrZ9gAuBxe1jDXBj32ctSXpV48Y9M/dk5i/a9v8BO4AFwApgbTttLXBp214BfD87DwLzI+KMvs9cknRYR3TNPSKGgXcCW4DTM3NPO/Q74PS2vQDY2fNlu9rYwd9rTUSMRMTI2NjYEU5bkvRqJhz3iDgR+CHwucz8Y++xzEwgj+SGM/PmzFyamUuHhoaO5EslSeOYUNwj4ji6sN+Wmfe24d8fuNzSPu9t47uBRT1fvrCNSZJmyESeLRPALcCOzPxGz6H1wKq2vQpY1zP+6fasmWXAvp7LN5KkGTB3Aue8D/gU8KuI2NrGrge+AtwVEauBJ4HL2rH7gYuAUeB54DN9nbEkaVzjxj0z/weIwxy+4BDnJ3DVFOclSZoCX6EqSQUZd0kqyLhLUkHGXZIKMu6SVJBxl6SCjLskFWTcJakg4y5JBRl3SSrIuEtSQcZdkgoy7pJUkHGXpIKMuyQVZNwlqSDjLkkFGXdJKsi4S1JBxl2SCjLuklSQcZekgoy7JBVk3CWpIOMuSQUZd0kqyLhLUkHGXZIKMu6SVJBxl6SCjLskFWTcJakg4y5JBRl3SSpo3LhHxHcjYm9EbO8ZOzkiNkTE4+3zG9p4RMS3I2I0IrZFxLumc/KSpEObyCP37wHLDxq7FtiYmYuBjW0f4EJgcftYA9zYn2lKko7EuHHPzJ8ATx80vAJY27bXApf2jH8/Ow8C8yPijH5NVpI0MZO95n56Zu5p278DTm/bC4CdPeftamN/JSLWRMRIRIyMjY1NchqSpEOZ8g9UMzOBnMTX3ZyZSzNz6dDQ0FSnIUnqMdm4//7A5Zb2eW8b3w0s6jlvYRuTJM2gycZ9PbCqba8C1vWMf7o9a2YZsK/n8o0kaYbMHe+EiLgD+ABwakTsAv4F+ApwV0SsBp4ELmun3w9cBIwCzwOfmYY5S5LGMW7cM/OKwxy64BDnJnDVVCclSZoaX6EqSQUZd0kqyLhLUkHGXZIKMu6SVJBxl6SCjLskFWTcJakg4y5JBRl3SSrIuEtSQcZdkgoy7pJUkHGXpIKMuyQVZNwlqSDjLkkFGXdJKsi4S1JBxl2SCjLuklSQcZekgoy7JBVk3CWpIOMuSQUZd0kqyLhLUkHGXZIKMu6SVJBxl6SCjLskFWTcJakg4y5JBRl3SSpoWuIeEcsj4rGIGI2Ia6fjNiRJh9f3uEfEHODfgAuBs4ArIuKsft+OJOnwpuOR+znAaGb+NjP/DNwJrJiG25EkHcZ0xH0BsLNnf1cbe4WIWBMRIxExMjY2NrlbeutbYeVKmDNncl8vSUUN7AeqmXlzZi7NzKVDQ0OT+yYrVsDdd8O8ef2dnCTNctMR993Aop79hW1MkjRDpiPuPwcWR8SZEXE8cDmwfhpuR5J0GHP7/Q0zc39E/CPwADAH+G5mPtLv25EkHV7f4w6QmfcD90/H95Ykjc9XqEpSQcZdkgoy7pJUkHGXpIIiMwc9ByJiDHhykl9+KvCHPk7naOd6a3O9tfV7vX+XmYd8FehREfepiIiRzFw66HnMFNdbm+utbSbX62UZSSrIuEtSQRXifvOgJzDDXG9trre2GVvvrL/mLkn6axUeuUuSDmLcJamgWR33im/EHRHfjYi9EbG9Z+zkiNgQEY+3z29o4xER327r3xYR7xrczI9cRCyKiM0R8WhEPBIRV7fxquudFxE/i4hftvV+qY2fGRFb2rp+0H5VNhFxQtsfbceHBzn/yYqIORHxcETc1/bLrjcinoiIX0XE1ogYaWMDuT/P2rgXfiPu7wHLDxq7FtiYmYuBjW0furUvbh9rgBtnaI79sh+4JjPPApYBV7X/hlXX+yfg/Mw8G1gCLI+IZcBXgRsy8y3AM8Dqdv5q4Jk2fkM7bza6GtjRs199vedl5pKe57MP5v6cmbPyA3gv8EDP/nXAdYOeV5/WNgxs79l/DDijbZ8BPNa2bwKuONR5s/EDWAd88FhYL/A64BfAe+hesTi3jb98v6Z7T4T3tu257bwY9NyPcJ0L6YJ2PnAfEMXX+wRw6kFjA7k/z9pH7kzwjbiLOD0z97Tt3wGnt+0yfwbtn+DvBLZQeL3tEsVWYC+wAfgN8Gxm7m+n9K7p5fW24/uAU2Z2xlP2TeALwEtt/xRqrzeB/4qIhyJiTRsbyP15Wt6sQ9MnMzMiSj1/NSJOBH4IfC4z/xgRLx+rtt7MfBFYEhHzgf8E3jbgKU2biPgwsDczH4qIDwx6PjPk/Zm5OyJOAzZExP/2HpzJ+/NsfuR+LL0R9+8j4gyA9nlvG5/1fwYRcRxd2G/LzHvbcNn1HpCZzwKb6S5LzI+IAw+0etf08nrb8ZOAp2Z4qlPxPuCSiHgCuJPu0sy3qLteMnN3+7yX7i/vcxjQ/Xk2x/1YeiPu9cCqtr2K7tr0gfFPt5+6LwP29fzz76gX3UP0W4AdmfmNnkNV1zvUHrETEa+l+/nCDrrIr2ynHbzeA38OK4FN2S7OzgaZeV1mLszMYbr/Pzdl5pUUXW9EvD4i/ubANvAPwHYGdX8e9A8gpvjDi4uAX9Ndt/zioOfTpzXdAewB/kJ3DW413XXHjcDjwH8DJ7dzg+4ZQ78BfgUsHfT8j3Ct76e7RrkN2No+Liq83r8HHm7r3Q78cxt/M/AzYBS4Gzihjc9r+6Pt+JsHvYYprP0DwH2V19vW9cv28ciBJg3q/uyvH5CkgmbzZRlJ0mEYd0kqyLhLUkHGXZIKMu6SVJBxl6SCjLskFfT/ZVxDnwNir8YAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPPUlEQVR4nO3dfYydZZnH8e9ly4u7ulTaWYJt42CsURIRSWWrmMhCVGDV8gcYjUKjTeofGEEwLrCJK3E1EhKLxA1KFqSsRl58g5AmyraY9SWWHaQi0FXGF6QN2BELrK9QvPaPc5ccastMZ86Zh7nm+0lOzvPc933Oc93T01+f3uc5cyIzkSTV8ryuC5AkDZ7hLkkFGe6SVJDhLkkFGe6SVNDCrgsAWLJkSY6OjnZdhiTNKXfeeedvMnNkX33PiXAfHR1lbGys6zIkaU6JiAf21+eyjCQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQV9Jy4zr1z27fDNdfA7t1dVyJpvnnb2+C1rx340xruTzzR++Fu3QoRXVcjab558YsN96H42Md6wf6Nb8Dq1V1XI0kDMb/X3L/7Xbj0Uli71mCXVMr8DffHH4ezzoLRUVi/vutqJGmg5u+yzHnnwa9+Bd/5DrzwhV1XI0kDNT/P3L/+dfjCF+Cii+D1r++6GkkauPkX7g8/DOvWwXHHwUc/2nU1kjQU8yvcM3tvnv7ud/DFL8LBB3ddkSQNxfxac//852HjRrjiCnjlK7uuRpKGZv6cuf/0p3DBBfDmN8M553RdjSQN1fwI9yef7F32eMghvV8z8Lz5MW1J89f8WJb55Cfhjjvgxhth6dKuq5Gkoat/CrtlC3z84/Ce98CZZ3ZdjSTNitrh/vvf95Zjli6Fz36262okadbUXpb58IdhfBw2b4bDDuu6GkmaNXXP3DduhM99Ds4/H048setqJGlW1Qz3iQl43/vgVa+CT3yi62okadbVW5bJhPe/H3btgm99q3f5oyTNM/XC/dpre78Y7LLL4Jhjuq5GkjpRa1nmF7+AD34Q3vhG+NCHuq5GkjpTJ9yfegrOPrv36dMNG2DBgq4rkqTO1FmWueyy3tfmXXcdvOQlXVcjSZ2qceZ+1129381+5pm9T6JK0jw35XCPiAURcVdE3Nr2j4qILRExHhE3RMTBrf2Qtj/e+keHU3rzxz/2An3JErjySogY6uEkaS44kDP3c4FtffuXAusz82XALmBta18L7Grt69u44bn4Yrjvvt7X5i1ePNRDSdJcMaVwj4hlwD8B/9H2AzgJ+EobsgE4vW2vbvu0/pPb+MH73vfg8svhAx+At7xlKIeQpLloqmfulwMfAf7S9hcDj2bm7ra/Hdjzu3SXAg8CtP7H2vhniIh1ETEWEWMTExPTq/773+/dX3LJ9B4vSUVNGu4R8VZgZ2beOcgDZ+ZVmbkyM1eOjIzM7Mn8FKokPcNULoU8AXh7RJwGHAr8HfAZYFFELGxn58uAHW38DmA5sD0iFgKHAY8MvHJJ0n5NeuaemRdl5rLMHAXeCWzOzHcDtwNntGFrgJvb9i1tn9a/OTNzoFVLkp7VTK5z/2fg/IgYp7emfnVrvxpY3NrPBy6cWYmSpAN1QJ9QzcxvA99u2z8Hjt/HmD8Bfp+dJHWoxidUJUnPYLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGGuyQVZLhLUkGThntEHBoRd0TEjyLi3oi4pLUfFRFbImI8Im6IiINb+yFtf7z1jw53CpKkvU3lzP3PwEmZ+WrgWOCUiFgFXAqsz8yXAbuAtW38WmBXa1/fxkmSZtGk4Z49v2u7B7VbAicBX2ntG4DT2/bqtk/rPzkiYmAVS5ImNaU194hYEBFbgZ3AbcDPgEczc3cbsh1Y2raXAg8CtP7HgMX7eM51ETEWEWMTExMzm4Uk6RmmFO6Z+VRmHgssA44HXjHTA2fmVZm5MjNXjoyMzPTpJEl9Duhqmcx8FLgdeB2wKCIWtq5lwI62vQNYDtD6DwMeGUi1kqQpmcrVMiMRsahtPx94E7CNXsif0YatAW5u27e0fVr/5szMQRYtSXp2CycfwpHAhohYQO8fgxsz89aIuA+4PiL+DbgLuLqNvxr4z4gYB34LvHMIdUuSnsWk4Z6ZdwOv2Uf7z+mtv+/d/ifgzIFUJ0maFj+hKkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVNCk4R4RyyPi9oi4LyLujYhzW/vhEXFbRNzf7l/U2iMiroiI8Yi4OyKOG/YkJEnPNJUz993ABZl5NLAKOCcijgYuBDZl5gpgU9sHOBVY0W7rgCsHXrUk6VlNGu6Z+VBm/rBt/x+wDVgKrAY2tGEbgNPb9mrguuz5AbAoIo4ceOWSpP06oDX3iBgFXgNsAY7IzIda18PAEW17KfBg38O2t7a9n2tdRIxFxNjExMQBli1JejZTDveIeAHwVeC8zHy8vy8zE8gDOXBmXpWZKzNz5cjIyIE8VJI0iSmFe0QcRC/Yv5SZX2vNv96z3NLud7b2HcDyvocva22SpFkylatlArga2JaZn+7rugVY07bXADf3tZ/drppZBTzWt3wjSZoFC6cw5gTgLODHEbG1tV0MfAq4MSLWAg8A72h9G4HTgHHgD8B7B1qxJGlSk4Z7Zn4XiP10n7yP8QmcM8O6JEkz4CdUJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCpo03CPimojYGRH39LUdHhG3RcT97f5FrT0i4oqIGI+IuyPiuGEWL0nat6mcuV8LnLJX24XApsxcAWxq+wCnAivabR1w5WDKlCQdiEnDPTP/G/jtXs2rgQ1tewNwel/7ddnzA2BRRBw5qGIlSVMz3TX3IzLzobb9MHBE214KPNg3bntr+ysRsS4ixiJibGJiYpplSJL2ZcZvqGZmAjmNx12VmSszc+XIyMhMy5Ak9ZluuP96z3JLu9/Z2ncAy/vGLWttkqRZNN1wvwVY07bXADf3tZ/drppZBTzWt3wjSZolCycbEBFfBk4ElkTEduBfgU8BN0bEWuAB4B1t+EbgNGAc+APw3iHULEmaxKThnpnv2k/XyfsYm8A5My1KkjQzfkJVkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpoKGEe0ScEhE/iYjxiLhwGMeQJO3fwMM9IhYA/w6cChwNvCsijh70cSRJ+zeMM/fjgfHM/HlmPgFcD6wewnEkSfsxjHBfCjzYt7+9tT1DRKyLiLGIGJuYmJjekV7+cjjjDFiwYHqPl6SiOntDNTOvysyVmblyZGRkek+yejXcdBMceuhgi5OkOW4Y4b4DWN63v6y1SZJmyTDC/X+AFRFxVEQcDLwTuGUIx5Ek7cfCQT9hZu6OiA8A3wQWANdk5r2DPo4kaf8GHu4AmbkR2DiM55YkTc5PqEpSQYa7JBVkuEtSQYa7JBUUmdl1DUTEBPDANB++BPjNAMt5rnO+tTnf2gY935dk5j4/BfqcCPeZiIixzFzZdR2zxfnW5nxrm835uiwjSQUZ7pJUUIVwv6rrAmaZ863N+dY2a/Od82vukqS/VuHMXZK0F8Ndkgqa0+Fe8Yu4I+KaiNgZEff0tR0eEbdFxP3t/kWtPSLiijb/uyPiuO4qP3ARsTwibo+I+yLi3og4t7VXne+hEXFHRPyozfeS1n5URGxp87qh/apsIuKQtj/e+ke7rH+6ImJBRNwVEbe2/bLzjYhfRsSPI2JrRIy1tk5ez3M23At/Efe1wCl7tV0IbMrMFcCmtg+9ua9ot3XAlbNU46DsBi7IzKOBVcA57c+w6nz/DJyUma8GjgVOiYhVwKXA+sx8GbALWNvGrwV2tfb1bdxcdC6wrW+/+nz/MTOP7buevZvXc2bOyRvwOuCbffsXARd1XdeA5jYK3NO3/xPgyLZ9JPCTtv154F37GjcXb8DNwJvmw3yBvwF+CPwDvU8sLmztT7+u6X0nwuva9sI2Lrqu/QDnuYxeoJ0E3ApE8fn+EliyV1snr+c5e+bOFL+Iu4gjMvOhtv0wcETbLvMzaP8Ffw2whcLzbUsUW4GdwG3Az4BHM3N3G9I/p6fn2/ofAxbPbsUzdjnwEeAvbX8xteebwLci4s6IWNfaOnk9D+XLOjQ8mZkRUer61Yh4AfBV4LzMfDwinu6rNt/MfAo4NiIWAV8HXtFxSUMTEW8FdmbmnRFxYtf1zJI3ZOaOiPh74LaI+N/+ztl8Pc/lM/f59EXcv46IIwHa/c7WPud/BhFxEL1g/1Jmfq01l53vHpn5KHA7vWWJRRGx50Srf05Pz7f1HwY8MsulzsQJwNsj4pfA9fSWZj5D3fmSmTva/U56/3gfT0ev57kc7vPpi7hvAda07TX01qb3tJ/d3nVfBTzW99+/57zonaJfDWzLzE/3dVWd70g7Yycink/v/YVt9EL+jDZs7/nu+TmcAWzOtjg7F2TmRZm5LDNH6f393JyZ76bofCPibyPihXu2gTcD99DV67nrNyBm+ObFacBP6a1b/kvX9QxoTl8GHgKepLcGt5beuuMm4H7gv4DD29igd8XQz4AfAyu7rv8A5/oGemuUdwNb2+20wvM9Brirzfce4KOt/aXAHcA4cBNwSGs/tO2Pt/6Xdj2HGcz9RODWyvNt8/pRu927J5O6ej376wckqaC5vCwjSdoPw12SCjLcJakgw12SCjLcJakgw12SCjLcJamg/weUVTeZzfbL2wAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"# Not ordered\n",
|
|
"for c in range(3):\n",
|
|
" plt.plot(roc_curves[c][0], roc_curves[c][1], color = \"red\")\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"anaconda-cloud": {},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|