R/LearnerClassifKeras.R
LearnerClassifKeras.Rd
Neural Network using Keras and Tensorflow.
This learner allows for supplying a custom architecture.
Calls keras::fit()
from package keras.
Parameters:
Most of the parameters can be obtained from the keras
documentation.
Some exceptions are documented here.
R6::R6Class()
inheriting from mlr3::LearnerClassif.
LearnerClassifKeras$new() mlr3::mlr_learners$get("classif.keras") mlr3::lrn("classif.keras")
Keras Learners offer several methods for easy access to the stored models.
.$plot()
Plots the history, i.e. the train-validation loss during training.
.$save(file_path)
Dumps the model to a provided file_path in 'h5' format.
.$load_model_from_file(file_path)
Loads a model saved using saved
back into the learner.
The model needs to be saved separately when the learner is serialized.
In this case, the learner can be restored from this function.
Currently not implemented for 'TabNet'.
.$lr_find(task, epochs, lr_min, lr_max, batch_size)
Employ an implementation of the learning rate finder as popularized by
Jeremy Howard in fast.ai (http://course.fast.ai/) for the learner.
For more info on parameters, see find_lr
.
# Define a model library(keras) model = keras_model_sequential() %>% layer_dense(units = 12L, input_shape = 4L, activation = "relu") %>% layer_dense(units = 12L, activation = "relu") %>% layer_dense(units = 3L, activation = "softmax") %>% compile(optimizer = optimizer_sgd(), loss = "categorical_crossentropy", metrics = "accuracy") # Create the learner learner = LearnerClassifKeras$new() learner$param_set$values$model = model learner$train(mlr3::mlr_tasks$get("iris"))