#' Generate Coefficients for Simulated Differential ORF Translation
#'
#' @description
#' Simulates log-fold change coefficients for differential ORF translation
#' (DOT) analysis. This function assigns baseline coefficients to all ORFs
#' and applies specific log-fold changes to multi-cistronic ORFs based on a
#' selected regulatory scenario.
#'
#' @param orfs A data frame containing ORF annotations. Row names must align
#' with the filtered count matrices. Must include columns
#' \code{gene_id} and \code{orf_type}.
#'
#' @param scenario A character string specifying the regulatory scenario to
#' simulate. Must be one of:
#' \itemize{
#'     \item \code{"uORF_up_mORF_down"}
#'     \item \code{"dORF_up_mORF_up"}
#'     \item \code{"uORF_up_mORF_flat"}
#'     \item \code{"uORF_down_mORF_flat"}
#'     \item \code{"uORF_down_mORF_down"}
#'     \item \code{"uORF_down_mORF_up"}
#'     \item \code{"uORF_flat_mORF_up"}
#'     \item \code{"uORF_flat_mORF_down"}
#'     \item \code{"dORF_up_mORF_down"}
#'     \item \code{"dORF_up_mORF_flat"}
#'     \item \code{"dORF_down_mORF_flat"}
#'     \item \code{"dORF_down_mORF_down"}
#'     \item \code{"dORF_down_mORF_up"}
#'     \item \code{"dORF_flat_mORF_up"}
#'     \item \code{"dORF_flat_mORF_down"}
#' }
#'
#' @param gcoeff Numeric. The log-fold change magnitude to apply to regulated
#' ORFs.
#'
#' @param shape Numeric. Shape parameter for the gamma distribution used to
#' simulate baseline coefficients.
#'
#' @param scale Numeric. Scale parameter for the gamma distribution used to
#' simulate baseline coefficients.
#'
#' @return A list with the following components:
#' \describe{
#'     \item{gcoeffs_ribo}{Named numeric vector of log-fold change
#'     coefficients for ribosome profiling data.}
#'     \item{gcoeffs_rna}{Named numeric vector of baseline log-fold
#'     change coefficients for RNA-seq data.}
#'     \item{labels}{Named binary vector indicating truly regulated
#'     ORFs (\code{1}) vs. non-regulated (\code{0}).}
#' }
#'
#' @keywords internal
#' @examples
#' \dontrun{
#' generate_coefficients(orfs_df, scenario = "uORF_up_mORF_down")
#' }
#' 
generate_coefficients <- function(
    orfs,
    scenario = "uORF_up_mORF_down",
    gcoeff = 1.5,
    shape = 0.6,
    scale = 0.5
) {
    
    df <- orfs

    # Identify multi-cistronic genes
    gene_counts <- table(df$gene_id)
    multi_cistronic_genes <- names(gene_counts[gene_counts > 1])

    # Define the log-fold change for each ORF type per scenario
    scenario_lfc <- list(
        "uORF_up_mORF_up" = c(uORF = gcoeff, mORF = gcoeff, dORF = 0),
        "uORF_up_mORF_down" = c(uORF = gcoeff, mORF = -gcoeff, dORF = 0),
        "uORF_up_mORF_flat" = c(uORF = gcoeff, mORF = 0, dORF = 0),
        "uORF_down_mORF_flat" = c(uORF = -gcoeff, mORF = 0, dORF = 0),
        "uORF_down_mORF_down" = c(uORF = -gcoeff, mORF = -gcoeff, dORF = 0),
        "uORF_down_mORF_up" = c(uORF = -gcoeff, mORF = gcoeff, dORF = 0),
        "uORF_flat_mORF_up" = c(uORF = 0, mORF = gcoeff, dORF = 0),
        "uORF_flat_mORF_down" = c(uORF = 0, mORF = -gcoeff, dORF = 0),
        "dORF_up_mORF_up" = c(uORF = 0, mORF = gcoeff, dORF = gcoeff),
        "dORF_up_mORF_down" = c(uORF = 0, mORF = -gcoeff, dORF = gcoeff),
        "dORF_up_mORF_flat" = c(uORF = 0, mORF = 0, dORF = gcoeff),
        "dORF_down_mORF_flat" = c(uORF = 0, mORF = 0, dORF = -gcoeff),
        "dORF_down_mORF_down" = c(uORF = 0, mORF = -gcoeff, dORF = -gcoeff),
        "dORF_down_mORF_up" = c(uORF = 0, mORF = gcoeff, dORF = -gcoeff),
        "dORF_flat_mORF_up" = c(uORF = 0, mORF = gcoeff, dORF = 0),
        "dORF_flat_mORF_down" = c(uORF = 0, mORF = -gcoeff, dORF = 0)
    )

    if (!scenario %in% names(scenario_lfc)) {
        stop(
            "Invalid scenario specified.", 
            "Please choose from available scenarios."
        )
    }

    # Generate a baseline of coefficients for all ORFs
    gcoeffs <- rgamma(nrow(df), shape = shape, scale = scale) * sample(c(-1, 1), nrow(df), replace = TRUE)
    names(gcoeffs) <- rownames(df)
    gcoeffs_ribo <- gcoeffs

    # Initialize labels and a temporary list to track regulated ORFs
    labels <- rep(0, nrow(df))
    names(labels) <- rownames(df)
    regulated_orfs <- c()

    # Apply effects based on the scenario rules
    current_scenario_lfc <- scenario_lfc[[scenario]]
    for (orf_type in names(current_scenario_lfc)) {
        lfc <- current_scenario_lfc[orf_type]

        if (lfc != 0) {
            orf_ids <- rownames(df)[df$gene_id %in% multi_cistronic_genes & df$orf_type == orf_type]
            gcoeffs_ribo[orf_ids] <- gcoeffs_ribo[orf_ids] + lfc
            regulated_orfs <- c(regulated_orfs, orf_ids)
        }
    }

    labels[regulated_orfs] <- 1

    return(list(
        gcoeffs_ribo = gcoeffs_ribo,
        gcoeffs_rna = gcoeffs,
        labels = labels
    ))
}


