## ----install-from-github, eval=FALSE------------------------------------------
# install.packages("devtools")
# devtools::install_github("magronah/power.nb")

## ----pkgs, message=FALSE, warning=FALSE---------------------------------------
##Load libraries
library(tidyverse)
library(dplyr)
library(power.nb)
library(patchwork)
library(rlist)
library(latex2exp)

data_path =  system.file("pkgdata", package = "power.nb")
##Read count data and metadata
data <- read.table(
  file.path(data_path, "ob_ross_ASVs_table.tsv"),
  header = TRUE, sep = "\t",
  check.names = FALSE, comment.char = "", 
  row.names = 1
)

metadata <- read.table(
  file.path(data_path, "ob_ross_metadata.tsv"),
  header = TRUE, sep = "\t",
  check.names = FALSE, comment.char = ""
)

metadata <- metadata %>%
  setNames(c("SampleID", "Groups"))

dim(data); dim(metadata)
head(data,2); head(metadata,2)


## ----filter-data--------------------------------------------------------------
filter_data  =  filter_low_count(countdata     =  data,
                                metadata       =  metadata,
                                abund_thresh   =  3,
                                sample_thresh  =  2,
                                sample_colname =  "SampleID",
                                group_colname  =  "Groups")

dim(filter_data)

## ----logfoldchange, message=FALSE---------------------------------------------

unique(metadata$Groups)

foldchange_est <- deseqfun(countdata      =   filter_data,
                           metadata       =   metadata,
                           alpha_level    =   0.1,
                           ref_name       =  "H",
                           group_colname  =  "Groups",
                           sample_colname =  "SampleID")

logfoldchange =  foldchange_est$deseq_estimate$log2FoldChange

head(logfoldchange)

## ----quiet_fun----------------------------------------------------------------
### function to mute printing out messages about
### update on the number of iterations 

quiet <- function(expr) {
  out <- suppressWarnings(suppressMessages(
    capture.output(res <- eval.parent(substitute(expr)))
  ))
  res}

### Path to other files used in this vignette
extdata_path =  system.file("extdata", package = "power.nb")

## ----logmean, message=FALSE, warning=FALSE------------------------------------
 logmean    =  log(rowMeans(filter_data))
 logmeanFit <- readRDS(file.path(extdata_path, "logmeanFit.rds"))

## The code below was used to generate the precomputed
## `logmeanFit.rds` object. The saved result is loaded
## throughout this vignette to reduce vignette build times.

# logmeanFit =   quiet(logmean_fit(logmean, sig = 0.05,
#                                  max.comp = 4, max.boot = 100))


## ----logfoldchangeFit---------------------------------------------------------

logfoldchangeFit <- readRDS(file.path(extdata_path, "logfoldchangeFit.rds"))

logfoldchangeFit

## The code below was used to generate the precomputed
## `logfoldchangeFit.rds` object. The saved result is loaded
## throughout this vignette to reduce vignette build times.

# logfoldchangeFit <- quiet(logfoldchange_fit(logmean,
#                                             logfoldchange,
#                                             ncore = 1,
#                                             max_sd_ord = 2,
#                                             max_np = 5,
#                                             minval = -5,
#                                             maxval = 5,
#                                             itermax = 100,
#                                             NP = 800,
#                                             seed = 100))



## ----dispersion,message=FALSE, warning=FALSE----------------------------------
dispersion    =  foldchange_est$dispersion
dispersionFit =  dispersion_fit(dispersion, logmean)
dispersionFit$param

## ----sim-counts---------------------------------------------------------------

### Simulated log mean count and log foldchange from the fitted 
### log mean count and log foldchange models

logmean_param       =  logmeanFit$param
logfoldchange_param =  logfoldchangeFit

notu  = dim(filter_data)[1]
notu
logmean_sim  =  logmean_sim_fun(logmean_param, notu)

logfoldchange_sim  =  logfoldchange_sim_fun(logmean_sim = logmean_sim,
                                            logfoldchange_param = 
                                            logfoldchange_param,
                                            max_lfc  = 15,
                                            max_iter = 30000)

head(logfoldchange_sim)

## ----sim-counts-nb------------------------------------------------------------

