rCISSVAE Vignette

2026-01-20

Installing package

Install devtools or remotes if not already installed:

install.packages("remotes")
# or
install.packages("devtools")

The rCISSVAE package can be installed with:

remotes::install_github("CISS-VAE/rCISS-VAE")
# or
devtools::install_github("CISS-VAE/rCISS-VAE")

Ensuring correct virtual environment for reticulate

This package uses reticulate to interface with the python version of the package ciss_vae. Therefore, it is necessary to make sure that you have a venv or conda environment set up that has the ciss_vae package installed. If you are comfortable creating an environment and installing the package, great! Then all you need to do is tell reticulate where to point.

For Venv

reticulate::use_virtualenv("./.venv", required = TRUE)

For conda

reticulate::use_condaenv("myenv", required = TRUE)

Virtual environment helper function

If you do not want to manually create the virtual environment, you can use the helper function create_cissvae_env() to create a virtual environment (venv) in your current working directory.

create_cissvae_env(
  envname = "./cissvae_environment", ## name of environment
  path = NULL, ## add path to wherever you want virtual environment to be
  install_python = FALSE, ## set to TRUE if you want create_cisssvae_env to install python for you
  python_version = "3.10" ## set to whatever version you want >=3.10. Python 3.10 or 3.11 recommended
)

Note: If you run into issues with create_cissvae_env(), you can create the virtual environment manually by following this tutorial

Once the environment is created, activate it using:

reticulate::use_virtualenv("./cissvae_environment", required = TRUE)
# If you used a non-default environment name then,
# reticulate::use_virtualenv("./your_environment_name", required = TRUE)

(optional) Installing other python packages

If you want to install other python packages (eg seaborn) to your environment, you can use reticulate::virtualenv_install().

(optional) Use check_devices() to see available gpu devices

Use check_devices() to see what cpu/gpu devices are available for model training. Optionally pass the path to your virtual environment using the parameter env_path to make sure that the correct environment is activated.

rCISSVAE::check_devices()
Available Devices:
  • cpu  (Main system processor) 

[1] "cpu"

Quickstart

Once reticulate is pointing to the virtual environment containing the ciss_vae python package, you can either use the run_cissvae function or the autotune_cissvae function. For details on using the CISSVAE model with binary/categorical values see binary variables tutorial.

If you know what hyperparameters you want to use for the model, use the run_cissvae function.

Run CISSVAE with Training History and Progress Tracking

Your data should be in a DataFrame format with optional index column. If you already have clusters you want to use, they should be in a separate vector from the dataframe. If you do not have clusters to begin with, set ‘clusters’ in run_cissvae() to NULL.

Example Dataset

With the rCISSVAE package comes a sample dataset with predetermined clusters. We will use this for the sake of the tutorial.

The dataset, df_missing contains an index column as well as the following:

Characteristic N = 8,000
Age 9.99 (8.73, 11.44)
Salary 5.70 (5.34, 6.17)
ZipCode10001 2,628 (33%)
ZipCode20002 2,697 (34%)
ZipCode30003 2,675 (33%)
Y11 0 (-11, 7)
Unknown 3,122
Y12 49 (-22, 58)
Unknown 3,118
Y13 72 (-17, 95)
Unknown 3,110
Y14 69 (-12, 119)
Unknown 3,129
Y15 72 (-12, 134)
Unknown 3,141
Y21 -9 (-22, 0)
Unknown 3,135
Y22 47 (-35, 58)
Unknown 3,094
Y23 74 (-29, 100)
Unknown 3,098
Y24 71 (-22, 128)
Unknown 3,146
Y25 73 (-22, 145)
Unknown 3,106
Y31 0 (-14, 11)
Unknown 2,067
Y32 59 (-17, 69)
Unknown 2,056
Y33 81 (-13, 101)
Unknown 2,013
Y34 79 (-6, 124)
Unknown 2,051
Y35 81 (-6, 139)
Unknown 2,054
Y41 0 (-6, 5)
Unknown 2,032
Y42 27 (-8, 32)
Unknown 2,022
Y43 37 (-5, 46)
Unknown 2,013
Y44 36 (-3, 56)
Unknown 2,023
Y45 37 (-3, 63)
Unknown 2,086
Y51 1.8 (-3.6, 5.8)
Unknown 2,077
Y52 25 (-5, 29)
Unknown 2,034
Y53 34 (-3, 41)
Unknown 1,976
Y54 33 (-1, 50)
Unknown 2,047
Y55 34 (-1, 56)
Unknown 2,050

