
Deep Neural Networks for Survival Analysis using torch
survdnn implements neural network-based models for
right-censored survival analysis using the native torch
backend in R. It supports multiple loss functions including Cox partial
likelihood, L2-penalized Cox, Accelerated Failure Time (AFT) objectives,
as well as time-dependent extension such as Cox-Time. The package
provides a formula interface, supports model evaluation using
time-dependent metrics (C-index, Brier score, IBS), cross-validation,
and hyperparameter tuning.
A methodological paper describing the design, implementation, and
evaluation of survdnn is currently under review at The
R Journal.
Formula interface for Surv() ~ . models
Modular neural architectures: configurable layers, activations, optimizers, and losses
Built-in survival loss functions:
"cox": Cox partial likelihood"cox_l2": penalized Cox"aft": Accelerated Failure Time"coxtime": deep time-dependent CoxEvaluation: C-index, Brier score, IBS
Model selection with cv_survdnn() and
tune_survdnn()
Prediction of survival curves via predict() and
plot()
# Install from CRAN
install.packages("surdnn")
# Install from GitHub
install.packages("remotes")
remotes::install_github("ielbadisy/survdnn")
# Or clone and install locally
git clone https://github.com/ielbadisy/survdnn.git
setwd("survdnn")
devtools::install()library(survdnn)
library(survival, quietly = TRUE)
library(ggplot2)
veteran <- survival::veteran
mod <- survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
hidden = c(32, 16),
epochs = 300,
loss = "cox",
verbose = TRUE
)## Epoch 50 - Loss: 3.983201
##
## Epoch 100 - Loss: 3.947356
##
## Epoch 150 - Loss: 3.934828
##
## Epoch 200 - Loss: 3.876191
##
## Epoch 250 - Loss: 3.813223
##
## Epoch 300 - Loss: 3.868888
summary(mod)##
## Formula:
## Surv(time, status) ~ age + karno + celltype
## <environment: 0x611459d0ec80>
##
## Model architecture:
## Hidden layers: 32 : 16
## Activation: relu
## Dropout: 0.3
## Final loss: 3.868888
##
## Training summary:
## Epochs: 300
## Learning rate: 1e-04
## Loss function: cox
## Optimizer: adam
##
## Data summary:
## Observations: 137
## Predictors: age, karno, celltypesmallcell, celltypeadeno, celltypelarge
## Time range: [ 1, 999 ]
## Event rate: 93.4%
plot(mod, group_by = "celltype", times = 1:300)mod1 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "cox",
epochs = 300
)## Epoch 50 - Loss: 3.986035
##
## Epoch 100 - Loss: 3.973183
##
## Epoch 150 - Loss: 3.944867
##
## Epoch 200 - Loss: 3.901533
##
## Epoch 250 - Loss: 3.849433
##
## Epoch 300 - Loss: 3.899746
mod2 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "aft",
epochs = 300
)## Epoch 50 - Loss: 18.154217
##
## Epoch 100 - Loss: 17.844833
##
## Epoch 150 - Loss: 17.560537
##
## Epoch 200 - Loss: 17.134348
##
## Epoch 250 - Loss: 16.840366
##
## Epoch 300 - Loss: 16.344124
mod3 <- survdnn(
Surv(time, status) ~ age + karno,
data = veteran,
loss = "coxtime",
epochs = 300
)## Epoch 50 - Loss: 4.932558
##
## Epoch 100 - Loss: 4.864682
##
## Epoch 150 - Loss: 4.830169
##
## Epoch 200 - Loss: 4.784954
##
## Epoch 250 - Loss: 4.764827
##
## Epoch 300 - Loss: 4.731824
cv_results <- cv_survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(600),
metrics = c("cindex", "ibs"),
folds = 3,
hidden = c(16, 8),
loss = "cox",
epochs = 300
)
print(cv_results)grid <- list(
hidden = list(c(16), c(32, 16)),
lr = c(1e-3),
activation = c("relu"),
epochs = c(100, 300),
loss = c("cox", "aft", "coxtime")
)
tune_res <- tune_survdnn(
formula = Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(90, 300),
metrics = "cindex",
param_grid = grid,
folds = 3,
refit = FALSE,
return = "summary"
)
print(tune_res)tune_survdnn() can be used also to automatically refit
the best-performing model on the full dataset. This behavior is
controlled by the refit and return arguments.
For example:
best_model <- tune_survdnn(
formula = Surv(time, status) ~ age + karno + celltype,
data = veteran,
times = c(90, 300),
metrics = "cindex",
param_grid = grid,
folds = 3,
refit = TRUE,
return = "best_model"
)In this mode, cross-validation is used to select the optimal
hyperparameter configuration, after which the selected model is refitted
on the full dataset. The function then returns a fitted object of class
"survdnn".
The resulting model can be used directly for prediction visualization, and evaluation:
summary(best_model)
plot(best_model, times = 1:300)
predict(best_model, veteran, type = "risk", times = 180)This makes tune_survdnn() suitable for end-to-end
workflows, combining model selection and final model fitting.
plot(mod1, group_by = "celltype", times = 1:300)
plot(mod1, group_by = "celltype", times = 1:300, plot_mean_only = TRUE)
help(package = "survdnn")
?survdnn
?tune_survdnn
?cv_survdnn
?plot.survdnn# run all tests
devtools::test()By default, {torch} initializes model weights and
shuffles minibatches using random draws, so results may differ across
runs. Unlike set.seed(), which only controls R’s random
number generator, {torch} relies on its own RNG implemented
in C++ (and CUDA when using GPUs).
To ensure reproducibility, random seeds must therefore be set at the Torch level as well.
survdnn provides built-in control of randomness to
guarantee reproducible results across runs. The main fitting function,
survdnn(), exposes a dedicated .seed
argument:
mod <- survdnn(
Surv(time, status) ~ age + karno + celltype,
data = veteran,
epochs = 300,
.seed = 123
)When .seed is provided, survdnn()
internally synchronizes both R and Torch random number generators via
survdnn_set_seed(), ensuring reproducible:
weight initialization
dropout behavior
minibatch ordering
loss trajectories
If .seed = NULL (the default), randomness is left
uncontrolled and results may vary between runs.
For full reproducibility in cross-validation or hyperparameter
tuning, the same .seed mechanism is propagated internally
by cv_survdnn() and tune_survdnn(), ensuring
consistent data splits, model initialization, and optimization paths
across repetitions.
survdnn relies on the {torch} backend for
numerical computation. The number of CPU cores (threads) used during
training, prediction, and evaluation is controlled globally by
Torch.
By default, Torch automatically configures its CPU thread pools based on the available system resources, unless explicitly overridden by the user using:
torch::torch_set_num_threads(4)This setting affects:
model training
prediction
evaluation metrics
cross-validation and hyperparameter tuning
GPU acceleration can be enabled by setting
.device = "cuda" when calling survdnn()
(cv_survdnn() and tune_survdnn() too).
The survdnn R package is available on CRAN or github
Contributions, issues, and feature requests are welcome!
Open an issue or submit a pull request.
MIT License © 2025 Imad EL BADISY