Skip to content

Commit

Permalink
add num_partitions argument to support large datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
MMenchero committed Aug 27, 2024
1 parent 60622ef commit af251c3
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 25 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Depends:
LazyData: true
Imports:
dplyr,
future,
future.apply,
ggplot2,
httr2,
lubridate,
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ importFrom(dplyr,slice)
importFrom(dplyr,slice_tail)
importFrom(dplyr,summarize)
importFrom(dplyr,ungroup)
importFrom(future,availableCores)
importFrom(future,multisession)
importFrom(future,plan)
importFrom(future.apply,future_lapply)
importFrom(ggplot2,aes)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_line)
Expand Down
4 changes: 4 additions & 0 deletions R/nixtlaR-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
#' @importFrom dplyr slice_tail
#' @importFrom dplyr summarize
#' @importFrom dplyr ungroup
#' @importFrom future availableCores
#' @importFrom future multisession
#' @importFrom future plan
#' @importFrom future.apply future_lapply
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 facet_wrap
#' @importFrom ggplot2 geom_line
Expand Down
7 changes: 1 addition & 6 deletions R/nixtla_client_cross_validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
#'
nixtla_client_cross_validation <- function(df, h=8, freq=NULL, id_col=NULL, time_col="ds", target_col="y", X_df=NULL, level=NULL, quantiles=NULL, n_windows=1, step_size=NULL, finetune_steps=0, finetune_loss="default", clean_ex_first=TRUE, model="timegpt-1", num_partitions=NULL){

start <- Sys.time()

# Prepare data ----
names(df)[which(names(df) == time_col)] <- "ds"
names(df)[which(names(df) == target_col)] <- "y"
Expand Down Expand Up @@ -158,7 +156,7 @@ nixtla_client_cross_validation <- function(df, h=8, freq=NULL, id_col=NULL, time
}

# Date transformation ----
res <- .transform_output_dates(res, "ds", freq, data$flag)
res <- .transform_output_dates(res, id_col, "ds", freq, data$flag)
new_cutoff <- future.apply::future_lapply(res$cutoff, lubridate::ymd_hms)
res$cutoff <- do.call(c, new_cutoff)

Expand All @@ -174,8 +172,5 @@ nixtla_client_cross_validation <- function(df, h=8, freq=NULL, id_col=NULL, time

row.names(res) <- NULL

end <- Sys.time()
print(paste0("Total execution time: ", end-start))

return(res)
}
7 changes: 1 addition & 6 deletions R/nixtla_client_detect_anomalies.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
#'
nixtla_client_detect_anomalies <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", level=c(99), clean_ex_first=TRUE, model="timegpt-1", num_partitions=NULL){

start <- Sys.time()

# Prepare data ----
names(df)[which(names(df) == time_col)] <- "ds"
names(df)[which(names(df) == target_col)] <- "y"
Expand Down Expand Up @@ -89,7 +87,7 @@ nixtla_client_detect_anomalies <- function(df, freq=NULL, id_col=NULL, time_col=
res[, 3:ncol(res)] <- future.apply::future_lapply(res[, 3:ncol(res)], as.numeric)

# Date transformation ----
res <- .transform_output_dates(res, "ds", freq, data$flag)
res <- .transform_output_dates(res, id_col, "ds", freq, data$flag)

# Rename columns ----
colnames(res)[which(colnames(res) == "ds")] <- time_col
Expand All @@ -103,8 +101,5 @@ nixtla_client_detect_anomalies <- function(df, freq=NULL, id_col=NULL, time_col=

row.names(res) <- NULL

end <- Sys.time()
print(paste0("Total execution time: ", end-start))

return(res)
}
6 changes: 1 addition & 5 deletions R/nixtla_client_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#'
nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col=NULL, time_col="ds", target_col="y", X_df=NULL, level=NULL, quantiles=NULL, finetune_steps=0, finetune_loss="default", clean_ex_first=TRUE, add_history=FALSE, model="timegpt-1", num_partitions=NULL){

start <- Sys.time()
# Prepare data ----
names(df)[which(names(df) == time_col)] <- "ds"
names(df)[which(names(df) == target_col)] <- "y"
Expand Down Expand Up @@ -153,7 +152,7 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col=NULL, time_col="ds
}

# Date transformation ----
fcst <- .transform_output_dates(fcst, "ds", freq, data$flag)
fcst <- .transform_output_dates(fcst, id_col, "ds", freq, data$flag)

# Rename columns ----
names(fcst)[which(names(fcst) == "ds")] <- time_col
Expand All @@ -180,8 +179,5 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col=NULL, time_col="ds

row.names(fcst) <- NULL

end <- Sys.time()
print(paste0("Total execution time: ", end-start))

return(fcst)
}
7 changes: 1 addition & 6 deletions R/nixtla_client_historic.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
#'
nixtla_client_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", level=NULL, quantiles=NULL, finetune_steps=0, finetune_loss="default", clean_ex_first=TRUE, model="timegpt-1", num_partitions=NULL){

start <- Sys.time()

# Prepare data ----
names(df)[which(names(df) == time_col)] <- "ds"
names(df)[which(names(df) == target_col)] <- "y"
Expand Down Expand Up @@ -131,7 +129,7 @@ nixtla_client_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", ta
}

# Date transformation ----
fitted <- .transform_output_dates(fitted, "ds", freq, data$flag)
fitted <- .transform_output_dates(fitted, id_col, "ds", freq, data$flag)

# Rename columns ----
names(fitted)[which(names(fitted) == "ds")] <- time_col
Expand All @@ -145,8 +143,5 @@ nixtla_client_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", ta

row.names(fitted) <- NULL

end <- Sys.time()
print(paste0("Total execution time: ", end-start))

return(fitted)
}
3 changes: 2 additions & 1 deletion R/transform_output_dates.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
##' This is a private function of 'nixtlar'
#'
#' @param df Dataframe with the 'TimeGPT' output, where column 'col' contains date strings.
#' @param id_col Column that identifies each series.
#' @param col Name of the column with the dates to transform.
#' @param freq Frequency of the data.
#' @param flag Indicator where 1 denotes 'tsibble' and 0 denotes 'dataframe'.
Expand All @@ -15,7 +16,7 @@
#' fcst <- .transform_output_dates(fcst, col, freq, flag)
#' }
#'
.transform_output_dates <- function(df, col, freq, flag){
.transform_output_dates <- function(df, id_col, col, freq, flag){

index_col <- which(names(df) == col)

Expand Down
4 changes: 3 additions & 1 deletion man/dot-transform_output_dates.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit af251c3

Please sign in to comment.