# FUNCTONS FOR PRINTING/SUMMARIZING ---------------------------------------

#' @rdname estimate_dm
#' @export
print.fits_ids_dm <- function(x, ...) {
  fits_ids <- x

  # ifs to ensure backward compatibility (deprecated)
  if (!is.null(fits_ids$drift_dm_fit_info$fit_procedure_name)) {
    cat("Fit procedure name:", fits_ids$drift_dm_fit_info$fit_procedure_name)
    cat("\n")
    print_classes(
      class(fits_ids$drift_dm_fit_info$drift_dm_obj),
      header = "Fitted model type:"
    )

    time_call <- fits_ids$drift_dm_fit_info$time_call
    if (!is.null(time_call)) {
      cat("Time of (last) call:", time_call)
    }
    cat("\n")

    cat("N Individuals:", length(fits_ids$all_fits))
    cat("\n")
  } else {
    sum_obj <- summary(fits_ids)
    print(sum_obj, just_header = TRUE)
  }

  invisible(x)
}


#' @rdname summary.fits_ids_dm
#' @export
print.summary.fits_ids_dm <- function(
  x,
  ...,
  just_header = FALSE,
  round_digits = drift_dm_default_rounding()
) {
  summary_obj <- x

  if (!is.null(summary_obj$fit_procedure_name)) {
    cat("Fit Procedure Name:", summary_obj$fit_procedure_name)
    cat("\n")
    cat("N Individuals:", summary_obj$N, "\n")
  } else {
    cat("Fit approach: separately - classical\n")
    print_classes(
      header = "Fitted model type:",
      class_vector = summary_obj$summary_drift_dm_obj$class
    )
    cat("Optimizer:", summary_obj$optimizer, "\n")
    all_not_conv <- summary_obj$conv_info$not_conv
    # keep only those that are TRUE
    not_conv <- all_not_conv[sapply(all_not_conv, isTRUE)]
    ids <- names(not_conv) # potentially non-converged individuals
    if (length(ids) > 0) {
      info <- paste(
        "Failed for",
        length(ids),
        paste("participant", if (length(ids) > 1) "s", sep = "")
      )
    } else {
      if (any(is.na(all_not_conv))) {
        info <- "NA"
      } else {
        info <- "TRUE"
      }
    }
    cat("Convergence:", info, "\n")
    cat("N Individuals:", summary_obj$obs_data$N, "\n")
    print_trial_numbers(
      trials_vector = summary_obj$obs_data$avg_trials,
      round_digits = 0,
      header = "Average Trial Numbers:\n"
    )
    print_cost_function(
      cost_function_label = summary_obj$summary_drift_dm_obj$cost_function
    )
  }

  if (!just_header) {
    cat("\n")
    for (one_cond in names(summary_obj$stats)) {
      cat("Parameter Summary:", one_cond, "\n")
      temp <- round(summary_obj$stats[[one_cond]], round_digits)
      print(temp)
      cat("\n")
    }

    if (!is.null(summary_obj$lower)) {
      cat("\n")
      cat("Parameter Space:\n")
      temp <- rbind(summary_obj$lower, summary_obj$upper)
      rownames(temp) <- c("lower", "upper")
      colnames(temp) <- names(summary_obj$upper)
      print(temp)
      cat("\n")
    }
    cat("-------\n")
    if (!is.null(summary_obj$lower)) {
      cat("Fitted Model Type:", summary_obj$model_type)
      cat("\n")
      cat("Time of (Last) Call:", summary_obj$time_call)
      cat("\n")
    } else {
      solver <- summary_obj$summary_drift_dm_obj$solver
      prms_solve <- summary_obj$summary_drift_dm_obj$prms_solve
      print_deriving_pdfs(
        solver = solver,
        prms_solve = prms_solve
      )
    }
  }
  invisible(x)
}


