Skip to content

Commit

Permalink
WIP: Added better support for tsibbles and improved function to infer…
Browse files Browse the repository at this point in the history
… frequency.
  • Loading branch information
MMenchero committed Oct 31, 2023
1 parent 3b5939a commit fc857ac
Show file tree
Hide file tree
Showing 17 changed files with 220 additions and 184 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Depends:
LazyData: true
Imports:
httr2,
lubridate,
tsibble
Suggests:
httptest2,
Expand Down
8 changes: 5 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Generated by roxygen2: do not edit by hand

export(find_frequency)
export(prepare_multi_series)
export(prepare_single_series)
export(date_conversion)
export(infer_frequency)
export(prepare_data)
export(set_token)
export(timegpt_forecast)
export(validate_token)
importFrom(httr2,req_headers)
importFrom(httr2,req_perform)
importFrom(httr2,request)
importFrom(httr2,resp_status)
importFrom(lubridate,ymd)
importFrom(lubridate,ymd_hms)
importFrom(tsibble,is_tsibble)
40 changes: 40 additions & 0 deletions R/date_conversion.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#' Infer frequency of a tsibble and convert its index to date or string.
#'
#' @param df A tsibble.
#'
#' @return A list with the inferred frequency and df with the new index.
#' @export
#'
date_conversion <- function(df){
cls <- class(df$ds)[1]

if(cls == "integer"){
freq <- "Y"
df$ds <- paste0(df$ds, "-01-01")

}else if(cls %in% c("yearquarter", "yearmonth", "yearweek")){
freq <- switch(cls,
yearquarter = "Q",
yearmonth = "MS",
yearweek = "W")
df$ds <- as.Date(df$ds)
df$ds <- as.character(df$ds)

}else if(cls == "Date"){
freq <- "D"

}else if(cls %in% c("POSIXct", "POSIXt")){
freq <- "H"

}else{
freq <- NULL

}

if(!is.null(freq)){
message(paste0("Frequency chosen: ", freq))
}

res <- list(df = df, freq = freq)
return(res)
}
48 changes: 0 additions & 48 deletions R/find_frequency.R

This file was deleted.

53 changes: 53 additions & 0 deletions R/infer_frequency.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#' Infer frequency of a data frame.
#'
#' @param df A data frame with time series data.
#'
#' @return The inferred frequency.
#' @export
#'
infer_frequency <- function(df){

freq <- NA
dates <- sort(unique(df$ds))

# Check if it's hourly data
nchrs <- lapply(as.character(dates), nchar)
ntable <- sort(table(unlist(nchrs)))
nmode <- ntable[length(ntable)]
nmode <- as.numeric(names(nmode))

if(nmode > 10){
freq <- "H" # We'll assume hourly data
message("Frequency chosen: H")
return(freq)
}

# If it's not hourly data, check the time differences in days
ddiff <- diff(as.Date(dates))
table <- sort(table(ddiff))
mode <- table[length(table)]
mode <- as.numeric(names(mode))

freq_list = list(
list(alias = "Y", value = c(365,366)),
list(alias = "Q", value = c(91,92)),
list(alias = "MS", value = c(30,31)),
list(alias = "W", value = c(7)),
list(alias = "D", value = c(1))
)

for(i in 1:length(freq_list)){
if(mode %in% freq_list[i][[1]]$value){
freq <- freq_list[i][[1]]$alias
}
}

if(is.na(freq)){
freq <- "D"
message("I'm not sure about the frequency of the data. Will default to daily (D). Please provide it if you know it.")
}

message(paste0("Frequency chosen: ", freq))

return(freq)
}
2 changes: 2 additions & 0 deletions R/nixtlaR-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#' @importFrom httr2 req_perform
#' @importFrom httr2 request
#' @importFrom httr2 resp_status
#' @importFrom lubridate ymd
#' @importFrom lubridate ymd_hms
#' @importFrom tsibble is_tsibble
## usethis namespace: end
NULL
22 changes: 22 additions & 0 deletions R/prepare_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#' Prepare time series data for TimeGPT's API
#'
#' @param df A tsibble or a data frame with time series data.
#'
#' @return A list with the time series data for TimeGPT's API.
#' @export
#'
prepare_data <- function(df){
if("unique_id" %in% names(df)){
df <- df[, c("unique_id", "ds", "y")]
y <- list(
columns = names(df),
data = lapply(1:nrow(df), function(i) as.list(df[i,]))
)
}else{
# only "ds" and "y" columns
y <- df$y
names(y) <- df$ds
y <- as.list(y)
}
return(y)
}
30 changes: 0 additions & 30 deletions R/prepare_multi_series.R

