Con el código anterior ya podemos dar forma a una función reutilizable con cualquier dataset:
def plot_decision_boundaries(model, X, ax):
minX = min(X[:, 0])
maxX = max(X[:, 0])
minY = min(X[:, 1])
maxY = max(X[:, 1])
marginX = (maxX - minX) * 0.1
marginY = (maxY - minY) * 0.1
x = np.linspace(minX - marginX, maxX + marginX, 100)
y = np.linspace(minY - marginY, maxY + marginY, 100)
X, Y = np.meshgrid(x, y)
Z = model.predict(np.c_[X.ravel(), Y.ravel()]).reshape(X.shape)
ax.contourf(X, Y, Z, levels = 2,
colors = ["#E3BCAB", "#B0D9CB", "#75B6E6"],
zorder = 0
)
minX = min(X[:, 0])
maxX = max(X[:, 0])
minY = min(X[:, 1])
maxY = max(X[:, 1])
marginX = (maxX - minX) * 0.1
marginY = (maxY - minY) * 0.1
x = np.linspace(minX - marginX, maxX + marginX, 100)
y = np.linspace(minY - marginY, maxY + marginY, 100)
X, Y = np.meshgrid(x, y)
Z = model.predict(np.c_[X.ravel(), Y.ravel()]).reshape(X.shape)
ax.contourf(X, Y, Z, levels = 2,
colors = ["#E3BCAB", "#B0D9CB", "#75B6E6"],
zorder = 0
)
La función recibe como argumentos el modelo (que será utilizado para realizar las predicciones), el dataset X del que extraer las primeras dos características predictivas con formato de array NumPy y, previendo que esta gráfica vaya a ser añadida a un conjunto de ejes sobre el que también se quiera mostrar un diagrama de dispersión son los datos de entrenamiento, también requiere como tercer argumento el conjunto de ejes.
Esta función es la que tenemos disponible en la librería boundaries.