Skip to content

Commit

Permalink
perf: quantiles no longer stop plot function
Browse files Browse the repository at this point in the history
  • Loading branch information
MMenchero committed Jun 16, 2024
1 parent e737778 commit 9d5f866
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions R/nixtla_client_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ nixtla_client_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds"
# Add prediction intervals ----
levels <- grepl("(lo|hi)", names(fcst))
if(any(levels)){

# Build data frame with prediction intervals
lower <- fcst[,c(which(names(fcst) %in% c("unique_id", "ds")), grep("lo", names(fcst)))]
upper <- fcst[,c(which(names(fcst) %in% c("unique_id", "ds")), grep("hi", names(fcst)))]
Expand All @@ -203,6 +202,19 @@ nixtla_client_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds"
quant_long <- tidyr::pivot_longer(quant, cols=grep("-q-", names(quant)), values_to="value", names_to="variable") |>
dplyr::mutate(quantiles = gsub("^[^-]*-[^-]*-", "", .data$variable))

num_quantiles <- length(unique(quant_long$quantiles))
# Prepare colors for quantiles - max 10
qcolors <- c("#755faa", "#3da564", "#dabb35", "#29e2ff","#b5d56d","#9ca5e2", "#d954a0", "#cb4545", "#e45d17","#18392b")

if(num_quantiles > 10){
message("Can't plot more than 10 quantiles at the same time. Selecting 10 at random.")
qselect <- sample(unique(quant_long$quantiles), 10, replace = FALSE)
quant_long <- quant_long |>
dplyr::filter(quantiles %in% qselect)
num_quantiles <- 10
}

color_vals <- c(qcolors[1:num_quantiles], color_vals)
quant_long$quantiles <- paste0("q-", quant_long$quantiles)

plot <- plot+
Expand All @@ -220,17 +232,6 @@ nixtla_client_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds"
}
}

if(any(quantiles)){
num_quantiles <- length(unique(quant_long$quantiles))
if(num_quantiles > 10){
stop("Can't plot more than 10 quantiles")
}else{
qcolors <- c("#755faa", "#3da564", "#dabb35", "#29e2ff","#b5d56d",
"#9ca5e2", "#d954a0", "#cb4545", "#e45d17","#18392b")
color_vals <- c(qcolors[1:num_quantiles], color_vals)
}
}

plot <- plot+
ggplot2::geom_line(ggplot2::aes(x=.data$ds, y=.data$value, group=.data$variable, color=.data$variable), data=data_long)+
ggplot2::scale_color_manual(values = color_vals)+
Expand Down

0 comments on commit 9d5f866

Please sign in to comment.