nsim = 4
ntreat = sum(metadata$Groups == "OB")
ncont  = sum(metadata$Groups == "H")
dispersion_param  =  dispersionFit$param

countdata_sims = countdata_sim_fun(logmean_param,
                                   logfoldchange_param,
                                   dispersion_param,
                                   nsamp_per_group = NULL,
                                   ncont = ncont,
                                   ntreat = ntreat,
                                   notu,
                                   nsim = nsim,
                                   disp_scale = 0.7,
                                   max_lfc = 15,
                                   maxlfc_iter = 1000,
                                   seed = 121)

dim(countdata_sims$treat_countdata_list$sim_1)
dim(countdata_sims$control_countdata_list$sim_1)
dim(countdata_sims$countdata_list$sim_1)
dim(countdata_sims$metadata_list$sim_1)

## ----compare-with-sim, warning=FALSE------------------------------------------

compare_dataset <- function(countdata_sim_list,countdata_obs,method = c("var", "mean")){
  
  # Check if method is either "var" or "mean"
  if (!(method %in% c("var", "mean"))) {
    stop("Invalid method. Please choose either 'var' or 'mean'.")
  }
  
  # Calculate variance or mean based on the chosen method
  if (method == "var") {
    calc_func <- stats::var
    xlab_text <- "$\\log$(variance of taxa)"
  } else {
    calc_func <- mean
    xlab_text <- "$\\log$(mean of taxa)"
  }
  
  # Create a list combining simulation data and observed data
  dlist <- list.append(countdata_sim_list, obs_filt = countdata_obs)
  
  # Calculate variance or mean for each dataset in the list
  vars <- dlist |> purrr::map(~ apply(., 1, calc_func)) |>
    purrr::map_dfr(~ tibble(var = .), .id = "type")
  
  # Separate observed and simulated data
  vars_obs <- vars[vars$type == "obs_filt", ]
  vars_sim <- vars[vars$type != "obs_filt", ]
  
  # Create ggplot for visualization
  p <- ggplot(vars_sim, aes(x = log(var), colour = type)) + 
    geom_density(lwd = 1.5) +
    geom_density(data = vars_obs, aes(x = log(var)), 
                 colour = "black", linetype = "dashed", lwd = 1.5) +
    xlab(TeX(xlab_text)) + theme_bw()
  
  return(p)
}
                                   
countdata_sim_list =  countdata_sims$countdata_list[1]
countdata_obs      =  filter_data
p11 = compare_dataset(countdata_sim_list,countdata_obs,method = "mean")
p12 = compare_dataset(countdata_sim_list,countdata_obs,method = "var")

(p11|p12) + plot_layout(guides = "collect")

dim(countdata_obs)


## ----est-p-values, cache = FALSE----------------------------------------------

countdata_list  =   countdata_sims$countdata_list
metadata_list   =   countdata_sims$metadata_list
desq_est   =   quiet(deseq_fun_est(metadata_list =  metadata_list,
                            countdata_list =  countdata_list,
                            alpha_level    =  0.1,
                            group_colname  = "Groups",
                            sample_colname = "Samples",
                            num_cores      =  1,
                            ref_name       = "control"))

deseq_est_list     =    lapply(desq_est, function(x){x$deseq_estimate})
names(desq_est$sim_1)

## ----gam-fit, cache = FALSE, warning = FALSE----------------------------------
true_lmean_list       =    countdata_sims$logmean_list
true_lfoldchange_list =    countdata_sims$logfoldchange_list

gamFit <- gam_fit(deseq_est_list,
                  true_lfoldchange_list,
                  true_lmean_list,
                  grid_len = 50,
                  alpha_level=0.1)

range(true_lfoldchange_list)
range(true_lmean_list)

## ----contour-plot, cache = FALSE----------------------------------------------
cont_breaks     =  seq(0,1,0.1)
combined_data   =  gamFit$combined_data
power_estimate  =  gamFit$power_estimate

contour_plot <- contour_plot_fun(combined_data,
                                 power_estimate,
                                 cont_breaks)
contour_plot