#' Estimate zero-inflated negative binomial parameters from a real dataset
#'
#' @description
#' This function estimates the parameters of a zero-inflated negative binomial
#' distribution based on a real count data set using the method of moments.
#' The function also returns a spline fit of log mean to log size which can be
#' used when generating new simulated data.
#'
#' @param counts A matrix of counts.
#' 
#' @param threshold Only estimate parameters from transcripts with row means
#' greater than this threshold.
#' 
#' @param size_factor Optional numeric scalar to scale the estimated size
#' parameters.
#' 
#' @param min_size Optional numeric value to enforce a minimum size parameter.
#' 
#' @param scale_p0 Optional numeric scalar to scale the zero-inflation
#' probabilities.
#'
#' @return A list containing:
#'  \describe{
#'      \item{p0}{A vector of probabilities that the count will be zero,
#'   one for each gene/transcript.}
#'      \item{mu}{The estimated negative binomial mean by method of
#'   moments for the non-zero counts.}
#'      \item{size}{The estimated negative binomial size by method of
#'   moments for the non-zero counts.}
#'      \item{fit}{A fit relating log mean to log size for use in
#'   simulating new data.}
#'  }
#'
#' @author Jeff Leek (original), Chun Shen Lim (modifications)
#'
#' @importFrom stats smooth.spline
#'
#' @keywords internal
#'
get_params <- function(
    counts, 
    threshold = NULL, 
    size_factor = NULL, 
    min_size = NULL, 
    scale_p0 = NULL
) {
    
    if (!is.null(threshold)) {
        rowm <- rowMeans(counts)
        index1 <- which(rowm > threshold)
        counts <- counts[index1, ]
    }

    nsamples <- dim(counts)[2]
    counts0 <- counts == 0
    nn0 <- rowSums(!counts0)
    if (any(nn0 == 1)) {
        # need more than 1 nonzero count to estimate variance
        counts <- counts[nn0 > 1, ]
        nn0 <- nn0[nn0 > 1]
        counts0 <- counts == 0
    }
    mu <- rowSums((!counts0) * counts) / nn0
    s2 <- rowSums((!counts0) * (counts - mu)^2) / (nn0 - 1)
    size <- mu^2 / (s2 - mu + 0.0001)
    size <- ifelse(size > 0, size, min(size[size > 0]))
    p0 <- (nsamples - nn0) / nsamples

    if (!is.null(size_factor)) {
        size <- size * size_factor
    }

    if (!is.null(min_size)) {
        size[size < min_size] <- min_size
    }

    if (!is.null(scale_p0)) {
        p0 <- p0 * scale_p0
        p0[p0 > 1] <- 1
    }

    lsize <- log(size)
    lmu <- log(mu + 0.0001)
    fit <- smooth.spline(lsize ~ lmu)
    return(list(p0 = p0, mu = mu, size = size, fit = fit))
}


