Convolutional Neural Network (CNN) application from keras. This learner builds and compiles the keras model from the hyperparameters in param_set, and does not require a supplied and compiled model. The 'application' parameter refers to a 'keras::application_*' CNN architectures, possibly with pre-trained weights.

Calls keras::fit_generator together with keras::flow_images_from_dataframe from package keras. Layers are set up as follows:

  • The last layer (classification layer) is cut off the neural network.

  • A classification layer with 'cl_layer_units' is added.

  • The weights of all layers are frozen.

  • The last 'unfreeze_n_last_layers' are unfrozen.

Parameters:
Most of the parameters can be obtained from the keras documentation. Some exceptions are documented here.

  • application: A (possibly pre-trained) CNN architecture. Default: keras::application_resnet50.

  • cl_layer_units: Number of units in the classification layer.

  • unfreeze_n_last_layers: Number of last layers to be unfrozen.

  • optimizer: Some optimizers and their arguments can be found below.
    Inherits from tensorflow.python.keras.optimizer_v2.

    "sgd"     : optimizer_sgd(lr, momentum, decay = decay),
    "rmsprop" : optimizer_rmsprop(lr, rho, decay = decay),
    "adagrad" : optimizer_adagrad(lr, decay = decay),
    "adam"    : optimizer_adam(lr, beta_1, beta_2, decay = decay),
    "nadam"   : optimizer_nadam(lr, beta_1, beta_2, schedule_decay = decay)
    
  • class_weights: needs to be a named list of class-weights for the different classes numbered from 0 to c-1 (for c classes).

    Example:
    wts = c(0.5, 1)
    setNames(as.list(wts), seq_len(length(wts)) - 1)
    
  • callbacks: A list of keras callbacks. See ?callbacks.

Format

R6::R6Class() inheriting from LearnerClassifKeras.

Construction

LearnerClassifKerasCNN$new()
mlr3::mlr_learners$get("classif.kerascnn")
mlr3::lrn("classif.kerascnn")

Learner Methods

Keras Learners offer several methods for easy access to the stored models.

  • .$plot()
    Plots the history, i.e. the train-validation loss during training.

  • .$save(file_path)
    Dumps the model to a provided file_path in 'h5' format.

  • .$load_model_from_file(file_path)
    Loads a model saved using saved back into the learner. The model needs to be saved separately when the learner is serialized. In this case, the learner can be restored from this function. Currently not implemented for 'TabNet'.

  • .$lr_find(task, epochs, lr_min, lr_max, batch_size)
    Employ an implementation of the learning rate finder as popularized by Jeremy Howard in fast.ai (http://course.fast.ai/) for the learner. For more info on parameters, see find_lr.

See also

Examples

learner = mlr3::lrn("classif.kerascnn") print(learner)
#> <LearnerClassifKeras:classif.keras> #> * Model: - #> * Parameters: epochs=100, callbacks=<list>, validation_split=0.3333, #> batch_size=128, low_memory=FALSE, verbose=0, application=<function>, #> optimizer=<keras.optimizer_v2.adam.Adam>, #> loss=categorical_crossentropy, metrics=accuracy, cl_layer_units=1024, #> output_activation=softmax, unfreeze_n_last_layers=5, #> validation_fraction=0.2 #> * Packages: keras #> * Predict Type: response #> * Feature types: - #> * Properties: multiclass, twoclass
# available parameters: learner$param_set$ids()
#> [1] "epochs" "model" "class_weight" #> [4] "validation_split" "batch_size" "callbacks" #> [7] "low_memory" "verbose" "optimizer" #> [10] "loss" "output_activation" "application" #> [13] "cl_layer_units" "unfreeze_n_last_layers" "metrics" #> [16] "validation_fraction"