## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
    collapse = TRUE,
    comment = "#>",
    fig.width = 7,
    fig.height = 5,
    fig.align = "center",
    warning = FALSE
)
set.seed(42)
par(bty = "n")

## ----libraries, message = FALSE-----------------------------------------------
library(yaap)
library(generics)
library(ggplot2)
library(ggtern)
library(recipes)
library(tune)

## ----load-usarrests-----------------------------------------------------------
arrests <- USArrests
state <- rownames(arrests)
region <- state.region[match(state, state.name)]
analysis_K <- 3
region_cols <- c(
    "Northeast" = "#0072B2",
    "South" = "#D55E00",
    "North Central" = "#009E73",
    "West" = "#CC79A7"
)

head(data.frame(state = state, region = region, arrests, row.names = NULL))

## ----recipe-fit---------------------------------------------------------------
rec_base <- recipe(~ ., data = arrests) |>
    step_normalize(all_numeric())

rec <- rec_base |>
    step_archetypes(
        all_numeric(),
        num_comp = analysis_K,
        seed = 42,
        options = list(
            init = "furthest_sum",
            max_iter = 100,
            tol_r2 = 0
        )
    )

rec_prep <- prep(rec, training = arrests)
rec_prep

## ----recipe-bake--------------------------------------------------------------
baked <- bake(rec_prep, new_data = arrests)
head(data.frame(state = state, region = region, baked, row.names = NULL))

aa_cols <- paste0("A", seq_len(analysis_K))
baked[, aa_cols] |> as.matrix() |> rowSums() |> range()

## ----recipe-keep-original-----------------------------------------------------
rec_keep <- rec_base |>
    step_archetypes(
        all_numeric(),
        num_comp = analysis_K,
        keep_original_cols = TRUE,
        seed = 42,
        options = list(
            max_iter = 100,
            tol_r2 = 0
        )
    ) |>
    prep(training = arrests)

arrests_with_originals <- data.frame(
    state = state,
    region = region,
    bake(rec_keep, new_data = arrests),
    row.names = NULL
)

head(arrests_with_originals)

## ----compare-pca-aa-----------------------------------------------------------
rec_aa <- rec_base |>
    step_archetypes(
        all_numeric(),
        num_comp = analysis_K,
        seed = 42,
        options = list(
            init = "furthest_sum",
            max_iter = 100,
            tol_r2 = 0
        )
    ) |>
    prep(training = arrests)

rec_pca <- rec_base |>
    step_pca(all_numeric(), num_comp = analysis_K) |>
    prep(training = arrests)

aa_scores <- bake(rec_aa, new_data = arrests)
pca_scores <- bake(rec_pca, new_data = arrests)

score_data <- data.frame(
    state = state,
    region = region,
    pca_scores,
    aa_scores,
    row.names = NULL
)
head(score_data)

## ----comparison-labels, include=FALSE-----------------------------------------
aa_cols <- grep("^A[0-9]+$", names(aa_scores), value = TRUE)
aa_mat <- as.matrix(score_data[, aa_cols])

simplex_label_ix <- vapply(seq_along(aa_cols), function(j) {
    which.max(aa_mat[, j])
}, integer(1L))
pca_label_ix <- unique(c(
    which.min(score_data$PC1),
    which.max(score_data$PC1),
    which.min(score_data$PC2),
    which.max(score_data$PC2)
))
label_ix <- unique(c(pca_label_ix, simplex_label_ix))

score_plot_data <- transform(
    score_data,
    label = ifelse(seq_len(nrow(score_data)) %in% label_ix, state, NA)
)

## ----compare-pca-aa-tidy------------------------------------------------------
tidy(rec_pca, number = 2, type = "variance")
tidy(rec_aa, number = 2)

## ----pca-score-plot, fig.width = 6, fig.height = 5, fig.cap = "PCA scores for USArrests. Labels mark states selected as PCA extremes or AA simplex-vertex representatives."----
pca_rng <- extendrange(c(score_plot_data$PC1, score_plot_data$PC2))
ggplot(score_plot_data, aes(PC1, PC2, colour = region)) +
    geom_hline(yintercept = 0, colour = "grey85") +
    geom_vline(xintercept = 0, colour = "grey85") +
    geom_point(size = 1.8) +
    geom_text(
        aes(label = label),
        na.rm = TRUE,
        nudge_y = diff(pca_rng) * 0.035,
        size = 3,
        show.legend = FALSE
    ) +
    coord_equal(xlim = pca_rng, ylim = pca_rng) +
    scale_colour_manual(values = region_cols, drop = FALSE) +
    labs(title = "step_pca()", x = "PC1", y = "PC2", colour = "Region") +
    theme_minimal(base_size = 10) +
    theme(panel.grid.minor = element_blank())

## ----aa-simplex-plot, fig.width = 6, fig.height = 5, fig.cap = "AA composition weights for USArrests on the archetype simplex. Labels match the PCA plot for comparison."----
ggtern(score_plot_data, aes(A1, A2, A3, colour = region)) +
    geom_point(size = 1.8) +
    geom_text(
        aes(label = label),
        na.rm = TRUE,
        position = "identity",
        size = 3,
        show.legend = FALSE
    ) +
    scale_colour_manual(values = region_cols, drop = FALSE) +
    labs(
        title = "step_archetypes()",
        x = "A1",
        y = "A2",
        z = "A3",
        colour = "Region"
    ) +
    theme_minimal(base_size = 10) +
    theme_showarrows() +
    theme(
        legend.position = "bottom",
        plot.margin = margin(8, 16, 8, 16),
        tern.panel.mask.show = FALSE
    )

## ----tunable------------------------------------------------------------------
rec_tune <- rec_base |>
    step_archetypes(
        all_numeric(),
        num_comp = tune(),
        delta = tune()
    )

tunable(rec_tune$steps[[2L]])