## ----sim_multi_ss-------------------------------------------------------------
nsim = 4
nsample_vec = seq(10, 100, 20)
countdata_sims_list  =  list()
for(j in 1:length(nsample_vec)){
  countdata_sims_list[[j]]  =  countdata_sim_fun(logmean_param,
                                                 logfoldchange_param,
                                                 dispersion_param,
                                                 nsamp_per_group = nsample_vec[j],
                                                 ncont  = NULL,
                                                 ntreat = NULL,
                                                 notu,
                                                 nsim = nsim,
                                                 disp_scale = 0.3,
                                                 max_lfc = 15,
                                                 maxlfc_iter = 1000,
                                                 seed = 121)
}

names(countdata_sims_list) = paste0("sample_",nsample_vec)

## ----deseq_multi_ss, cache = FALSE, message=FALSE-----------------------------

desq_est_list  =  list()
for(i in 1:length(countdata_sims_list)){

  countdata_list       =   countdata_sims_list[[i]]$countdata_list
  metadata_list        =   countdata_sims_list[[i]]$metadata_list
  desq_est_list[[i]]   =   deseq_fun_est(metadata_list =  metadata_list,
                               countdata_list =  countdata_list,
                               alpha_level    =  0.1,
                               group_colname  = "Groups",
                               sample_colname = "Samples",
                               num_cores      =  1,
                               ref_name       = "control")

}
names(desq_est_list) = paste0("sample_",nsample_vec)


## ----gam_multi_ss,  cache = FALSE---------------------------------------------

pow_ss  <- readRDS(file.path(extdata_path, "gam_fit_MultiSamples.rds"))

## The code below was used to generate the precomputed
## `gam_fit_MultiSamples.rds` object.
## The saved result is loaded throughout
## this vignette to reduce vignette build times.

# deseq_list = lapply(desq_est_list, function(x){
#   read_data(x, "deseq_estimate")
# })
# 
# pval_est_list <- lapply(deseq_list, function(sample_list) {
#   lapply(sample_list, function(sim_df) {
#     sim_df$padj
#   })
# })
# 
# 
# logfoldchange_list  =   read_data(countdata_sims_list,"logfoldchange_list")
# logmean_list        =   read_data(countdata_sims_list,"logmean_list")
# 
# pow_ss <- power_fun_ss(pval_est_list,
#                          logmean_list,
#                          nsample_vec = nsample_vec,
#                          logfoldchange_list,
#                          alpha_level=0.1)
  

## ----ss_estimation------------------------------------------------------------
target_power =  0.8; model  =  pow_ss$gam_mod; abs_lfc = 1.2; logmean = 4

ss_estimate = uniroot_ss(target_power, logmean, abs_lfc, model,
                        xmin = log2(10), xmax = log2(5000),
                        maxiter = 10000,
                        max_report = 2000)

ss_estimate

## ----ss_estimation_grid-------------------------------------------------------
target_power_vec <- seq(0.5, 0.9, 0.05) #c(0.6, 0.7, 0.8, 0.9)
abs_lfc_vec      <- c(0.5, 1.0, 1.5, 2.0)
logmean_vec      <- c(2, 4, 6)

param_grid <- expand.grid(
  target_power = target_power_vec,
  abs_lfc = abs_lfc_vec,
  logmean = logmean_vec
)

param_grid$sample_size_per_group <- apply(
  param_grid,
  1,
  function(x) {
      uniroot_ss(
        target_power = as.numeric(x["target_power"]),
        logmean      = as.numeric(x["logmean"]),
        abs_lfc      = as.numeric(x["abs_lfc"]),
        model        = pow_ss$gam_mod,
        xmin         = log2(10),
        xmax         = log2(5000),
        maxiter      = 10000,
        max_report   = 2000)$sample_size_per_group
  }
)

head(param_grid)

## ----ss_grid_plot, fig.width=11, fig.height=8---------------------------------
ggplot(param_grid,
       aes(x = target_power,
           y = sample_size_per_group,
           group = 1)) +
  geom_line() +
  geom_point() +
  facet_grid(
    abs_lfc ~ logmean,
    labeller = labeller(
      abs_lfc = function(x) paste0("|log(fold change)| = ", x),
      logmean = function(x) paste0("log(mean count) = ", x)
    )
  ) +
    labs(
    x = "Target Power",
    y = "Estimated Sample Size per Group",
    title = "Estimated Sample Size by Target Power"
  ) + 
  theme_bw() +
  theme(
    plot.title = element_text(hjust = 0.5)
  )