Age, Salary and ZipCode columns represent demographic data with no missingness and columns Y1t-Y5t represent biomarker data obtained at different timepoints t.

Explanation of run_cissvae() Parameters

The run_cissvae() function is a comprehensive wrapper for all basic steps in running the CISS-VAE model, including dataset preparation, optional clustering, and running the imputation model.

Click for a detailed explanation of all parameters

Dataset Parameters - data: A DataFrame containing the dataset to be imputed. Contains optional index column.
- index_col: Name of index column to be preserved when imputing. Index column will not have any values held out for validation
- val_proportion: Fraction of non-missing entries to hold out for validation during training. To use different proportions for each cluster, pass a vector.
- replacement_value: Fill value for masked entries during training. Default is 0.0.
- columns_ignore: Character or integer vector containing columns to exclude when selecting validation data. These columns will be used during training. - print_dataset: Set TRUE to print dataset summary information during processing.

Clustering Parameters (optional)
- clusters: Vector of one cluster label per row of ‘data’ dataframe. If NULL, will automatically determine clusters using Leiden Clustering or KMeans.
- n_clusters: Number of clusters for KMeans clustering when ‘clusters’ is NULL. If n_clusters is NULL, will use Leiden Clustering for clustering.
- leiden_resolution: Resolution parameter for Leiden clustering. Defaults to 0.5. - k_neighbors: Number of nearest neighbors for the Leiden KNN graph construction. Defaults to 15. - leiden_objective: Objective function for Leiden clustering. One of {“CPM”, “RB”, “Modularity”}. Defaults to “CPM”. - missingness_proportion_matrix: Optional pre-computed missingness proportion matrix for feature-based clustering. If provided, clustering will be based on these proportions instead of direct 0/1 missingness pattern.
- scale_features: Set TRUE to scale features when using missingness proportion matrix clustering.

Model Parameters
- hidden_dims: A vector containing the sizes of hidden layers in encoder/decoder. The length of this vector determines number of hidden layers.
- latent_dim: The dimension of the latent space representation.
- layer_order_enc: A vector stating the pattern of ‘shared’ and ‘unshared’ layers for the encoder. The length must match length(hidden_dims). Default c(‘unshared’, ‘unshared’, ‘unshared’).
- layer_order_dec: A vector stating the pattern of ‘shared’ and ‘unshared’ layers for the decoder. The length must match length(hidden_dims). Default c(‘shared’, ‘shared’, ‘shared’).
- latent_shared: Whether latent space weights are shared across clusters. If FALSE, will have separate latent weights for each cluster.
- ouput_shared: If FALSE, will have separate output layer for each cluster.
- batch_size: Integer. Mini-batch size for training. Larger values may improve training stability but require more memory.
- return_model: If TRUE, returns the model object. Set TRUE to use plot_vae_architecture() after running.
- epochs: Number of epochs for initial training phase
- initial_lr: Initial learning rate for optimizer.
- decay_factor: Exponential decay factor for learning rate.
- beta: Weight for KL divergence term in VAE loss function.
- device: Device specification for computation (“cpu” or “cuda”). If NULL, automatically selects best available device.
- max-loops: Max number of impute-refit loops to perform.
- patience: Training stops if validation loss doesn’t improve for this many consecutive impute-refit loops.
- epochs-per-loop: Number of epochs per refit loop. If null, uses same value as epochs. Default NULL.
- decay_factor_refit: Decay factor for refit loops. If NULL, uses same value as decay_factor. Default NULL.
- beta_refit: KL weight for refit loops. If NULL, uses same value as beta. Default NULL.

