An extension for
mlr3 to enable using various
keras models as learners.
mlr3keras is in very early stages, and currently under development. Functionality is therefore experimental and we do not guarantee correctness, safety or stability. It builds on top of the (awesome) R packages
keras. Comments, discussion and issues/bug reports and PR’s are highly appreciated.
If you want to contribute, please propose / discuss adding functionality in an issue in order to avoid unnecessary or duplicate work.
# Install from GitHub remotes::install_github("mlr-org/mlr3keras")
If you encounter problems using the correct python versions, see here.
mlr3keras is currently tested and works using the python packages
keras (2.4) and
One possible workflow for working with mlr3keras is described below. While (1.) and (2.) are one-time setup steps, (3.) now has to be called everytime mlr3keras is loaded.
Note from the author: The workflow described below is something that works for me personally, as I have to switch between versions and projects often. It is described for the user, as I personally find it useful. It assumes, the R packages
reticulateare installed. In order to load mlr3keras I now have to execute an additional one additional line (see 3.), but version management is heavily simplified. Note, that this seems to work on Linux but has not been extensively tested on other systems.
# Execute and restart R afterwards reticulate::install_miniconda()
NOTE: It might make sense, to set the
RETICULATE_PYTHONenvironment variable to your miniconda / Anaconda installation path as explained here.
# Execute and restart R afterwards reticulate::conda_create( envname = "mlr3keras", packages = c("pandas", "python=3.8") ) keras::install_keras("conda", tensorflow="2.3.1", envname="mlr3keras")
conda_install("mlr3keras", packages = "tabnet", pip = TRUE)
mlr3keras currently exposes three
Learners for regression and classification respectively.
|regr/classif.keras||A generic wrapper that allows to supply a custom keras architecture as a hyperparameter.||–|
|regr/classif.kerasFF||A fully-connected feed-forward Neural Network with entity embeddings||Guo et al. (2016) Entity Embeddings for Categorical Variables|
|regr/classif.tabNet||An implementation of
||Sercan, A. and Pfister, T. (2019): TabNet|
|regr/classif.smlp||Shaped MLP inspired by Configuration Space 1*||Zimmer, L. et al. (2020): Auto PyTorch Tabular|
|regr/classif.smlp2||Shaped MLP inspired by Configuration Space 2*||Zimmer, L. et al. (2020): Auto PyTorch Tabular|
|regr/classif.deep_wide||Deep and Wide Architecture inspired by||Ericson et al., 2020 AutoGluon-Tabular: Robust and Accurate AutoML for Structured Data|
|classif.kerascnn||Various CNN Applications via Transfer-learning||Uses
Learners can be used for
prediction as follows:
library("mlr3") # Instantiate Learner lrn = LearnerClassifKerasFF$new() # Set Learner Hyperparams lrn$param_set$values$epochs = 50 lrn$param_set$values$layer_units = 12 # Train and Predict lrn$train(tsk("iris")) lrn$predict(tsk("iris"))
The vignette has some examples on how to use some of the functionality introduces in
This package’s purpose for now is to understand the design-decisions required to make
tensorflow work with
mlr3 and flexible enough for users.
Several design decisions are not made yet, so input is highly appreciated.
The goal of the project is to expose keras models as mlr3 learners. A keras model in this context should be understood as the combination of
keras_model_sequential() %>% ... %>% layer_activation(...)
All hyperparameters that control the steps:
Some important caveats:
Architectures are often data-dependent, e.g. require correct number of input / output neurons. As a result, in
mlr3keras, the architecture is a function of the incoming training data. In
mlr3keras, this is abstracted via
KerasArchitectureFF for an example. This Architecture is initialized with a
build_arch_fun which given the
task and a set of hyperparameters constructs & compiles the architecture.
Depending on the architecture, different data-formats are required for
x (features) and
y (target) (e.g. a matrix for a feed-forward NN, a list of features if we use embeddings, …) To accomodate this, each architecture comes with an
x_transform and a
y_transform method, which are called on the features and target respectively before passing those on to
Scope The current scope for
mlr3keras is to support deep learning on different kinds of tabular data. In the future, we aim to extend this to other data modalities, but as of yet, work on this has not started.
In an initial version, we aim to support two types of models:
Pre-defined architectures: In many cases, we just want to try out and tune architectures that have already been successfully used in other contexts (LeNet, ResNet, TabNet). We aim to implement / make those accessible for simplified tuning and fast iteration. Example:
Fully custom architectures: Some operations require completely new architectures. We aim to allow users to supply custom architectures and tune hyperparameters of those. This can be done via
KerasArchitectureCustom by providing a function that builds the model given a
Task and a set of hyperparameters.
All architectures can be parametrized and tuned using the
kerasand it’s ecosystem is constantly growing, the interface needs to be flexible and highly adaptable. We try to solve this using a
keras_reflectionsthat stores possible values for exchangable parts of an architecture. New methods can now be added by adding to the respective reflection.
mlr3does not support features such as images / audio / …, therefore
mlr3kerasis not yet applicable for image classification and other related tasks. We aim to make this possible in the future! A minor road block here is to find a way to not read images to memory in R but directly load from disk to avoid additional overhead.
mlr3pipelines, yet we hope this can be solved at some point in the future.
mlr3focusses heavily on standard classification and regression. Many Deep Learning tasks require slight extensions (image annotation, bounding boxes, object detection, … ) of those existing data containers (