La librería sklearn incluye la función sklearn.tree.export_graphviz que permite la visualización de un árbol de decisión. Probémosla con el dataset de tips proveído por la librería seaborn:
import seaborn as sns
tips = sns.load_dataset("tips")
tips.sample(5)
Como la implementación del árbol de decisión para escenarios de regresión de sklearn no permite trabajar con características categóricas, comenzamos codificándolas:
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
tips.sex = le.fit_transform(tips.sex)
tips.smoker = le.fit_transform(tips.smoker)
tips.day = le.fit_transform(tips.day)
tips.time = le.fit_transform(tips.time)
Ahora podemos crear las estructuras X e y:
y = tips.pop("tip")
X = tips
...e importar el algoritmo, instanciarlo y entrenarlo (con los parámetros por defecto salvo la profundidad máxima, que fijamos en 3 para no crear un árbol demasiado grande):
from sklearn.tree import DecisionTreeRegressor
model = DecisionTreeRegressor(max_depth = 3)
model.fit(X, y)
Ahora importamos la función export_graphviz y un par de clases adicionales de sklearn, IPython y pydotplus:
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
dot_data = StringIO()
export_graphviz(model, out_file = dot_data,
filled = True, rounded = True,
special_characters = True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
Aunque entrar en el detalle de las funciones involucradas haría este escenario excesivamente complejo, aquí van algunos comentarios:
Los parámetros usados en export_graphviz son los siguientes:
- El primer parámetro es el modelo entrenado, ya sea de regresión o de clasificación.
- out_file indica el fichero o el manejador al que volcar el resultado, en este caso es un objeto de la clase StringIO.
- El parámetro filled colorea los nodos generados para indicar en cuáles existe un valor más frecuente que otros (en clasificación) o en cuáles hay valores extremos (en regresión).
- rounded muestra los rectángulos representando los nodos con las esquinas redondeadas
- special_characters, por último, incluye en el gráfico generado caracteres especiales (el ignorarlos tendría sentido solo para asegurar la compatibilidad con el formato PostScript)
La función pydotplus.graph_from_dot_data carga un gráfico en formato DOT y, por último, la función Image(graph.create_png()) crea y muestra la imagen generada en formato png.