Optional Parameters
- verbose: Set TRUE to print MSE for each loop as it runs. - return_silhouettes: If clusters not given, will return silhouette scores for automatic clustering. - return_history: If TRUE, returns training history as data.frame. Good for checking for overfitting. - return_dataset: If TRUE, returns ClusterDataset object.

Running the imputation

To run the imputation model, first load your data. You can use the cluster_summary() function to visualize the missingness by cluster. The cluster_summary() function builds off of {gtsummary}.

  library(tidyverse)
  library(reticulate)
  library(rCISSVAE)
  library(gtsummary)


## Set correct virtualenv
reticulate::use_virtualenv("./cissvae_environment", required = TRUE)

## Load the data
data(df_missing)
data(clusters) ## actual cluster labels in clusters$clusters (other column is index)

cluster_summary(
  data = df_missing, 
  clusters = clusters$clusters, 
  include =setdiff(names(df_missing), "index"),
  statistic = list(
    all_continuous() ~ "{mean} ({sd})",
    all_categorical() ~ "{n} / {N}\n ({p}%)"), 
  missing = "always")
Characteristic N 0
N = 2,000
1
1
N = 2,000
1
2
N = 2,000
1
3
N = 2,000
1
Age 8,000 10.10 (2.04) 10.19 (2.08) 10.21 (2.14) 10.29 (2.06)
    Unknown
0 0 0 0
Salary 8,000 5.81 (0.61) 5.83 (0.62) 5.83 (0.61) 5.81 (0.60)
    Unknown
0 0 0 0
ZipCode10001 8,000 646 / 2,000 (32%) 674 / 2,000 (34%) 663 / 2,000 (33%) 645 / 2,000 (32%)
    Unknown
0 0 0 0
ZipCode20002 8,000 703 / 2,000 (35%) 652 / 2,000 (33%) 655 / 2,000 (33%) 687 / 2,000 (34%)
    Unknown
0 0 0 0
ZipCode30003 8,000 651 / 2,000 (33%) 674 / 2,000 (34%) 682 / 2,000 (34%) 668 / 2,000 (33%)
    Unknown
0 0 0 0
Y11 4,878 -21 (10) -16 (9) 8 (5) -3 (6)
    Unknown
1,281 1,288 0 553
Y12 4,882 69 (11) -26 (9) 55 (6) -24 (8)
    Unknown
1,264 1,283 0 571
Y13 4,890 77 (12) -25 (9) 98 (12) -17 (7)
    Unknown
1,289 1,264 0 557
Y14 4,871 73 (12) -21 (8) 125 (16) -11 (6)
    Unknown
1,300 1,283 0 546
Y15 4,859 76 (12) -12 (6) 141 (19) -14 (6)
    Unknown
1,273 1,293 0 575
Y21 4,865 -33 (12) -28 (11) 1 (7) -12 (7)
    Unknown
1,266 1,292 0 577
Y22 4,906 69 (12) -40 (12) 54 (6) -36 (10)
    Unknown
1,266 1,276 0 552
Y23 4,902 79 (13) -38 (11) 104 (13) -29 (9)
    Unknown
1,273 1,275 0 550
Y24 4,854 75 (12) -32 (10) 135 (18) -22 (7)
    Unknown
1,302 1,287 0 557
Y25 4,894 78 (13) -22 (8) 153 (21) -25 (8)
    Unknown
1,257 1,294 0 555
Y31 5,933 -18 (10) -13 (9) 13 (5) 1 (6)
    Unknown
192 1,285 0 590
Y32 5,944 74 (11) -24 (10) 62 (7) -21 (8)
    Unknown