#' Generate a simulated data set based on known model parameters
#'
#' @param mu Baseline mean expression for negative binomial model.
#' 
#' @param fit Fitted relationship between log mean and log size.
#' 
#' @param p0 A vector of the probabilities a count is zero.
#' 
#' @param m Number of genes/transcripts to simulate (not necessary if
#' \code{mod}, \code{beta} are specified).
#' 
#' @param n Number of samples to simulate (not necessary if \code{mod},
#' \code{beta} are specified).
#' 
#' @param mod Model matrix you would like to simulate from without an 
#' intercept.
#' 
#' @param beta Set of coefficients for the model matrix (must have same 
#' number of columns as \code{mod}).
#'
#' @return A list containing:
#' \describe{
#'     \item{counts}{Data matrix with counts for genes in rows and
#'     samples in columns.}
#' }
#'
#' @author Jeff Leek
#'
#' @importFrom stats rnbinom
#'
#' @keywords internal
#' 
create_read_numbers <- function(
    mu,
    fit,
    p0,
    m = NULL,
    n = NULL,
    mod = NULL,
    beta = NULL
) {
    
    if (is.null(mod) | is.null(beta)) {
        message("Generating data from baseline model.\n")
        # if (is.null(m) | is.null(n)) {
        #     stop("create_read_numbers error: if you don't specify
        #     mod and beta, you must specify m and n.\n")
        # }
        index <- sample(seq_len(length(mu)), size = m)
        mus <- mu[index]
        p0s <- p0[index]
        mumat <- log(mus + 0.001) %*% t(rep(1, n))
    } else {
        m <- dim(beta)[1]
        n <- dim(mod)[1]
        index <- sample(seq_len(length(mu)), size = m)
        mus <- mu[index]
        p0s <- p0[index]

        ind <- !apply(mod, 2, function(x) {
            all(x == 1)
        })
        mod <- cbind(mod[, ind])
        beta <- cbind(beta[, ind])
        mumat <- log(mus + 0.001) + beta %*% t(mod)
    }

    muvec <- as.vector(mumat)
    sizevec <- predict(fit, muvec)$y
    sizemat <- matrix(sizevec, nrow = m)
    counts <- sizemat * NA
    for (i in seq_len(m)) {
        counts[i, ] <- rbinom(n, prob = (1 - p0s[i]), size = 1) *
            rnbinom(n, mu = exp(mumat[i, ]), size = exp(sizemat[i, ]))
    }
    return(counts)
}