#' Summary and Printing for fits_ids_dm Objects
#'
#' Methods for summarizing and printing objects of the class `fits_ids_dm`,
#' which contain multiple fits across individuals.
#'
#' @param object an object of class `fits_ids_dm`, generated by a call
#'   to [dRiftDM::load_fits_ids].
#' @param x an object of class `summary.fits_ids_dm`.
#' @param round_digits an integer, specifying the number of decimal places for
#'   rounding in the printed summary. Default is 3.
#' @param select_unique logical, passed to [dRiftDM::coef.drift_dm()].
#' @param just_header logical, if `TRUE` only print the header information
#' without details. Default is `FALSE`.
#' @param ... additional arguments (currently unused).
#'
#' @details
#' The `summary.fits_ids_dm` function creates a summary object. The contents of
#' this summary object depends on whether the user supplies a `fits_ids_dm`
#' object that was created with [dRiftDM::estimate_dm()] or the deprecated
#' function [dRiftDM::load_fits_ids()].
#'
#' - In the first case, the object contains:
#'  - **summary_drift_dm_obj**: A list with information about the underlying
#'   drift diffusion model (as returned by [dRiftDM::summary.drift_dm()]).
#'  - **prms**: All parameter values across all conditions (essentially a call
#'   to coef() with the argument select_unique = FALSE).
#'  - **stats**: A named list of matrices for each condition, including mean and
#'   standard error for each parameter.
#'  - **obs_data**: A list providing the number of individual participants and
#'   the average number of trials per condition across participants.
#'  - **optimizer**: A string of the optimizer that was used
#'  - **conv_info**: A list providing a summary of the convergance and messages
#'   for all IDs
#'
#'
#' - In the second case, the object contains:
#'  - **lower** and **upper**: Lower and upper bounds of the search space.
#'  - **model_type**: Description of the model type, based on class information.
#'  - **prms**: All parameter values across all conditions (essentially a call
#'   to coef() with the argument select_unique = FALSE).
#'  - **stats**: A named list of matrices for each condition, including mean and
#'   standard error for each parameter.
#'  - **N**: The number of individuals.
#'
#' The `print.summary.fits_ids_dm` function displays the summary object in a
#' formatted manner.
#'
#' @return
#' `summary.fits_ids_dm()` returns a list of class `summary.fits_ids_dm` (see
#' the Details section summarizing each entry of this list).
#'
#' `print.summary.fits_ids_dm()` returns invisibly the `summary.fits_ids_dm`
#'  object.
#'
#' @examples
#' # get an auxiliary object of type fits_ids_dm for demonstration purpose
#' all_fits <- get_example_fits("fits_ids_dm")
#' sum_obj <- summary(all_fits)
#' print(sum_obj, round_digits = 2)
#'
#' @export
summary.fits_ids_dm <- function(object, ..., select_unique = FALSE) {
  fits_ids <- object
  ans <- list()

  # for backward compatibility
  if (!is.null(fits_ids$drift_dm_fit_info$fit_procedure_name)) {
    ans$fit_procedure_name <- fits_ids$drift_dm_fit_info$fit_procedure_name
    ans$time_call <- fits_ids$drift_dm_fit_info$time_call

    l_u <- get_parameters_smart(
      drift_dm_obj = fits_ids$drift_dm_fit_info$drift_dm_obj,
      input_a = fits_ids$drift_dm_fit_info$lower,
      input_b = fits_ids$drift_dm_fit_info$upper
    )
    ans$lower <- l_u$vec_a
    ans$upper <- l_u$vec_b
    ans$model_type <- paste(
      class(fits_ids$drift_dm_fit_info$drift_dm_obj),
      collapse = ", "
    )
  } else {
    ans$summary_drift_dm_obj <- unclass(
      summary(fits_ids$drift_dm_fit_info$drift_dm_obj)
    )
  }

  # this works for both old and new
  all_prms <- coef(fits_ids, select_unique = select_unique)
  ans$prms <- all_prms
  prm_names <- colnames(all_prms)[!(colnames(all_prms) %in% c("ID", "Cond"))]
  means <- stats::aggregate(all_prms[prm_names], by = all_prms["Cond"], mean)
  std_errs <- stats::aggregate(
    all_prms[prm_names],
    by = all_prms["Cond"],
    \(x) stats::sd(x) / sqrt(length(x))
  )
  ans$stats <- sapply(
    conds(fits_ids),
    function(one_cond) {
      mean <- means[means$Cond == one_cond, -1]
      std_err <- std_errs[means$Cond == one_cond, -1]
      stats_data_frame <- rbind(mean, std_err)
      rownames(stats_data_frame) <- c("mean", "std_err")
      return(stats_data_frame)
    },
    simplify = FALSE,
    USE.NAMES = TRUE
  )

  # for backward compatibility
  if (!is.null(ans$fit_procedure_name)) {
    ans$N <- length(fits_ids$all_fits)
  } else {
    n_avg_trials <- get_avg_trials(fits_ids$drift_dm_fit_info$obs_data_ids)
    ans$obs_data <- n_avg_trials
  }

  # optimizer and convergence info
  if (is.null(ans$fit_procedure_name)) {
    ans$optimizer <- fits_ids$drift_dm_fit_info$optimizer
    ans$conv_info <- fits_ids$drift_dm_fit_info$conv_info
  }

  class(ans) <- "summary.fits_ids_dm"
  return(ans)
}


# HELPER FUNCTION ---------------------------------------------------------

#' Compute average trials per condition across individuals
#'
#' internal helper — assumes each `ID` appears with the same set of conditions
#' in column `Cond`. counts trials per `ID`×`Cond` and averages across `ID`.
#'
#' @param obs_data_ids data frame with columns `ID` and `Cond`; one row per
#'  trial.
#'
#' @return list with:
#'   - `N`: number of unique individuals
#'   - `avg_trials`: named numeric vector of average trials per condition
#'
#' @keywords internal
get_avg_trials <- function(obs_data_ids) {
  ans <- list()
  # Number of subjects
  ans$N <- length(unique(obs_data_ids$ID))
  n_trials <- table(obs_data_ids$ID, obs_data_ids$Cond)
  ans$avg_trials <- colMeans(n_trials)
  return(ans)
}