206 1,287 0 563
Y33 5,987 84 (13) -23 (10) 108 (13) -14 (7)
    Unknown
203 1,267 0 543
Y34 5,949 81 (13) -17 (8) 136 (17) -7 (6)
    Unknown
195 1,275 0 581
Y35 5,946 83 (13) -8 (6) 153 (20) -10 (7)
    Unknown
204 1,285 0 565
Y41 5,968 -8 (4) -5 (3) 6 (2) 1 (2)
    Unknown
184 1,279 0 569
Y42 5,978 35 (6) -11 (4) 29 (4) -9 (3)
    Unknown
199 1,282 0 541
Y43 5,987 39 (7) -10 (3) 49 (6) -6 (3)
    Unknown
217 1,242 0 554
Y44 5,977 37 (7) -8 (3) 62 (9) -3 (2)
    Unknown
186 1,280 0 557
Y45 5,914 39 (7) -4 (3) 70 (10) -5 (2)
    Unknown
204 1,305 0 577
Y51 5,923 -5.4 (3.6) -2.9 (3.0) 6.9 (1.9) 2.5 (2.0)
    Unknown
222 1,279 0 576
Y52 5,966 32 (5) -8 (3) 26 (3) -6 (3)
    Unknown
209 1,283 0 542
Y53 6,024 35 (6) -6 (3) 44 (6) -3 (2)
    Unknown
184 1,243 0 549
Y54 5,953 34 (6) -5 (3) 55 (7) -1 (2)
    Unknown
217 1,281 0 549
Y55 5,950 35 (6) -2 (2) 62 (9) -2 (2)
    Unknown
207 1,292 0 551
1 Mean (SD); n / N (%)

Then, plug your data and clusters into the run_cissvae() function.

## Run the imputation model. 
dat = run_cissvae(
  data = df_missing,
  index_col = "index",
  val_proportion = 0.1, ## pass a vector for different proportions by cluster
  columns_ignore = c("Age", "Salary", "ZipCode10001", "ZipCode20002", "ZipCode30003"), ## If there are columns in addition to the index you want to ignore when selecting validation set, list them here. In this case, we ignore the 'demographic' columns because we do not want to remove data from them for validation purposes. 
  clusters = clusters$clusters, ## we have precomputed cluster labels so we pass them here
  epochs = 500,
  return_silhouettes = FALSE,
  return_history = TRUE,  # Get detailed training history
  verbose = FALSE,
  return_model = TRUE, ## Allows for plotting model schematic
  device = "cpu",  # Explicit device selection
  layer_order_enc = c("unshared", "shared", "unshared"),
  layer_order_dec = c("shared", "unshared", "shared")
)

## Retrieve results
imputed_df <- dat$imputed
silhouette <- dat$silhouettes
training_history <- dat$history  # Detailed training progress

## Plot training progress
if (!is.null(training_history)) {
  plot(training_history$epoch, training_history$loss, 
       type = "l", main = "Training Loss Over Time", 
       xlab = "Epoch", ylab = "Loss")
}

plot_vae_architecture(model = dat$model, save_path = "test_plot_arch.png")

print(head(dat$imputed_dataset))
Cluster dataset:
 ClusterDataset(n_samples=8000, n_features=30, n_clusters=4)
  • Original missing: 61800 / 200000 (30.90%)
  • Validation held-out: 13783 (9.97% of non-missing)
  • .data shape:     (8000, 30)
  • .masks shape:    (8000, 30)
  • .val_data shape: (8000, 30)