#' Simulate Differential ORF Translation (DOT)
#'
#' @description
#' Simulates ribosome profiling and matched RNA-seq count matrices with
#' specified differential ORF translation (DOT) effects. The simulation can
#' include batch effects and supports multiple experimental conditions and
#' replicates.
#'
#' @param ribo A matrix or data frame of ribosome profiling counts
#' (genes x samples).
#' 
#' @param rna A matrix or data frame of RNA-seq counts (genes x samples).
#' 
#' @param annotation A GRanges object with ORF level annotation, 
#' typically obtained from \code{\link{getORFs}}.
#' 
#' @param regulation_type Character. Specifies the type of DOT effect to
#' simulate. Passed to the \code{scenario} argument of
#' \code{generate_coefficients}.
#' 
#' @param te_genes Numeric. Percentage of genes to be assigned as
#' differentially translated (default: 10).
#' 
#' @param bgenes Numeric. Percentage of genes to carry a batch effect
#' (default: 10).
#' 
#' @param num_samples Integer. Number of biological replicates per condition
#' (default: 2).
#' 
#' @param conditions Integer. Number of experimental conditions (default: 2).
#' 
#' @param gcoeff Numeric. Magnitude of log-fold change for DOT effects
#' (default: 1.5).
#' 
#' @param bcoeff Numeric. Magnitude of batch effect coefficient (default: 0.9).
#' 
#' @param num_batches Integer. Number of batches (default: 2).
#' 
#' @param size_factor Numeric scalar. A multiplicative factor applied to the
#' estimated size parameter (\eqn{r}) for all transcripts. Since
#' dispersion \eqn{\phi = 1/r}, a value greater than 1 (e.g., 1.5) will
#' decrease biological dispersion (noise), making the simulated data
#' less variable. A value less than 1 will increase dispersion
#' (default: 1.5).
#' 
#' @param min_size Numeric scalar. A lower bound for the modified size
#' parameter (\eqn{r}). Any transcript whose modified \eqn{r} falls
#' below this value will be set to \code{min_size}. This caps maximum
#' dispersion and prevents unrealistic variability (default: 5).
#' 
#' @param scale_p0 Optional numeric scalar to scale the zero-inflation
#' probabilities.
#' 
#' @param shape Numeric. Shape parameter for gamma distribution used to
#' simulate baseline coefficients (default: 0.6).
#' 
#' @param scale Numeric. Scale parameter for gamma distribution used to
#' simulate baseline coefficients (default: 0.5).
#' 
#' @param batch_scenario Character. Specifies the batch effect design. Must 
#' be one of:
#' \itemize{
#'     \item \code{"balanced"}
#'     \item \code{"confounded"}
#'     \item \code{"random"}
#'     \item \code{"unbalanced"}
#'     \item \code{"nested"}
#'     \item \code{"modality_specific"}
#' }
#' @param diagplot_ribo Logical. If \code{TRUE}, generate diagnostic plots
#' for ribo data (default: \code{FALSE}).
#' @param diagplot_rna Logical. If \code{TRUE}, generate diagnostic plots
#' for RNA data (default: \code{FALSE}).
#'
#' @return A \code{\link{DOTSeqDataSets-class}} object containing:
#' \describe{
#'     \item{DOU}{
#'         A \code{\link{DOUData-class}} object containing simulated count 
#'         matrix (\code{assay} slot), sample metadata (\code{colData} slot), and ORF-level 
#'         annotation (\code{rowRanges} slot). The \code{rowRanges} slot also stores 
#'         labels (named binary vector indicating true positive (1)), and 
#'         logFC (log-fold changes for the simulated DOU effect) for modeling 
#'         Differential ORF Usage (DOU).
#'     }
#'     \item{DTE}{
#'         A \code{\link{DTEData-class}} object used for modeling 
#'         Differential Translation Efficiency (DTE). Stores all data above
#'         except for \code{rowRanges}
#'     }
#' }
#'
#' @importFrom stats prcomp rgamma runif df model.matrix
#' @import SummarizedExperiment
#' @importFrom S4Vectors mcols mcols<-
#'
#' @export
#'
#' @examples
#' library(SummarizedExperiment)
#' dir <- system.file("extdata", package = "DOTSeq")
#'
#' cnt <- read.table(file.path(dir, "featureCounts.cell_cycle_subset.txt.gz"),
#'     header = TRUE, comment.char = "#"
#' )
#' names(cnt) <- gsub(".*(SRR[0-9]+).*", "\\1", names(cnt))
#'
#' flat <- file.path(dir, "gencode.v47.orf_flattened_subset.gtf.gz")
#' bed <- file.path(dir, "gencode.v47.orf_flattened_subset.bed.gz")
#'
#' meta <- read.table(file.path(dir, "metadata.txt.gz"))
#' names(meta) <- c("run", "strategy", "replicate", "treatment", "condition")
#' cond <- meta[meta$treatment == "chx", ]
#' cond$treatment <- NULL
#'
#' d <- DOTSeqDataSetsFromFeatureCounts(
#'     count_table = cnt,
#'     condition_table = cond,
#'     flattened_gtf = flat,
#'     flattened_bed = bed
#' )
#' raw_counts <- assay(getDOU(d))
#' raw_counts <- raw_counts[, grep("Cycling|Interphase",
#'     colnames(raw_counts))]
#' ribo <- raw_counts[, grep("ribo", colnames(raw_counts))]
#' rna <- raw_counts[, grep("rna", colnames(raw_counts))]
#' rowranges <- rowRanges(getDOU(d))
#' r <- "uORF_up_mORF_down"
#' g <- 1.5
#' d <- simDOT(
#'     ribo,
#'     rna,
#'     annotation = rowranges,
#'     regulation_type = r,
#'     gcoeff = g,
#'     num_samples = 1,
#'     num_batches = 2
#' )
#'
#' show(d)
#' 
#' rowData(getDOU(d))
#' 
#' @references
#' Frazee, A. C., Jaffe, A. E., Langmead, B., & Leek, J. T. (2015). 
#' Polyester: Simulating RNA-seq datasets with differential transcript 
#' expression. Bioinformatics, 31(17), 2778-2784. 
#' DOI: 10.1093/bioinformatics/btv272
#' 
#' Chothani, S., Adami, E., Ouyang, J. F., Viswanathan, S., Hubner, N., 
#' Cook, S. A., Schafer, S., Rackham, O. J. L.  (2019). deltaTE: Detection 
#' of translationally regulated genes by integrative analysis of Ribo-seq 
#' and RNA-seq data. Current Protocols in Molecular Biology, 129, e108. 
#' DOI: 10.1002/cpmb.108
#' 
simDOT <- function(
    ribo,
    rna,
    annotation = NULL,
    regulation_type = NULL,
    te_genes = 10,
    bgenes = 10,
    num_samples = 2,
    conditions = 2,
    gcoeff = 1.5,
    bcoeff = 0.9,
    num_batches = 2,
    size_factor = NULL,
    min_size = NULL,
    scale_p0 = NULL,
    shape = 0.6,
    scale = 0.5,
    batch_scenario = "balanced",
    diagplot_ribo = FALSE,
    diagplot_rna = FALSE
) {
    
    # Ensure all inputs are matrices and have consistent gene names
    ribo <- as.matrix(ribo)
    rna <- as.matrix(rna)

    # Find genes that exist in ALL three input objects
    original_genes <- intersect(rownames(ribo), rownames(rna))
    if (!is.null(regulation_type) & !is.null(annotation)) {
        common_genes_all <- intersect(original_genes, names(annotation))
    } else {
        common_genes_all <- original_genes
    }
    if (length(common_genes_all) == 0) {
        stop("No common genes found across all three input datasets (ribo, rna, annotation).")
    }

    # Filter all inputs to this common set of genes, keeping them as matrices
    ribo_subset <- ribo[common_genes_all, , drop = FALSE]
    rna_subset <- rna[common_genes_all, , drop = FALSE]
    if (!is.null(regulation_type) & !is.null(annotation)) {
        message("simulating Differential ORF Usage (DOU)")
        # Clean up GRanges
        cols_to_keep <- c("gene_id", "orf_type")
        mcols(annotation) <- mcols(annotation)[, cols_to_keep, drop = FALSE]
        
        orfs <- data.frame(row.names = names(annotation), gene_id = mcols(annotation)$gene_id, orf_type = mcols(annotation)$orf_type)
        
        # Create a dataframe for filtered ORFs 
        orfs_filtered <- annotation[common_genes_all, , drop = FALSE]
        orfs_filtered <- data.frame(row.names = names(orfs_filtered), gene_id = mcols(orfs_filtered)$gene_id, orf_type = mcols(orfs_filtered)$orf_type)
        
    } else {
        message("simulating Differential Translation Efficiency (DTE)")
    }

    # Perform the filtering based on count thresholds
    ribo_pass <- apply(ribo_subset, 1, function(x) length(x[x > 5]) >= 2)
    rna_pass <- apply(rna_subset, 1, function(x) length(x[x > 5]) >= 2)

    # Find the names of genes that passed the filter
    common_filtered_genes <- intersect(names(which(ribo_pass)), names(which(rna_pass)))

    if (length(common_filtered_genes) == 0) {
        stop("No genes passed the filtering criteria in both ribo and rna datasets.")
    }

    # Final alignment of all data to the filtered gene list
    counts_ribo_filtered <- ribo_subset[common_filtered_genes, , drop = FALSE]
    counts_rna_filtered <- rna_subset[common_filtered_genes, , drop = FALSE]
    if (!is.null(regulation_type) & !is.null(annotation)) {
        orfs_filtered <- orfs_filtered[common_filtered_genes, , drop = FALSE]
    } 

    total_samples <- num_samples * conditions * num_batches
    group <- rep(rep(seq(0, conditions - 1), each = num_samples), num_batches)

    # Explicitly convert group to factor for correct model.matrix behavior
    group <- as.factor(group)

    # Batch assignment logic
    if (batch_scenario == "balanced") {
        message("using batch_scenario: ", batch_scenario)
        batch <- rep(seq(1, num_batches), each = num_samples * conditions)
        batch <- as.factor(batch)
        if (num_batches == 1) {
            mod <- model.matrix(~ -1 + group)
        } else {
            mod <- model.matrix(~ -1 + batch + group)
        }
    } else if (batch_scenario == "confounded") {
        message("using batch_scenario: ", batch_scenario)
        if (num_batches > 1) {
            # In a confounded design, batch and group are the same factor
            batch <- group
            mod <- model.matrix(~ -1 + group)
        } else {
            stop("Use num_batches > 1 for confounded to generate simulated data")
        }
    } else if (batch_scenario == "random") {
        if (num_batches > 1) {
            message("using batch_scenario: ", batch_scenario)
            batch <- sample(seq(1, num_batches), total_samples, replace = TRUE)
            batch <- as.factor(batch)
            mod <- model.matrix(~ -1 + batch + group)
        } else {
            stop("Use num_batches > 1 for random to generate simulated data")
        }
    } else if (batch_scenario == "unbalanced") {
        if (num_batches > 1) {
            message("using batch_scenario: ", batch_scenario)
            major <- rep(1, round(0.6 * total_samples))
            minor <- sample(seq(2, num_batches), total_samples - length(major), replace = TRUE)
            batch <- c(major, minor)
            batch <- as.factor(batch)
            mod <- model.matrix(~ -1 + batch + group)
        } else {
            stop("Use num_batches > 1 for unbalanced to generate simulated data")
        }
    } else if (batch_scenario == "nested") {
        message("using batch_scenario: ", batch_scenario)
        # Create a single factor representing the nested structure
        nested_factor <- as.factor(paste0(
            "C", rep(seq(0, conditions - 1), each = num_samples * num_batches),
            "_B", rep(rep(seq(1, num_batches), each = num_samples), conditions)
        ))
        mod <- model.matrix(~ -1 + nested_factor)
        batch <- nested_factor
        # Group and replicate info are for colData only in this scenario
        group <- as.factor(rep(rep(seq(0, conditions - 1), each = num_samples), num_batches))
    } else if (batch_scenario == "modality_specific") {
        if (num_batches > 1) {
            message("using batch_scenario: ", batch_scenario)
            # Create a consistent batch vector for both modalities
            batch <- rep(seq_len(num_batches), each = num_samples * conditions)
            batch <- as.factor(batch)
            mod_ribo <- model.matrix(~ -1 + batch + group)
            mod_rna <- model.matrix(~ -1 + group)
        } else {
            stop("num_batches must be > 1 for modality_specific scenario.")
        }
    }

    params_ribo <- get_params(
        counts_ribo_filtered, 
        size_factor = size_factor, 
        min_size = min_size, 
        scale_p0 = scale_p0
    )
    params_rna <- get_params(
        counts_rna_filtered, 
        size_factor = size_factor, 
        min_size = min_size, 
        scale_p0 = scale_p0
    )

    if (!is.null(regulation_type) & !is.null(annotation)) {
        coeffs_list <- generate_coefficients(
            orfs = orfs_filtered,
            scenario = regulation_type,
            gcoeff = gcoeff,
            shape = shape,
            scale = scale
        )
        gcoeffs_ribo <- coeffs_list$gcoeffs_ribo
        gcoeffs_rna <- coeffs_list$gcoeffs_rna
        labels <- coeffs_list$labels

        bgenes <- round(bgenes * nrow(counts_ribo_filtered) / 100)
        bselect <- if (bgenes > 0) sample(seq_len(nrow(counts_ribo_filtered)), bgenes) else c()
        bcoeffs <- rep(0, nrow(counts_ribo_filtered))
        if (length(bselect) > 0) bcoeffs[bselect] <- bcoeff * sample(c(1, -1), bgenes, replace = TRUE)
    } else {
        dTE_genes <- round(te_genes * nrow(counts_ribo_filtered) / 100)
        bgenes <- round(bgenes * nrow(counts_ribo_filtered) / 100)

        gcoeffs_rna <- rgamma(nrow(counts_ribo_filtered), shape = shape, scale = scale) * sample(c(-1, 1), nrow(counts_ribo_filtered), replace = TRUE)
        select <- sample(seq_len(nrow(counts_ribo_filtered)), dTE_genes)

        bcoeffs <- rep(0, nrow(counts_ribo_filtered))
        bselect <- if (bgenes <= length(select)) sample(select, bgenes) else c(select, sample(setdiff(seq_len(nrow(counts_ribo_filtered)), select), bgenes - length(select)))
        bcoeffs[bselect] <- bcoeff * sample(c(1, -1), bgenes, replace = TRUE)

        gcoeffs_ribo <- gcoeffs_rna
        gcoeffs_ribo[select] <- gcoeffs_rna[select] + sample(c(gcoeff, -gcoeff), length(select), replace = TRUE)

        labels <- rep(0, nrow(counts_ribo_filtered))
        names(labels) <- rownames(counts_ribo_filtered)
        labels[select] <- 1
    }

    if (batch_scenario == "confounded" && length(bselect) > 0) {
        # The batch effect and condition effect are the same.
        gcoeffs_ribo[bselect] <- gcoeffs_ribo[bselect] + bcoeffs[bselect]
        gcoeffs_rna[bselect] <- gcoeffs_rna[bselect] + bcoeffs[bselect]
    }

    # Build the beta matrices using the list output and bcoeffs
    if (batch_scenario == "modality_specific") {
        # Create beta matrices with correct dimensions for each modality
        coeffs_ribo <- matrix(0, nrow = nrow(counts_ribo_filtered), ncol = ncol(mod_ribo))
        colnames(coeffs_ribo) <- colnames(mod_ribo)
        batch_cols <- grep("^batch", colnames(mod_ribo))
        group_cols <- grep("^group", colnames(mod_ribo))

        # Apply batch coefficients only to ribo data
        if (length(bselect) > 0) {
            for (i in seq_along(batch_cols)) {
                coeffs_ribo[bselect, batch_cols[i]] <- bcoeffs[bselect] * runif(length(bselect), 0.8, 1.2)
            }
        }
        coeffs_ribo[, group_cols] <- gcoeffs_ribo

        # RNA-seq data has no batch effect
        coeffs_rna <- matrix(0, nrow = nrow(counts_rna_filtered), ncol = ncol(mod_rna))
        colnames(coeffs_rna) <- colnames(mod_rna)
        group_cols <- grep("^group", colnames(mod_rna))
        coeffs_rna[, group_cols] <- gcoeffs_rna
    } else if (batch_scenario == "nested") {
        # Assign coefficients to the nested factor columns directly
        coeffs_ribo <- matrix(0, nrow = nrow(counts_ribo_filtered), ncol = ncol(mod))
        colnames(coeffs_ribo) <- colnames(mod)
        coeffs_rna <- matrix(0, nrow = nrow(counts_rna_filtered), ncol = ncol(mod))
        colnames(coeffs_rna) <- colnames(mod)

        for (col in colnames(mod)) {
            coeffs_ribo[, col] <- gcoeffs_ribo
            coeffs_rna[, col] <- gcoeffs_rna
        }
        # Note: The batch effect is integrated into the group effect in this scenario.
        # gcoeffs already contain the combined effect.
    } else {
        # This covers balanced, random, unbalanced, and confounded (which now only has `group`)
        coeffs_ribo <- matrix(0, nrow = nrow(counts_ribo_filtered), ncol = ncol(mod))
        colnames(coeffs_ribo) <- colnames(mod)

        coeffs_rna <- matrix(0, nrow = nrow(counts_rna_filtered), ncol = ncol(mod))
        colnames(coeffs_rna) <- colnames(mod)

        # Assign batch coefficients
        batch_cols <- grep("^batch", colnames(mod))
        if (length(bselect) > 0 && length(batch_cols) > 0) {
            # Assign coefficients per batch column
            for (i in seq_along(batch_cols)) {
                coeffs_ribo[bselect, batch_cols[i]] <- bcoeffs[bselect] * runif(length(bselect), 0.8, 1.2)
                coeffs_rna[bselect, batch_cols[i]] <- bcoeffs[bselect] * runif(length(bselect), 0.8, 1.2)
            }
        }

        # Assign group coefficients (DOT effect)
        group_cols <- grep("^group", colnames(mod))
        if (length(group_cols) > 0) {
            coeffs_ribo[, group_cols] <- gcoeffs_ribo
            coeffs_rna[, group_cols] <- gcoeffs_rna
        }
    }

    if (batch_scenario == "modality_specific") {
        sim_ribo_filtered <- create_read_numbers(
            params_ribo$mu, 
            params_ribo$fit, 
            params_ribo$p0, 
            beta = coeffs_ribo, 
            mod = mod_ribo
        )
        sim_rna_filtered <- create_read_numbers(
            params_rna$mu, 
            params_rna$fit, 
            params_rna$p0, 
            beta = coeffs_rna, 
            mod = mod_rna
        )
    } else {
        sim_ribo_filtered <- create_read_numbers(
            params_ribo$mu, 
            params_ribo$fit, 
            params_ribo$p0, 
            beta = coeffs_ribo, 
            mod = mod
        )
        sim_rna_filtered <- create_read_numbers(
            params_rna$mu, 
            params_rna$fit, 
            params_rna$p0, 
            beta = coeffs_rna, 
            mod = mod
        )
    }

    final_cols_ribo <- paste0(
        "sample", 
        seq_len(total_samples), 
        ".condition", 
        group, 
        ".batch", 
        batch, 
        ".ribo"
    )
    final_cols_rna <- paste0(
        "sample", 
        seq_len(total_samples), 
        ".condition", 
        group, 
        ".batch", 
        batch, 
        ".rna"
    )

    sim_ribo_full <- matrix(
        0, 
        nrow = length(original_genes), 
        ncol = total_samples, 
        dimnames = list(original_genes, final_cols_ribo)
    )
    sim_rna_full <- matrix(
        0, 
        nrow = length(original_genes), 
        ncol = total_samples, 
        dimnames = list(original_genes, final_cols_rna)
    )
    
    sim_logfc <- matrix(
        0, 
        nrow = length(original_genes), 
        ncol = 1, 
        dimnames = list(original_genes, "logFC")
    )

    sim_ribo_full[common_filtered_genes, ] <- sim_ribo_filtered
    sim_rna_full[common_filtered_genes, ] <- sim_rna_filtered
    
    sim_logfc[common_filtered_genes, ] <- gcoeffs_ribo - gcoeffs_rna

    uorf_ids <- rownames(orfs[orfs$orf_type == "uORF", ])
    dorf_ids <- rownames(orfs[orfs$orf_type == "dORF", ])

    scale_uorf <- 0.1
    scale_dorf_ribo <- 0.01
    scale_dorf_rna <- 0.1
    
    sim_ribo_full[uorf_ids, ] <- sim_ribo_full[uorf_ids, ] * scale_uorf
    sim_ribo_full[dorf_ids, ] <- sim_ribo_full[dorf_ids, ] * scale_dorf_ribo

    sim_rna_full[uorf_ids, ] <- sim_rna_full[uorf_ids, ] * scale_uorf
    sim_rna_full[dorf_ids, ] <- sim_rna_full[dorf_ids, ] * scale_dorf_rna

    merged <- cbind(sim_ribo_full, sim_rna_full)
    merged <- round(merged)
    storage.mode(merged) <- "integer"

    replicate <- rep(rep(seq(1, num_samples), conditions), num_batches)

    # Ensure batch vector is correct for colData
    if (batch_scenario == "modality_specific") {
        batch_rna <- rep("none", total_samples)
        batch_ribo <- as.character(batch)
        batch_full <- c(batch_ribo, batch_rna)
    } else {
        batch_full <- as.character(rep(batch, times = 2))
    }

    coldata <- data.frame(
        run = colnames(merged),
        condition = factor(rep(group, 2)),
        replicate = factor(rep(replicate, 2)),
        strategy = factor(rep(c("ribo", "rna"), each = total_samples)),
        batch = factor(batch_full)
    )

    final_labels <- rep(0, length(original_genes))
    names(final_labels) <- original_genes
    if (!is.null(regulation_type)) {
        final_labels[common_filtered_genes] <- labels
    } else { # DTE
        final_labels[names(labels)] <- labels
    }
    
    # Store simData in DOTSeqDataSets
    mcols(annotation)$status <- final_labels
    mcols(annotation)$logFC <- sim_logfc[, 1]
    d <- DOTSeqDataSetsFromSummarizeOverlaps(
        count_table = as.data.frame(merged), 
        condition_table = coldata, 
        annotation = annotation
    )

    if (isTRUE(diagplot_ribo)) {
        tryCatch(
            {
                plot_pca(
                    gcoeff = gcoeff, 
                    bcoeff = bcoeff,
                    batch_scenario = batch_scenario,
                    num_batches = num_batches,
                    countdata = merged, 
                    coldata = coldata, 
                    strategy = "ribo",
                    formula1 = ~strategy, 
                    formula2 = ~ condition + batch
                )
            },
            error = function(e) {
            }
        )
    }

    if (isTRUE(diagplot_rna)) {
        tryCatch(
            {
                plot_pca(
                    gcoeff = gcoeff, 
                    bcoeff = bcoeff,
                    batch_scenario = batch_scenario,
                    num_batches = num_batches,
                    countdata = merged, 
                    coldata = coldata, 
                    strategy = "rna",
                    formula1 = ~strategy, 
                    formula2 = ~ condition + batch
                )
            },
            error = function(e) {
            }
        )
    }

    return(d)
}


