Skip to content

Commit

Permalink
added more flexibility to predict method in terms of allowing custom …
Browse files Browse the repository at this point in the history
…innovations for the simulated distribution
  • Loading branch information
alexiosg committed Apr 18, 2023
1 parent b2a912c commit 9f2d403
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 37 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ vignettes/*.pdf

# R Environment Variables
.Renviron
*.Rmd
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: tsgam
Type: Package
Title: Interface to mgcv
Version: 0.2.0
Version: 0.3.0
Authors@R: c(person("Alexios", "Galanos", role = c("aut", "cre"),
email = "[email protected]"))
Maintainer: alexios galanos <[email protected]>
Expand All @@ -11,5 +11,5 @@ Depends: R (>= 3.5.0), tsmethods
Imports: methods, tsaux, tsdistributions, zoo, xts, data.table, mgcv, gratia, future, future.apply, progressr
Encoding: UTF-8
LazyData: true
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
Remotes: tsmodels/tsmethods, tsmodels/tsaux, tsmodels/tsdistributions
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ importFrom(stats,na.omit)
importFrom(stats,predict)
importFrom(stats,quantile)
importFrom(stats,residuals)
importFrom(stats,sd)
importFrom(tsaux,bias)
importFrom(tsaux,crps)
importFrom(tsaux,mape)
Expand All @@ -39,6 +40,7 @@ importFrom(tsaux,mslre)
importFrom(tsaux,sampling_frequency)
importFrom(tsaux,smape)
importFrom(tsdistributions,distribution_modelspec)
importFrom(tsdistributions,qdist)
importFrom(tsdistributions,rdist)
importFrom(xts,as.xts)
importFrom(xts,is.xts)
Expand Down
146 changes: 119 additions & 27 deletions R/mgcv-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ estimate.gam.spec <- function(object, ...)
#' @param distribution A valid distribution from the tsdistributions package to be used to fit the
#' response residuals and then proxy for the predictive distribution. Only valid if the estimated
#' model used the gaussian family, else will throw an error.
#' @param innov an optional matrix of innovations (see innov_type for admissible types),
#' of dimensions nsim x horizon.
#' @param innov_type if \\sQuote{innov} is not NULL, then this denotes the type of values
#' passed, with \\dQuote{q} denoting quantile probabilities (default and
#' backwards compatible), \\dQuote{z} for standardized errors and \\dQuote{r} for
#' raw innovations (non standardized). See details.
#' @param tabular whether to return the data in long format as a data.table instead
#' of an object of tsmodels.predict
#' @param ... not currently used.
Expand All @@ -88,13 +94,24 @@ estimate.gam.spec <- function(object, ...)
#' the model residuals in the prediction step to proxy for the uncertainty is a 2 step option
#' which may yield superior results to using something like family scat in the mgcv model as this
#' has proven to be problematic in certain cases.
#' @details
#' The \dQuote{distribution} and \dQuote{innovations} arguments allow for significant customization
#' of the simulated distribution. When \dQuote{innov_type} is type q, then the values
#' are transformed back to normally distributed predictions using the normal quantile function
#' with mean the prediction vector and standard deviation calculated from the model residuals.
#' If distribution is also not NULL, then instead of using the normal quantile function, a model
#' if first fitted to the residuals given the distribution chosen and the quantile function of
#' that distribution is used to transform the uniform value.
#' When \dQuote{innov_type} is either z or r, then the values are simply rescaled using the
#' standard deviation of the residuals (for type z), and re-centered by the prediction means
#' (for type z and r).
#' @aliases predict
#' @method predict gam.estimate
#' @rdname predict
#' @export
#'
#'
predict.gam.estimate <- function(object, newdata, nsim = 9000, tabular = FALSE, distribution = NULL, ...)
predict.gam.estimate <- function(object, newdata, nsim = 9000, tabular = FALSE, distribution = NULL, innov = NULL, innov_type = "q", ...)
{
# setup data.table variable names to avoid notes in checking
prediction <- parameter <- forecast_date <- estimation_date <- draw <- value <- NULL
Expand All @@ -106,6 +123,7 @@ predict.gam.estimate <- function(object, newdata, nsim = 9000, tabular = FALSE,
# create the decomposition
decomp <- decompose_model(object, newdata = newx, type = "predict")
decomp <- xts(decomp, index(newdata))
if (!is.null(innov) & innov_type != "q") distribution <- NULL
if (!is.null(distribution)) {
if (object$model$family$family != "gaussian") stop("\ndistribution option only available for models estimated using gaussian family")
# estimate distribution on residuals of model
Expand Down Expand Up @@ -146,34 +164,108 @@ predict.gam.estimate <- function(object, newdata, nsim = 9000, tabular = FALSE,
return(simulated_draws)
}
} else {
simulated_draws <- as.data.table(predicted_samples(object$model, n = nsim, newdata = newx))
simulated_draws <- dcast(simulated_draws, draw~row, value.var = "response")
simulated_draws <- simulated_draws[order(draw)]
if (tabular) {
setcolorder(simulated_draws, c("draw",paste0(1:(ncol(simulated_draws)-1))))
colnames(simulated_draws) <- c("draw", as.character(index(newdata)))
simulated_draws <- melt(simulated_draws, id.vars = "draw", measure.vars = 2:ncol(simulated_draws), variable.name = "forecast_date", value.name = "value")
if (object$spec$time_class == "POSIXct") {
fun <- as.POSIXct
if (!is.null(innov)) {
h <- NROW(newdata)
innov_type <- match.arg(innov_type, c("q", "z", "r"))
innov <- as.matrix(innov)
if (NROW(innov) != nsim) {
stop("\nnrow of innov must be nsim")
}
if (NCOL(innov) != h) {
stop("\nncol of innov must be NROW(newdata)")
}
# check that the innovations are uniform samples (from a copula)
if (innov_type == "q") {
if (any(innov < 0 | innov > 1)) {
stop("\ninnov must be >0 and <1 (uniform samples) for innov_type = 'q'")
}
if (any(innov == 0))
innov[which(innov == 0)] <- 1e-12
if (any(innov == 1))
innov[which(innov == 1)] <- (1 - 1e-12)
}
innov <- matrix(innov, nsim, h)
sig <- sd(object$model$residuals)
if (innov_type == "q") {
if (!is.null(distribution)) {
spec <- distribution_modelspec(residuals(object$model, type = "response"), distribution = distribution)
distribution_fit <- estimate(spec)
p_matrix <- distribution_fit$spec$parmatrix
sigma <- p_matrix[parameter == "sigma"]$value
skew <- p_matrix[parameter == "skew"]$value
shape <- p_matrix[parameter == "shape"]$value
lambda <- p_matrix[parameter == "lambda"]$value
simulated_draws <- do.call(cbind, lapply(1:length(mean_prediction), function(i) {
qdist(distribution = distribution, p = innov[,i], mu = mean_prediction[i], sigma = sigma,
skew = skew, shape = shape, lambda = lambda)
}))
} else {
simulated_draws <- do.call(cbind, lapply(1:length(mean_prediction), function(i) {
qdist(distribution = "norm", p = innov[,i], mu = mean_prediction[i], sigma = sig)
}))
}
} else {
fun <- as.Date
if (innov_type == "z") {
simulated_draws <- sweep(innov, 2, sig, "*")
simulated_draws <- sweep(simulated_draws, 2, mean_prediction, "+")
} else {
simulated_draws <- sweep(innov, 2, mean_prediction, "+")
}
}
if (tabular) {
simulated_draws <- as.data.table(simulated_draws)
simulated_draws[,draw := 1:.N]
simulated_draws <- melt(simulated_draws, id.vars = "draw", measure.vars = 1:(ncol(simulated_draws) - 1), variable.name = "forecast_date", value.name = "value")
if (object$spec$time_class == "POSIXct") {
fun <- as.POSIXct
} else {
fun <- as.Date
}
simulated_draws[,forecast_date := fun(forecast_date, tz = object$spec$time_zone)]
e_date <- max(object$spec$clean_time_index)
simulated_draws[,estimation_date := e_date]
simulated_draws <- simulated_draws[,.(estimation_date, draw, forecast_date, value)]
} else {
simulated_draws <- as.matrix(simulated_draws)
colnames(simulated_draws) <- as.character(index(newdata))
class(simulated_draws) <- "tsmodel.distribution"
attr(simulated_draws, "date_class") <- object$time_class
simulated_draws <- list(original_series = zoo(object$model$y, object$spec$clean_time_index),
distribution = simulated_draws,
mean = zoo(mean_prediction, index(newdata)), decomposition = decomp)
class(simulated_draws) <- c("gam.predict","tsmodel.predict")
return(simulated_draws)
}
simulated_draws[,forecast_date := fun(forecast_date, tz = object$spec$time_zone)]
e_date <- max(object$spec$clean_time_index)
simulated_draws[,estimation_date := e_date]
simulated_draws <- simulated_draws[,.(estimation_date, draw, forecast_date, value)]
return(simulated_draws)
} else {
simulated_draws <- simulated_draws[,draw := NULL]
setcolorder(simulated_draws, paste0(1:ncol(simulated_draws)))
simulated_draws <- as.matrix(simulated_draws)
colnames(simulated_draws) <- as.character(index(newdata))
class(simulated_draws) <- "tsmodel.distribution"
attr(simulated_draws, "date_class") <- object$time_class
simulated_draws <- list(original_series = zoo(object$model$y, object$spec$clean_time_index),
distribution = simulated_draws, mean = zoo(mean_prediction, index(newdata)),decomposition = decomp)
class(simulated_draws) <- c("gam.predict","tsmodel.predict")
return(simulated_draws)
simulated_draws <- as.data.table(predicted_samples(object$model, n = nsim, newdata = newx))
simulated_draws <- dcast(simulated_draws, draw~row, value.var = "response")
simulated_draws <- simulated_draws[order(draw)]
if (tabular) {
setcolorder(simulated_draws, c("draw",paste0(1:(ncol(simulated_draws) - 1))))
colnames(simulated_draws) <- c("draw", as.character(index(newdata)))
simulated_draws <- melt(simulated_draws, id.vars = "draw", measure.vars = 2:ncol(simulated_draws), variable.name = "forecast_date", value.name = "value")
if (object$spec$time_class == "POSIXct") {
fun <- as.POSIXct
} else {
fun <- as.Date
}
simulated_draws[,forecast_date := fun(forecast_date, tz = object$spec$time_zone)]
e_date <- max(object$spec$clean_time_index)
simulated_draws[,estimation_date := e_date]
simulated_draws <- simulated_draws[,.(estimation_date, draw, forecast_date, value)]
return(simulated_draws)
} else {
simulated_draws <- simulated_draws[,draw := NULL]
setcolorder(simulated_draws, paste0(1:ncol(simulated_draws)))
simulated_draws <- as.matrix(simulated_draws)
colnames(simulated_draws) <- as.character(index(newdata))
class(simulated_draws) <- "tsmodel.distribution"
attr(simulated_draws, "date_class") <- object$time_class
simulated_draws <- list(original_series = zoo(object$model$y, object$spec$clean_time_index),
distribution = simulated_draws, mean = zoo(mean_prediction, index(newdata)),decomposition = decomp)
class(simulated_draws) <- c("gam.predict","tsmodel.predict")
return(simulated_draws)
}
}
}
}
Expand Down Expand Up @@ -211,7 +303,7 @@ gam_trainspec <- function(formula, family = gaussian(link = "identity"), data, e
if (length(prediction_dates) != n) stop("\nlength of prediction_dates list must be equal to length(estimation_dates)")
if (validate) {
check <- sapply(1:n, function(i){
all(prediction_dates[[i]]>estimation_dates[i])
all(prediction_dates[[i]] > estimation_dates[i])
})
if (any(!check)) warning("\nvalidation failed for prediction_dates>estimation_dates.")
data_dates <- index(data)
Expand Down
4 changes: 2 additions & 2 deletions R/tsgam-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#' @import tsmethods
#' @import data.table
#' @import mgcv
#' @importFrom stats as.formula gaussian quantile na.omit predict residuals coef fitted AIC
#' @importFrom stats as.formula gaussian quantile na.omit predict residuals coef fitted AIC sd
#' @importFrom gratia predicted_samples
#' @importFrom tsaux smape mape bias crps msis mis mslre sampling_frequency
#' @importFrom future.apply future_lapply
#' @importFrom future %<-%
#' @importFrom tsdistributions distribution_modelspec rdist
#' @importFrom tsdistributions distribution_modelspec rdist qdist
#' @importFrom progressr handlers progressor
#' @importFrom zoo index as.zoo zoo coredata
#' @importFrom xts xts as.xts is.xts tzone
Expand Down
9 changes: 7 additions & 2 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,10 @@ tsmodels due to the nature of GAMs.

## Installation

For installation instructions and vignettes for **tsmodels** packages,
see https://tsmodels.github.io/.
The package can be installed from the tsmodels github repo.

```{r,eval=FALSE}
remotes::install_github("tsmodels/tsgam", dependencies = TRUE)
```

A short vignette is available [here](https://www.nopredict.com/packages/tsgam.html).
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

[![R-CMD-check](https://github.com/tsmodels/tsgam/workflows/R-CMD-check/badge.svg)](https://github.com/tsmodels/tsgam/actions)
[![Last-changedate](https://img.shields.io/badge/last%20change-2022--10--08-yellowgreen.svg)](/commits/master)
[![packageversion](https://img.shields.io/badge/Package%20version-0.2.0-orange.svg?style=flat-square)](commits/master)
[![Last-changedate](https://img.shields.io/badge/last%20change-2023--04--18-yellowgreen.svg)](/commits/master)
[![packageversion](https://img.shields.io/badge/Package%20version-0.3.0-orange.svg?style=flat-square)](commits/master)
[![CRAN_Status_Badge](https://www.r-pkg.org/badges/version/tsgam)](https://cran.r-project.org/package=tsgam)

# tsgam
Expand All @@ -12,5 +12,11 @@ than other models in tsmodels due to the nature of GAMs.

## Installation

For installation instructions and vignettes for **tsmodels** packages,
see <https://tsmodels.github.io/>.
The package can be installed from the tsmodels github repo.

``` r
remotes::install_github("tsmodels/tsgam", dependencies = TRUE)
```

A short vignette is available
[here](https://www.nopredict.com/packages/tsgam.html).
22 changes: 22 additions & 0 deletions man/predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9f2d403

Please sign in to comment.