#>   index       Age   Salary ZipCode10001 ZipCode20002 ZipCode30003         Y11
#> 0     0 11.044449 6.366204            0            1            0  -4.0495372
#> 1     1  9.727260 5.912558            1            0            0   0.5461677
#> 2     2 11.383020 6.636472            0            1            0  -1.2134339
#> 3     3 13.560905 5.896255            0            0            1 -10.6082144
#> 4     4  9.542490 6.128326            1            0            0   0.3575883
#> 5     5  9.542521 6.393217            1            0            0   4.7617960
#>         Y12         Y13        Y14        Y15        Y21       Y22       Y23
#> 0 -14.39339  -0.5147629 -14.369148 -17.564449 -10.736261 -18.53887 -35.77263
#> 1 -19.02272 -12.1895180  -7.722473   3.319988  -7.470250 -20.88160 -25.92436
#> 2 -19.03144 -20.3589058 -15.126495 -17.251385 -18.448421 -21.01385 -34.40086
#> 3 -22.24773  -7.1759834 -14.207619 -21.339748 -21.971752 -24.32133 -40.18794
#> 4 -16.48769 -11.3127708   7.535458  10.184475  -7.576005 -27.63498 -15.74972
#> 5 -18.96558 -12.2694435   5.667511  -9.094982  -4.120708 -29.16836 -25.07210
#>          Y24       Y25       Y31        Y32        Y33         Y34        Y35
#> 0 -28.098907 -30.24259 -1.627203  -6.557133 -16.769653 -10.6946259 -13.900185
#> 1 -17.231422 -18.69529  4.355482  -9.500225 -10.927032  -5.8868866  -6.088921
#> 2 -27.250603 -28.83982 -2.169048 -10.735180 -17.193771 -10.4940796 -12.286621
#> 3 -26.304344 -33.35109 -7.478534 -13.641594 -25.310104  -0.5022125 -15.429726
#> 4  -1.223648 -18.55935  8.103012 -13.563183  -9.832379  11.2815590  -2.962196
#> 5  -4.624435 -19.35896  7.710739  -8.304111 -10.381554   7.6654510  -4.992691
#>          Y41       Y42       Y43       Y44       Y45      Y51         Y52
#> 0 -0.9049605 -2.512167  3.527925 -3.694853 -5.680294 2.587588  1.84293461
#> 1  2.6245856 -4.218277 -5.776196 -1.379498 -2.329605 6.080512  0.04265404
#> 2  0.1033239 -4.733656 -7.215717 -3.350798 -6.895340 2.531148  0.31131554
#> 3 -2.8251922 -6.273870 -8.293898 -2.398297  3.204130 0.138556 -0.79631424
#> 4  3.6173897 -4.619453 -3.856878  7.120106 -1.495659 2.406390 -3.34118080
#> 5  2.1276770 -4.269014 -4.685389  5.764248 -2.990555 2.464692 -3.37204909
#>         Y53        Y54        Y55
#> 0 -4.681194 -2.2484055 -2.6790791
#> 1 -2.290062 -0.8873978  0.5625324
#> 2 -5.427431 -1.3301620 -2.3243809
#> 3 -5.729860 -1.6395874 -4.4457073
#> 4 -1.915190  6.5168343  0.1034508
#> 5 -2.237415  5.5294094 -1.0737858

Clustering Features by Missingness Patterns

Before running CISS-VAE, you can cluster features based on their missingness patterns. This helps identify features that tend to be missing together systematically, which can improve imputation quality.

library(rCISSVAE)

data(df_missing)


cluster_result <- cluster_on_missing(
  data = df_missing,
  cols_ignore =  c("index", "Age", "Salary", "ZipCode10001", "ZipCode20002", "ZipCode30003"),
  n_clusters = 4,  # Use KMeans with 4 clusters
  seed = 42
)

cluster_summary(df_missing, factor(cluster_result$clusters), include = setdiff(names(df_missing), "index"), 
statistic = list(
  gtsummary::all_continuous() ~ "{mean} ({sd})",
  gtsummary::all_categorical() ~ "{n} / {N}\n ({p}%)"), 
  missing = "always")  

cat(paste("Clustering quality (silhouette):", round(cluster_result$silhouette, 3)))