#' Plot PCA for simulated RNA-seq or Ribo-seq data
#'
#' @description
#' Generates a PCA plot from simulated RNA-seq or Ribo-seq count data, 
#' highlighting sample-level variation across conditions and batches. 
#' The function applies variance-stabilizing transformation (VST), 
#' performs principal component analysis (PCA), and dynamically adjusts 
#' plot margins to accommodate legends based on device size and label width.
#'
#' @param gcoeff Numeric. Magnitude of log-fold change for DOT effects.
#' 
#' @param bcoeff Numeric vector. Batch effect coefficients.
#' 
#' @param batch_scenario Character. Describes the batch effect design 
#' (e.g., \code{"balanced"}, \code{"confounded"}).
#' 
#' @param num_batches Integer. Number of batches.
#' 
#' @param countdata A matrix or data frame of raw counts (genes x samples).
#' 
#' @param coldata A data frame containing sample metadata. Must include 
#' \code{condition}, \code{batch}, and \code{strategy} columns.
#' 
#' @param strategy Character string. Specifies which strategy to plot 
#' (e.g., \code{"rna"} or \code{"ribo"}).
#' 
#' @param formula1 A formula object specifying the initial design for 
#' DESeq2 object construction. Default is \code{~strategy}.
#' 
#' @param formula2 A formula object specifying the design for PCA modeling. 
#' Default is \code{~ condition + batch}.
#'
#' @return A PCA plot is rendered to the active graphics device. The plot 
#' includes sample points colored by condition and shaped by batch, with 
#' legends placed dynamically to avoid overlap.
#'
#' @importFrom DESeq2 DESeqDataSetFromMatrix design<- 
#' @importFrom DESeq2 varianceStabilizingTransformation
#' @importFrom stats prcomp
#' @importFrom grDevices dev.cur dev.new dev.size
#' 
#' @keywords internal
#' 
plot_pca <- function(
        gcoeff, 
        bcoeff,
        batch_scenario,
        num_batches,
        countdata, 
        coldata, 
        strategy = c("ribo", "rna"),
        formula1 = ~strategy, 
        formula2 = ~ condition + batch
) {
    
    dds <- DESeqDataSetFromMatrix(
        countData = round(countdata),
        colData = coldata,
        design = formula1
    )
    
    dds <- dds[, dds$strategy == strategy]
    design(dds) <- formula2
    
    # vst <- vst(dds, blind = FALSE)
    vst <- varianceStabilizingTransformation(dds, blind = FALSE)
    pca <- prcomp(t(assay(vst)))
    
    # Ensure a graphics device is open
    if (dev.cur() == 1) dev.new()
    
    # Get current margin and device size
    current_mar <- par("mar")
    dev_dims <- dev.size("in") # width, height in inches
    
    max_legend_width <- max(strwidth(unique(colData(vst)$condition), units = "inches"))
    extra_margin <- max(2, ceiling(max_legend_width * 2.5))  # scale factor can be tuned
    par(xpd = TRUE, mar = current_mar + c(0, 0, 0, extra_margin))
    legend_inset <- c(1 + max_legend_width / dev_dims[1], 0)
    legend_outside <- TRUE
    
    # PCA plot
    percentVar <- round(100 * pca$sdev^2 / sum(pca$sdev^2))
    colors <- as.numeric(as.factor(colData(vst)$condition))
    shapes <- as.numeric(as.factor(colData(vst)$batch))
    
    if (strategy == "rna") {
        fig_head <- "RNA-seq" 
    } else if (strategy == "ribo") {
        fig_head <- "Ribo-seq"
    }
    plot(
        x = pca$x[, 1],
        y = pca$x[, 2],
        col = colors,
        pch = shapes,
        xlab = paste0("PC1: ", percentVar[1], "% variance"),
        ylab = paste0("PC2: ", percentVar[2], "% variance"),
        main = paste0(
            "PCA of ", fig_head, " (gcoeff: ", gcoeff, ")\n", batch_scenario,
            " (num_batches: ", num_batches, ", bcoeff: ", paste(bcoeff, collapse = ", "), ")"
        )
    )
    
    # Condition legend
    legend(
        "topleft",
        inset = legend_inset,
        xpd = legend_outside,
        bty = "n",
        legend = unique(as.data.frame.array(colData(vst))$condition),
        col = unique(colors),
        pch = 16,
        title = "Condition"
    )
    
    # Batch legend
    legend(
        "bottomleft",
        inset = legend_inset,
        xpd = legend_outside,
        bty = "n",
        legend = unique(as.data.frame.array(colData(vst))$batch),
        pch = unique(shapes),
        col = "black",
        title = "Batch"
    )
    
    # Reset margins
    par(xpd = FALSE, mar = c(5, 4, 4, 2) + 0.1)
}