## ----echo=FALSE, warning=FALSE, message=FALSE---------------------------------
library(knitr)
library(kableExtra)
library(knitr)
library(kableExtra)

df <- data.frame(
  `Dataset` <- c("ArcticFireSoils", "Blueberry", "cdi_schubert", "cdi_vincent", 
                    "Chemerin",  "crc_baxter", "crc_zeller", "edd_singh", "Exercise",
                   "glass_plastic_oberbeckmann", "GWMC_ASIA_NA", "GWMC_HOT_COLD", "hiv_dinh",
                   "hiv_lozupone", "hiv_noguerajulian", "ibd_papa", "Ji_WTP_DS", "MALL", "ob_goodrich",
                   "ob_ross","ob_turnbaugh", "ob_zhu", "ob_zupancic", "par_scheperjans",
                   "sed_plastic_hoellein",  "sed_plastic_rosato", "sw_plastic_frere", "t1d_alkanani",
                   " ", "t1d_mejialeon",    "wood_plastic_kesy"),
  `$n_{comp}$` = c("4", "4", "3", "2", "3", "4", "5", "1", "3", "3",
  "4", "5", "3", "3", "4", "3", "3", "3", "4", "2",
  "3", "3", "3", "3", "3", "3", "5", "5"," ",  "2", "3"),
  `$p_1$`        =  c(
  "(0.10,0.38,0.37,0.15)",
  "(0.37,0.22,0.29,0.12)",
  "(0.54,0.37,0.09)",
  "(0.56,0.44)",
  "(0.56,0.02,0.42)",
  "(0.16,0.33,0.20,0.32)",
  "(0.13,0.40,0.30,0.13,0.03)",
  "1",
  "(0.15,0.48,0.36)",
  "(0.44,0.50,0.06)",
  "(0.10,0.35,0.38,0.17)",
  "(0.14,0.35,0.37,0.00,0.14)",
  "(0.49,0.45,0.06)",
  "(0.30,0.28,0.42)",
  "(0.09,0.46,0.25,0.19)",
  "(0.06,0.49,0.45)",
  "(0.33,0.30,0.37)",
  "(0.39,0.11,0.50)",
  "(0.15,0.16,0.62,0.06)",
  "(0.34,0.66)",
  "(0.55,0.03,0.42)",
  "(0.57,0.01,0.42)",
  "(0.45,0.32,0.23)",
  "(0.39,0.32,0.28)",
  "(0.41,0.39,0.20)",
  "(0.40,0.43,0.17)",
  "(0.06,0.25,0.33,0.27,0.08)",
  "(0.06,0.30,0.47,0.15,0.02)",
   " ",
  "(0.47,0.53)",
  "(0.29,0.42,0.29)"
),
  `$\\mu_1$`   =  c(
  "(-2.81,-1.56,0.31,3.27)",
  "(-0.47,-1.45,0.83,2.54)",
  "(-2.02,0.78,4.17)",
  "(3.06,-0.03)",
  "(-0.86,0.74,2.27)",
  "(-3.13,-2.01,-0.49,2.73)",
  "(-2.32,-1.33,0.15,2.33,5.75)",
  "(-0.42)",
  "(-1.20,0.17,2.39)",
  "(-0.08,1.91,5.71)",
  "(-4.72,-3.18,-1.17,1.41)",
  "(-4.73,-3.16,-1.06,-0.93,1.57)",
  "(0.42,2.15,5.79)",
  "(-0.12,1.30,3.80)",
  "(-2.83,-1.71,0.21,3.64)",
  "(0.14,1.71,4.69)",
  "(-0.68,2.61,6.67)",
  "(-0.65,0.65,3.26)",
  "(-3.89,-2.39,-0.17,3.31)",
  "(-0.58,2.34)",
  "(-0.98,0.14,1.22)",
  "(-0.12,0.53,2.77)",
  "(-1.98,-0.56,1.91)",
  "(-1.71,-0.31,1.34)",
  "(-0.03,1.46,3.32)",
  "(-0.25,1.58,4.06)",
  "(-1.88,-1.04,0.19,1.93,4.08)",
  "(-3.06,-2.54,-1.91,-0.87,0.82)",
   " ",
  "(0.73,3.57)",
  "(-1.16,1.16,4.65)"),
  
  `$\\sigma_1$`  = c(
  "(0.45,0.80,1.35,2.39)",
  "(0.71,0.52,1.05,1.66)",
  "(0.94,1.57,2.04)",
  "(2.10,0.89)",
  "(1.14,0.09,2.32)",
  "(0.55,0.76,2.46,1.21)",
  "(0.50,0.71,1.09,1.70,2.52)",
  "(2.39)",
  "(0.50,1.10,1.97)",
  "(0.64,1.23,1.18)",
  "(0.61,0.98,1.50,2.28)",
  "(0.69,1.55,1.05,0.07,2.28)",
  "(0.84,1.25,0.70)",
  "(0.92,0.50,1.83)",
  "(0.46,0.75,1.33,2.32)",
  "(0.69,0.57,0.98)",
  "(0.84,2.69,1.84)",
  "(0.82,0.40,1.83)",
  "(0.58,1.61,0.98,2.47)",
  "(0.75,1.98)",
  "(0.79,0.08,1.66)",
  "(0.82,0.04,1.97)",
  "(0.87,1.42,0.41)",
  "(0.71,1.02,1.58)",
  "(0.65,1.05,1.67)",
  "(0.70,1.23,1.93)",
  "(0.33,0.52,0.83,1.35,2.21)",
  "(0.17,0.30,0.46,0.72,1.37)",
  " ",
  "(0.92,1.94)",
  "(0.70,1.43,2.46)"
),
  `  `  = c(rep(" ", 31)),
  `  `  = c(rep(" ", 31)),
  `  `  = c(rep(" ", 31)),
  `  `  = c(rep(" ", 31)),
  `$n_{comp}$`   = c(rep("2", 27), "5", " ", "2", "2"),
  `$p_2$`        =  c(
  "(-3.88)", "(4.97)", "(-1.16)", "(4.93)", "(4.52)",
  "(4.65)", "(3.66)", "(-1.99)", "(-4.83)", "(-3.83)",
  "(4.08)", "(3.58)", "(-2.29)", "(-4.86)", "(-2.44)",
  "(-1.31)", "(4.44)", "(4.26)", "(4.90)", "(4.40)",
  "(1.60)", "(-2.12)", "(-1.90)", "(4.16)", "(-4.73)",
  "(4.65)", "(2.68)", "(0.23,-2.85,-4.30,4.83)",
  " ",
  "(-2.93)", "(3.96)"
),
  `$\\mu_2$`     = c(
  "$0.13 + 0.06x_i,\\;-1.55 - 0.62x_i$",
  "$0.16 - 4.31x_i,\\;0.01 + 0.01x_i$",
  "$-1.42 - 0.37x_i,\\;2.61 + 0.77x_i$",
  "$1.04 - 1.14x_i,\\;-0.21 - 0.09x_i$",
  "$2.95 + 3.42x_i,\\;-0.02 - 0.00x_i$",
  "$4.25 + 3.11x_i,\\;-0.09 - 0.03x_i$",
  "$0.78 + 0.88x_i,\\;-0.16 - 0.02x_i$",
  "$0.09 + 0.01x_i,\\;-2.20 - 0.72x_i$",
  "$0.03 - 0.02x_i,\\;2.27 - 2.38x_i$",
  "$-0.07 - 0.05x_i,\\;2.17 + 0.13x_i$",
  "$3.69 - 0.21x_i,\\;0.11 + 0.03x_i$",
  "$1.62 + 0.25x_i,\\;-0.03 - 0.01x_i$",
  "$0.33 + 0.11x_i,\\;-1.54 - 0.68x_i$",
  "$0.30 + 0.24x_i,\\;-4.39 + 0.21x_i$",
  "$0.07 + 0.04x_i,\\;-0.47 - 0.19x_i$",
  "$-0.32 + 0.11x_i,\\;0.57 + 0.11x_i$",
  "$2.59 - 1.66x_i,\\;0.10 + 0.00x_i$",
  "$-0.09 + 1.07x_i,\\;0.08 + 0.05x_i$",
  "$-0.35 - 2.37x_i,\\;-0.09 - 0.02x_i$",
  "$0.81 + 0.42x_i,\\;0.02 + 0.01x_i$",
  "$-1.24 - 0.36x_i,\\;0.28 + 0.10x_i$",
  "$-0.30 + 0.05x_i,\\;1.03 + 0.47x_i$",
  "$-0.07 + 0.03x_i,\\;1.26 + 0.18x_i$",
  "$0.88 + 0.44x_i,\\;-0.10 - 0.02x_i$",
  "$-0.07 - 0.00x_i,\\;-0.52 + 0.41x_i$",
  "$-0.25 - 1.49x_i,\\;-0.08 - 0.04x_i$",
  "$1.82 + 0.35x_i,\\;-0.06 + 0.01x_i$",
  "$-1.75 + 0.01x_i,\\;-0.36 - 2.61x_i,\\;-0.34 + 2.04x_i$",
  "$-0.13 - 0.93x_i,\\;-0.03 - 0.02x_i$",
  "$0.25 + 0.02x_i,\\;-3.32 + 0.13x_i$",
  "$0.28 + 2.24x_i,\\;-0.00 - 0.01x_i$"
),
`${f(x)}_{\\sigma_2}$`  = c(
"$-0.65 + 0.30x_i,\\;0.44 + 0.46x_i$",
"$-1.17 + 2.07x_i,\\;-1.07 + 0.13x_i$",
"$-0.27 + 0.54x_i,\\;-0.17 + 0.23x_i$",
"$1.05 - 0.49x_i,\\;0.18 + 0.13x_i$",
"$-1.25 - 2.24x_i,\\;-1.31 + 0.07x_i$",
"$0.34 + 1.84x_i,\\;-0.39 + 0.43x_i$",
"$1.53 - 0.39x_i - 1.30x_i^2,\\;-0.31 + 0.35x_i - 0.03x_i^2$",
"$0.13 + 0.28x_i,\\;-0.54 + 0.16x_i$",
"$-1.23 + 0.05x_i,\\;-1.55 + 1.39x_i$",
"$-0.51 + 0.06x_i,\\;-1.39 + 0.51x_i$",
"$0.26 - 1.68x_i,\\;0.55 + 0.46x_i$",
"$0.33 + 0.36x_i,\\;0.68 + 0.65x_i$",
"$-0.36 + 0.14x_i,\\;-0.93 + 0.25x_i$",
"$-0.04 + 0.17x_i,\\;-0.79 - 0.25x_i$",
"$-0.72 + 0.19x_i,\\;-0.18 - 0.20x_i$",
"$0.16 + 0.06x_i,\\;-1.30 + 0.48x_i$",
"$-1.34 - 1.83x_i,\\;-0.78 + 0.04x_i$",
"$-0.22 - 1.87x_i + 1.43x_i^2,\\;0.11 + 0.20x_i - 0.03x_i^2$",
"$-2.92 + 1.76x_i,\\;-0.71 + 0.39x_i$",
"$-2.19 + 0.27x_i,\\;-0.11 + 0.01x_i$",
"$-0.38 + 0.46x_i,\\;-0.24 + 0.18x_i$",
"$0.01 + 0.11x_i,\\;-0.61 + 0.19x_i$",
"$-0.04 + 0.17x_i,\\;-0.74 - 0.29x_i$",
"$-0.07 + 0.62x_i,\\;-0.44 + 0.15x_i$",
"$-0.42 + 0.10x_i,\\;0.23 - 0.99x_i$",
"$0.70 + 0.67x_i,\\;-1.61 + 0.19x_i$",
"$0.43 - 0.40x_i,\\;-0.54 + 0.11x_i$",
"$3.94 + 3.15x_i,\\;-1.28 - 1.36x_i,\\;-3.81 - 3.44x_i$",
"$-2.31 + 3.62x_i,\\;-1.12 + 0.63x_i$",
"$-0.04 + 0.06x_i,\\;-0.25 + 0.35x_i$",
"$0.12 + 0.67x_i,\\;-1.31 + 0.10x_i$"
),
  check.names = FALSE
)

kbl(
  df,
  booktabs = TRUE,
  longtable = TRUE,
  escape = FALSE
) %>%
  kable_styling(
    latex_options = c("hold_position", "repeat_header")
  ) %>%
  add_header_above(
    c(" " = 1,
      "Log mean count" = 8,
      "Log fold change" = 4)
  )