result <- run_cissvae(
  data = df_missing,
  index_col = "index",
  clusters = cluster_result$clusters,
  return_history = TRUE,
  verbose = FALSE,
  device = "cpu"
)
Characteristic N 0
N = 2,000
1
1
N = 2,000
1
2
N = 2,000
1
3
N = 2,000
1
Age 8,000 10.10 (2.04) 10.19 (2.08) 10.21 (2.14) 10.29 (2.06)
    Unknown
0 0 0 0
Salary 8,000 5.81 (0.61) 5.83 (0.62) 5.83 (0.61) 5.81 (0.60)
    Unknown
0 0 0 0
ZipCode10001 8,000 646 / 2,000 (32%) 674 / 2,000 (34%) 663 / 2,000 (33%) 645 / 2,000 (32%)
    Unknown
0 0 0 0
ZipCode20002 8,000 703 / 2,000 (35%) 652 / 2,000 (33%) 655 / 2,000 (33%) 687 / 2,000 (34%)
    Unknown
0 0 0 0
ZipCode30003 8,000 651 / 2,000 (33%) 674 / 2,000 (34%) 682 / 2,000 (34%) 668 / 2,000 (33%)
    Unknown
0 0 0 0
Y11 4,878 -21 (10) -16 (9) 8 (5) -3 (6)
    Unknown
1,281 1,288 0 553
Y12 4,882 69 (11) -26 (9) 55 (6) -24 (8)
    Unknown
1,264 1,283 0 571
Y13 4,890 77 (12) -25 (9) 98 (12) -17 (7)
    Unknown
1,289 1,264 0 557
Y14 4,871 73 (12) -21 (8) 125 (16) -11 (6)
    Unknown
1,300 1,283 0 546
Y15 4,859 76 (12) -12 (6) 141 (19) -14 (6)
    Unknown
1,273 1,293 0 575
Y21 4,865 -33 (12) -28 (11) 1 (7) -12 (7)
    Unknown
1,266 1,292 0 577
Y22 4,906 69 (12) -40 (12) 54 (6) -36 (10)
    Unknown
1,266 1,276 0 552
Y23 4,902 79 (13) -38 (11) 104 (13) -29 (9)
    Unknown
1,273 1,275 0 550
Y24 4,854 75 (12) -32 (10) 135 (18) -22 (7)
    Unknown
1,302 1,287 0 557
Y25 4,894 78 (13) -22 (8) 153 (21) -25 (8)
    Unknown
1,257 1,294 0 555
Y31 5,933 -18 (10) -13 (9) 13 (5) 1 (6)
    Unknown
192 1,285 0 590
Y32 5,944 74 (11) -24 (10) 62 (7) -21 (8)
    Unknown
206 1,287 0 563
Y33 5,987 84 (13) -23 (10) 108 (13) -14 (7)
    Unknown
203 1,267 0 543
Y34 5,949 81 (13) -17 (8) 136 (17) -7 (6)
    Unknown
195 1,275 0 581
Y35 5,946 83 (13) -8 (6) 153 (20) -10 (7)
    Unknown
204 1,285 0 565
Y41 5,968 -8 (4) -5 (3) 6 (2) 1 (2)
    Unknown
184 1,279 0 569
Y42 5,978 35 (6) -11 (4) 29 (4) -9 (3)
    Unknown
199 1,282 0 541
Y43 5,987 39 (7) -10 (3) 49 (6) -6 (3)
    Unknown
217 1,242 0 554
Y44 5,977 37 (7) -8 (3) 62 (9) -3 (2)
    Unknown
186 1,280 0 557
Y45 5,914 39 (7) -4 (3) 70 (10) -5 (2)
    Unknown
204 1,305 0 577
Y51 5,923 -5.4 (3.6) -2.9 (3.0) 6.9 (1.9) 2.5 (2.0)
    Unknown
222 1,279 0 576
Y52 5,966 32 (5) -8 (3) 26 (3) -6 (3)
    Unknown
