An extension for mlr3
to enable using various keras
models as learners.
mlr3keras
is in very early stages, and currently under development. Functionality is therefore experimental and we do not guarantee correctness, safety or stability. It builds on top of the (awesome) R packages reticulate
, tensorflow
and keras
. Comments, discussion and issues/bug reports and PR’s are highly appreciated.
If you want to contribute, please propose / discuss adding functionality in an issue in order to avoid unnecessary or duplicate work.
# Install from GitHub
remotes::install_github("mlr-org/mlr3keras")
Troubleshooting:
If you encounter problems using the correct python versions, see here.
mlr3keras
is currently tested and works using the python packages keras (2.4)
and tensorflow (2.3.1)
.
One possible workflow for working with mlr3keras is described below. While (1.) and (2.) are one-time setup steps, (3.) now has to be called everytime mlr3keras is loaded.
Note from the author: The workflow described below is something that works for me personally, as I have to switch between versions and projects often. It is described for the user, as I personally find it useful. It assumes, the R packages
keras
,tensorflow
andreticulate
are installed. In order to load mlr3keras I now have to execute an additional one additional line (see 3.), but version management is heavily simplified. Note, that this seems to work on Linux but has not been extensively tested on other systems.
# Execute and restart R afterwards
reticulate::install_miniconda()
NOTE: It might make sense, to set the
RETICULATE_PYTHON
environment variable to your miniconda / Anaconda installation path as explained here.
keras
and tensorflow
# Execute and restart R afterwards
reticulate::conda_create(
envname = "mlr3keras",
packages = c("pandas", "python=3.8")
)
keras::install_keras("conda", tensorflow="2.3.1", envname="mlr3keras")
reticulate::use_condaenv("mlr3keras")
library(mlr3keras)
conda_install("mlr3keras", packages = "tabnet", pip = TRUE)
mlr3keras
currently exposes three Learners
for regression and classification respectively.
Learner | Details | Reference |
---|---|---|
regr/classif.keras | A generic wrapper that allows to supply a custom keras architecture as a hyperparameter. | – |
regr/classif.kerasFF | A fully-connected feed-forward Neural Network with entity embeddings | Guo et al. (2016) Entity Embeddings for Categorical Variables |
regr/classif.tabNet | An implementation of TabNet
|
Sercan, A. and Pfister, T. (2019): TabNet |
regr/classif.smlp | Shaped MLP inspired by Configuration Space 1* | Zimmer, L. et al. (2020): Auto PyTorch Tabular |
regr/classif.smlp2 | Shaped MLP inspired by Configuration Space 2* | Zimmer, L. et al. (2020): Auto PyTorch Tabular |
regr/classif.deep_wide | Deep and Wide Architecture inspired by | Ericson et al., 2020 AutoGluon-Tabular: Robust and Accurate AutoML for Structured Data |
classif.kerascnn | Various CNN Applications via Transfer-learning | Uses keras::application_XXX (e.g. mobilenet) |
Learners can be used for training
and prediction
as follows:
library("mlr3")
# Instantiate Learner
lrn = LearnerClassifKerasFF$new()
# Set Learner Hyperparams
lrn$param_set$values$epochs = 50
lrn$param_set$values$layer_units = 12
# Train and Predict
lrn$train(tsk("iris"))
lrn$predict(tsk("iris"))
The vignette has some examples on how to use some of the functionality introduces in mlr3keras
.
This package’s purpose for now is to understand the design-decisions required to make keras
tensorflow
work with mlr3
and flexible enough for users.
Several design decisions are not made yet, so input is highly appreciated.
Design
The goal of the project is to expose keras models as mlr3 learners. A keras model in this context should be understood as the combination of
model architecture:
keras_model_sequential() %>%
... %>%
layer_activation(...)
training procedure:
model$compile(...)
model$fit(...)
All hyperparameters that control the steps:
Some important caveats:
Architectures are often data-dependent, e.g. require correct number of input / output neurons. As a result, in mlr3keras
, the architecture is a function of the incoming training data. In mlr3keras
, this is abstracted via KerasArchitecture
: See KerasArchitectureFF
for an example. This Architecture is initialized with a build_arch_fun
which given the task
and a set of hyperparameters constructs & compiles the architecture.
Depending on the architecture, different data-formats are required for x
(features) and y
(target) (e.g. a matrix for a feed-forward NN, a list of features if we use embeddings, …) To accomodate this, each architecture comes with an x_transform
and a y_transform
method, which are called on the features and target respectively before passing those on to fit(...)
.
Scope The current scope for mlr3keras
is to support deep learning on different kinds of tabular data. In the future, we aim to extend this to other data modalities, but as of yet, work on this has not started.
In an initial version, we aim to support two types of models:
Pre-defined architectures: In many cases, we just want to try out and tune architectures that have already been successfully used in other contexts (LeNet, ResNet, TabNet). We aim to implement / make those accessible for simplified tuning and fast iteration. Example: LearnerClassifTabNet$new()
.
Fully custom architectures: Some operations require completely new architectures. We aim to allow users to supply custom architectures and tune hyperparameters of those. This can be done via KerasArchitectureCustom
by providing a function that builds the model given a Task
and a set of hyperparameters.
All architectures can be parametrized and tuned using the mlr3tuning
library.
Open Issues:
keras
and it’s ecosystem is constantly growing, the interface needs to be flexible and highly adaptable. We try to solve this using a reflections
mechanism: keras_reflections
that stores possible values for exchangable parts of an architecture. New methods can now be added by adding to the respective reflection.mlr3
does not support features such as images / audio / …, therefore mlr3keras
is not yet applicable for image classification and other related tasks. We aim to make this possible in the future! A minor road block here is to find a way to not read images to memory in R but directly load from disk to avoid additional overhead.mlr3pipelines
, yet we hope this can be solved at some point in the future.mlr3
focusses heavily on standard classification and regression. Many Deep Learning tasks require slight extensions (image annotation, bounding boxes, object detection, … ) of those existing data containers (Task
, Prediction
, …).
remotes::install_github("mlr-org/mlr3keras")
mlr3
and its ecosystem, but it is still unfinished.