Visualización de un árbol de decisión

La librería sklearn incluye la función sklearn.tree.export_graphviz que permite la visualización de un árbol de decisión. Probémosla con el dataset de tips proveído por la librería seaborn:

import seaborn as sns

tips = sns.load_dataset("tips")
tips.sample(5)

Dataset tips

Como la implementación del árbol de decisión para escenarios de regresión de sklearn no permite trabajar con características categóricas, comenzamos codificándolas:

from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
tips.sex = le.fit_transform(tips.sex)
tips.smoker = le.fit_transform(tips.smoker)
tips.day = le.fit_transform(tips.day)
tips.time = le.fit_transform(tips.time)

Ahora podemos crear las estructuras X e y:

y = tips.pop("tip")
X = tips

...e importar el algoritmo, instanciarlo y entrenarlo (con los parámetros por defecto salvo la profundidad máxima, que fijamos en 3 para no crear un árbol demasiado grande):

from sklearn.tree import DecisionTreeRegressor
model = DecisionTreeRegressor(max_depth = 3)
model.fit(X, y)

DecisionTreeRegressor entrenado

Ahora importamos la función export_graphviz y un par de clases adicionales de sklearn, IPython y pydotplus:

from sklearn.externals.six import StringIO  
from IPython.display import Image  
from sklearn.tree import export_graphviz
import pydotplus

dot_data = StringIO()
export_graphviz(model, out_file = dot_data,  
                filled = True, rounded = True,
                special_characters = True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
Image(graph.create_png())

Árbol de decisión generado

Aunque entrar en el detalle de las funciones involucradas haría este escenario excesivamente complejo, aquí van algunos comentarios:

Los parámetros usados en export_graphviz son los siguientes:

  • El primer parámetro es el modelo entrenado, ya sea de regresión o de clasificación.
  • out_file indica el fichero o el manejador al que volcar el resultado, en este caso es un objeto de la clase StringIO.
  • El parámetro filled colorea los nodos generados para indicar en cuáles existe un valor más frecuente que otros (en clasificación) o en cuáles hay valores extremos (en regresión).
  • rounded muestra los rectángulos representando los nodos con las esquinas redondeadas
  • special_characters, por último, incluye en el gráfico generado caracteres especiales (el ignorarlos tendría sentido solo para asegurar la compatibilidad con el formato PostScript)

La función pydotplus.graph_from_dot_data carga un gráfico en formato DOT y, por último, la función Image(graph.create_png()) crea y muestra la imagen generada en formato png.

Categoría
Submitted by admin on Wed, 04/03/2019 - 15:53