209 1,283 0 542
Y53 6,024 35 (6) -6 (3) 44 (6) -3 (2)
    Unknown
184 1,243 0 549
Y54 5,953 34 (6) -5 (3) 55 (7) -1 (2)
    Unknown
217 1,281 0 549
Y55 5,950 35 (6) -2 (2) 62 (9) -2 (2)
    Unknown
207 1,292 0 551
1 Mean (SD); n / N (%)
#> Clustering quality (silhouette):  0.135

Using Pre-computed Missingness Proportion Matrix

To create clusters based on proportion of missingness across all timepoints for a given feature, you can provide a pre-computed missingness proportion matrix (by using create_missingness_prop_matrix() or manually) directly to run_cissvae():


## Standardize df_missing column names to feature_timepoint format
colnames(df_missing) = c('index', 'Age', 'Salary', 'ZipCode10001', 'ZipCode20002', 'ZipCode30003', 'Y1_1', 'Y1_2', 'Y1_3', 'Y1_4', 'Y1_5', 'Y2_1', 'Y2_2', 'Y2_3', 'Y2_4', 'Y2_5', 'Y3_1', 'Y3_2', 'Y3_3', 'Y3_4', 'Y3_5', 'Y4_1', 'Y4_2', 'Y4_3', 'Y4_4', 'Y4_5', 'Y5_1', 'Y5_2', 'Y5_3', 'Y5_4', 'Y5_5')

# Create and examine missingness proportion matrix
prop_matrix <- create_missingness_prop_matrix(df_missing, 
index_col = "index", 
cols_ignore = c('Age', 'Salary', 'ZipCode10001', 'ZipCode20002', 'ZipCode30003'),
repeat_feature_names = c("Y1", "Y2", "Y3", "Y4", "Y5"))


cat("Missingness proportion matrix dimensions:\n")
cat(dim(prop_matrix), "\n")
cat("Sample of proportion matrix:\n")
print(head(prop_matrix[, 1:5]))

# Use proportion matrix with scaling for better clustering
advanced_result <- run_cissvae(
  data = df_missing,
  index_col = "index",
  clusters = NULL,  # Let function cluster using prop_matrix
  columns_ignore = c('Age', 'Salary', 'ZipCode10001', 'ZipCode20002', 'ZipCode30003'), 
  missingness_proportion_matrix = prop_matrix,
  scale_features = TRUE,  # Standardize features before clustering
  n_clusters = 4,
  leiden_resolution = 0.1,  
  epochs = 5,
  return_history = TRUE,
  return_silhouettes = TRUE,
  device = "cpu",
  verbose = FALSE,
  return_clusters = TRUE
)

print("Clustering quality:")
print(paste("Silhouette score:", round(advanced_result$silhouette_width, 3)))

## Plotting imputation loss by epoch 

ggplot2::ggplot(data = advanced_result$training_history, aes(x = epoch, y = imputation_error)) + geom_point() + labs(y = "Imputation Loss", x = "Epoch") + 
  theme_classic()
#> Missingness proportion matrix dimensions:
#> 8000 5
#> Sample of proportion matrix:
#>    Y1  Y2  Y3  Y4  Y5
#> 1 0.4 0.4 0.2 0.4 0.2
#> 2 0.4 0.2 0.2 0.2 0.2
#> 3 0.4 0.2 0.2 0.4 0.2
#> 4 0.4 0.2 0.4 0.4 0.2
#> 5 0.4 0.4 0.2 0.2 0.4
#> 6 0.2 0.2 0.4 0.2 0.4
#> Clustering quality:
#> Silhouette score: 0.656

Advanced Hyperparameter Optimization with Autotune

Understanding Parameter Types

For hyperparameter optimization in autotune_cissvae(), parameters can be specified as:

  • Fixed value: beta = 0.01 → parameter remains constant across trials
  • Categorical choice: c(64, 128, 256) → Optuna selects from the provided options
  • Float range: reticulate::tuple(1e-4, 1e-3) → Optuna suggests floats in the specified range

