Nuestra clase NNClassifier contará con tres métodos principales:
- El constructor (__init__), que será el encargado de recibir los hiperparámetros de la red neuronal (número de capas y número de neuronas en cada capa, número de epochs, etc.)
- El método fit que recibirá las características predictivas y la variable objetivo y entrenará la red neuronal.
- El método predict que recibirá un conjunto de características predictivas y devolverá la predicción para cada muestra.
Podemos crear la estructura básica de la clase, por lo tanto, de la siguiente forma:
class NNClassifier(object):
def __init__(self):
""" Constructor de la red neuronal """
def fit(self, X, y):
""" Entrenamiento de la red neuronal"""
pass
def predict(self, X):
""" Predicción """
pass
La creación de una instancia de nuestro algoritmo se realizaría -en este momento- con la siguiente instrucción:
model = NNClassifier()