This file was deleted.

28 changes: 0 additions & 28 deletions R/prepare_single_series.R

This file was deleted.

53 changes: 42 additions & 11 deletions R/timegpt_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#' @param id_col Column that identifies each series.
#' @param time_col Column that identifies each timestep.
#' @param target_col Column that contains the target variable.
#' @param X_df A tsibble or a data frame with future exogenous variables.
#' @param level The confidence levels (0-100) for the prediction intervals.
#' @param finetune_steps Number of steps used to finetune TimeGPT in the new data.
#' @param clean_ex_first Clean exogenous signal before making the forecasts using TimeGPT.
Expand All @@ -14,29 +15,43 @@
#' @return TimeGPT forecasts for point predictions and probabilistic predictions (if level is not NULL).
#' @export
#'
timegpt_forecast <- function(df, h, freq=NULL, id_col=NULL, time_col="ds", target_col="y", level=NULL, finetune_steps=0, clean_ex_first=TRUE, add_history=FALSE){
timegpt_forecast <- function(df, h=8, freq=NULL, id_col=NULL, time_col="ds", target_col="y", X_df=NULL, level=NULL, finetune_steps=0, clean_ex_first=TRUE, add_history=FALSE){

token <- get("NIXTLAR_TOKEN", envir = nixtlaR_env)

if(!tsibble::is_tsibble(df) & !is.data.frame(df)){
stop("Only tsibbles or data frames are allowed.")
}

# Rename columns
names(df)[which(names(df) == time_col)] <- "ds"
names(df)[which(names(df) == target_col)] <- "y"
if(!is.null(id_col)){
names(df)[which(names(df) == id_col)] <- "unique_id"
}

# If df is a tsibble, convert dates to strings and infer frequency
if(tsibble::is_tsibble(df)){
res <- date_conversion(df)
df <- res$df
freq <- res$freq
}

# Infer frequency if not available
if(is.null(freq)){
freq <- infer_frequency(df)
}

# Check if single or multi-series and prepare data
if(is.null(id_col)){
url <- "https://dashboard.nixtla.io/api/timegpt"
series_type <- "single"
y <- prepare_single_series(df, time_col, target_col)
}else{
url <- "https://dashboard.nixtla.io/api/timegpt_multi_series"
series_type <- "multi"
y <- prepare_multi_series(df, id_col, time_col, target_col)
}

# Prepare request
if(is.null(freq)){
freq <- find_frequency(df, time_col)
}
y <- prepare_data(df)

timegpt_data <- list(
fh = h,
Expand All @@ -46,6 +61,18 @@ timegpt_forecast <- function(df, h, freq=NULL, id_col=NULL, time_col="ds", targe
clean_ex_first = clean_ex_first
)

# if(!is.null(X_df)){
# names(X_df)[which(names(X_df) == time_col)] <- "ds"
# if(!is.null(id_col)){
# names(X_df)[which(names(X_df) == id_col)] <- "unique_id"
# }
# x <- list(
# columns = names(X_df),
# data = lapply(1:nrow(X_df), function(i) as.list(X_df[i,]))
# )
# timegpt_data[["x"]] <- x
# }

if(!is.null(level)){
level <- as.list(level) # TimeGPT requires level to be a list.
# Users of the forecast package are used to define the level as a vector.
Expand Down Expand Up @@ -157,11 +184,15 @@ timegpt_forecast <- function(df, h, freq=NULL, id_col=NULL, time_col="ds", targe
}
}

# This part needs work
# Return a tsibble if the input was a tsibble
#if(tsibble::is_tsibble(df)){
# fcst <- tsibble::as_tsibble(fcst, key = "unique_id", index = "ds")
#}
if(tsibble::is_tsibble(df)){
if(freq == "H"){
fcst$ds <- lubridate::ymd_hms(fcst$ds)
}else{
fcst$ds <- lubridate::ymd(fcst$ds)
}
fcst <- tsibble::as_tsibble(fcst, key="unique_id", index="ds")
}

# Rename columns to original names
colnames(fcst)[which(colnames(fcst) == "ds")] <- time_col
Expand Down
17 changes: 17 additions & 0 deletions man/date_conversion.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 0 additions & 19 deletions man/find_frequency.Rd

This file was deleted.

Loading

0 comments on commit fc857ac

Please sign in to comment.