Layer Placement Strategies

The layer arrangement strategies control how shared and unshared layers are positioned:

  • "at_end": Places shared layers at the end of the encoder or start of the decoder
  • "at_start": Places shared layers at the start of the encoder or end of the decoder
  • "alternating": Distributes shared layers evenly throughout the architecture
  • "random": Uses random placement of shared layers (with reproducible seed)
library(tidyverse)
library(reticulate)
library(rCISSVAE)
reticulate::use_virtualenv("./cissvae_environment", required = TRUE)

data(df_missing)
data(clusters)

aut <- autotune_cissvae(
  data = df_missing,
  index_col = "index",
  clusters = clusters$clusters,
  save_model_path = NULL,
  save_search_space_path = NULL,
  n_trials = 3, ## Using low number of trials for demo
  study_name = "comprehensive_vae_autotune",
  device_preference = "cpu",
  show_progress = FALSE,  # Set true for Rich progress bars with training visualization
  optuna_dashboard_db = "sqlite:///optuna_study.db",  # Save results to database
  load_if_exists = FALSE, ## Set true to load and continue study if it exists
  seed = 42, 
  verbose = FALSE,
  
  # Search strategy options
  constant_layer_size = FALSE,     # Allow different sizes per layer
  evaluate_all_orders = FALSE,     # Sample layer arrangements efficiently
  max_exhaustive_orders = 100,     # Limit for exhaustive search
  
  ## Hyperparameter search space
  num_hidden_layers = c(2, 5),     # Try 2-5 hidden layers
  hidden_dims = c(64, 512),        # Layer sizes from 64 to 512
  latent_dim = c(10, 100),         # Latent dimension range
  latent_shared = c(TRUE, FALSE),
  output_shared = c(TRUE, FALSE),
  lr = 0.01,  # Learning rate range
  decay_factor = 0.99,
  beta = 0.01,  # KL weight range
  num_epochs = 500,                # Fixed epochs for demo
  batch_size = c(1000, 4000),     # Batch size options
  num_shared_encode = c(0, 1, 2, 3),
  num_shared_decode = c(0, 1, 2, 3),
  
  # Layer placement strategies - try different arrangements
  encoder_shared_placement = c("at_end", "at_start", "alternating", "random"),
  decoder_shared_placement = c("at_start", "at_end", "alternating", "random"),
  
  refit_patience = 2,        # Early stopping patience
  refit_loops = 100,                # Fixed refit loops
  epochs_per_loop = 100,   # Epochs per refit loop
  reset_lr_refit = c(TRUE, FALSE)
)

# Analyze results
imputed <- aut$imputed
best_model <- aut$model
study <- aut$study
results <- aut$results

# View best hyperparameters
print("Trial results:")
results %>% kable() %>%
  kable_styling(font_size=12)

# Plot model architecture
plot_vae_architecture(best_model, title = "Optimized CISSVAE Architecture")
#> [1] "Trial results:"
trial_number imputation_error num_hidden_layers hidden_dim_0 hidden_dim_1 hidden_dim_2 hidden_dim_3 hidden_dim_4 latent_dim latent_shared output_shared batch_size num_shared_encode num_shared_decode encoder_shared_placement decoder_shared_placement layer_order_enc_used layer_order_dec_used
0 29.62110 5 512 64 512 64 64 100 FALSE FALSE 1000 1 1 alternating at_start S,U,U,U,U S,U,U,U,U
1 33.00819 5 64 64 64 512 64 10 TRUE TRUE 1000 1 0 alternating alternating S,U,U,U,U U,U,U,U,U
2 58.28946 2 512 64 NaN NaN NaN 10 TRUE FALSE 4000 0 3 alternating random U,U S,S
Autotuned VAE Architecture
Autotuned VAE Architecture

For more information on using optuna dashboard, see the Optuna Dashboard tutorial