Stan’s algorithms, including MCMC sampling (NUTS), optimization (L-BFGS), variational inference, Pathfinder, and Laplace approximation, are gradient-based methods. This means they require not only the log-probability function but also its gradient (the vector of partial derivatives with respect to each parameter) to work efficiently.
StanEstimators provides three ways to compute
gradients:
RTMB package to compute exact gradients via
automatic differentiation (AD).To use RTMB with StanEstimators, you need to install the
RTMB, withr, and future
packages:
Once installed, simply set grad_fun = "RTMB" in any
StanEstimators function to enable automatic
differentiation.
For basic usage of StanEstimators, see the Getting Started vignette.
Next, we’ll examine a generalized linear model (GLM) for count data. Poisson regression uses a log-link function: \(\log(\lambda_i) = X_i\beta\), where \(\lambda_i\) is the expected count.
inits_pois <- rep(0, 3)
# Finite differences
timing_pois_fd <- system.time({
fit_pois_fd <- stan_sample(poisson_loglik, inits_pois,
additional_args = list(y = y_pois, X = X),
num_chains = 1, seed = 1234)
})
# RTMB
timing_pois_rtmb <- system.time({
fit_pois_rtmb <- stan_sample(poisson_loglik, inits_pois,
grad_fun = "RTMB",
additional_args = list(y = y_pois, X = X),
num_chains = 1, seed = 1234)
})timing_results_pois <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_pois_fd[3], timing_pois_rtmb[3]),
Speedup = c(1, timing_pois_fd[3] / timing_pois_rtmb[3])
)
knitr::kable(timing_results_pois, digits = 2,
caption = "Performance comparison for Poisson regression")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 10.73 | 1.00 | |
| elapsed | RTMB | 2.36 | 4.56 |
summary(fit_pois_rtmb)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -315. -314. 1.31 1.02 -317. -313. 1.00 421.
#> 2 pars[1] 0.546 0.546 0.0613 0.0611 0.447 0.646 1.00 506.
#> 3 pars[2] 1.15 1.15 0.0478 0.0453 1.07 1.23 1.00 515.
#> 4 pars[3] -0.730 -0.729 0.0507 0.0496 -0.818 -0.645 1.00 542.
#> # ℹ 1 more variable: ess_tail <dbl>RTMB handles the matrix operations and log-link function automatically, providing improved performance (typically 8-10x speedup) while correctly recovering the true parameter values.
Logistic regression models binary outcomes using a logit link: \(\text{logit}(p_i) = X_i\beta\), where \(p_i\) is the probability of success.
inits_logit <- rep(0, 3)
# Finite differences
timing_logit_fd <- system.time({
fit_logit_fd <- stan_sample(logistic_loglik, inits_logit,
additional_args = list(y = y_binom, X = X_logit),
num_chains = 1, seed = 1234)
})
# RTMB
timing_logit_rtmb <- system.time({
fit_logit_rtmb <- stan_sample(logistic_loglik, inits_logit,
grad_fun = "RTMB",
additional_args = list(y = y_binom, X = X_logit),
num_chains = 1, seed = 1234)
})timing_results_logit <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_logit_fd[3], timing_logit_rtmb[3]),
Speedup = c(1, timing_logit_fd[3] / timing_logit_rtmb[3])
)
knitr::kable(timing_results_logit, digits = 2,
caption = "Performance comparison for Logistic regression")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 5.45 | 1.00 | |
| elapsed | RTMB | 1.08 | 5.04 |
summary(fit_logit_rtmb)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -152. -152. 1.31 0.924 -155. -151. 0.999 495.
#> 2 pars[1] 0.220 0.225 0.150 0.142 -0.0330 0.453 1.01 759.
#> 3 pars[2] 1.55 1.54 0.201 0.210 1.23 1.87 0.999 695.
#> 4 pars[3] -0.896 -0.895 0.160 0.154 -1.18 -0.625 1.00 915.
#> # ℹ 1 more variable: ess_tail <dbl>Mixture models represent complex latent structure and demonstrate RTMB’s benefits for challenging models. We’ll fit a two-component Gaussian mixture.
The model is: \(p(y) = \pi \cdot N(\mu_1, \sigma_1^2) + (1-\pi) \cdot N(\mu_2, \sigma_2^2)\)
mixture_loglik <- function(pars, y) {
# Transform parameters to satisfy constraints
pi <- pars[1] # mixing proportion in [0,1]
mu1 <- pars[2]
mu2 <- pars[3]
sigma1 <- pars[4] # positive
sigma2 <- pars[5] # positive
# Log-likelihood for each component
log_lik1 <- dnorm(y, mu1, sigma1, log = TRUE) + log(pi)
log_lik2 <- dnorm(y, mu2, sigma2, log = TRUE) + log(1 - pi)
sum(log(exp(log_lik1) + exp(log_lik2)))
}# Initialize near true values (mixture models can have multimodality)
inits_mix <- c(0.3, -2, 3, 1, 1.5)
# Finite differences
timing_mix_fd <- system.time({
fit_mix_fd <- stan_sample(mixture_loglik, inits_mix,
lower = c(0, -Inf, -Inf, 0, 0),
upper = c(1, Inf, Inf, Inf, Inf),
additional_args = list(y = y_mix),
num_chains = 1, seed = 1234)
})
# RTMB
timing_mix_rtmb <- system.time({
fit_mix_rtmb <- stan_sample(mixture_loglik, inits_mix,
lower = c(0, -Inf, -Inf, 0, 0),
upper = c(1, Inf, Inf, Inf, Inf),
grad_fun = "RTMB",
additional_args = list(y = y_mix),
num_chains = 1, seed = 1234)
})timing_results_mix <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_mix_fd[3], timing_mix_rtmb[3]),
Speedup = c(1, timing_mix_fd[3] / timing_mix_rtmb[3])
)
knitr::kable(timing_results_mix, digits = 2,
caption = "Performance comparison for Gaussian Mixture")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 15.26 | 1.00 | |
| elapsed | RTMB | 1.62 | 9.44 |
summary(fit_mix_rtmb)
#> # A tibble: 6 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -909. -909. 1.54 1.42 -912. -907. 0.999 461.
#> 2 pars[1] 0.294 0.293 0.0254 0.0256 0.252 0.336 1.000 802.
#> 3 pars[2] -2.05 -2.06 0.120 0.123 -2.23 -1.85 0.999 744.
#> 4 pars[3] 2.95 2.96 0.104 0.103 2.78 3.12 1.000 603.
#> 5 pars[4] 1.08 1.07 0.0988 0.0960 0.929 1.25 1.01 779.
#> 6 pars[5] 1.51 1.50 0.0833 0.0808 1.38 1.66 1.00 644.
#> # ℹ 1 more variable: ess_tail <dbl>An autoregressive model of order 1 (AR(1)) captures temporal dependence: \(y_t = \phi y_{t-1} + \epsilon_t\), where \(|\phi| < 1\) for stationarity.
We use tanh() to constrain φ to (-1, 1).
ar1_loglik <- function(pars, y) {
phi <- pars[1] # constrain to (-1, 1)
sigma <-pars[2] # positive
n <- length(y)
# First observation from stationary distribution
ll <- dnorm(y[1], 0, sigma / sqrt(1 - phi^2), log = TRUE)
# Subsequent observations
for (t in 2:n) {
ll <- ll + dnorm(y[t], phi * y[t-1], sigma, log = TRUE)
}
ll
}inits_ar <- c(0.5, 1)
# Finite differences
timing_ar_fd <- system.time({
fit_ar_fd <- stan_sample(ar1_loglik, inits_ar,
lower = c(-1, 0),
upper = c(0, Inf),
additional_args = list(y = y_ar),
num_chains = 1, seed = 1234)
})
# RTMB
timing_ar_rtmb <- system.time({
fit_ar_rtmb <- stan_sample(ar1_loglik, inits_ar,
lower = c(-1, 0),
upper = c(0, Inf),
grad_fun = "RTMB",
additional_args = list(y = y_ar),
num_chains = 1, seed = 1234)
})timing_results_ar <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_ar_fd[3], timing_ar_rtmb[3]),
Speedup = c(1, timing_ar_fd[3] / timing_ar_rtmb[3])
)
knitr::kable(timing_results_ar, digits = 2,
caption = "Performance comparison for AR(1) model")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 41.01 | 1.00 | |
| elapsed | RTMB | 0.94 | 43.44 |
summary(fit_ar_rtmb)
#> # A tibble: 3 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -373. -3.73e+2 1.15 0.768 -3.75e+2 -3.72e+2 1.01 228.
#> 2 pars[1] -0.00659 -4.74e-3 0.00634 0.00487 -1.94e-2 -1.66e-4 1.00 448.
#> 3 pars[2] 1.53 1.52e+0 0.0734 0.0722 1.41e+0 1.65e+0 1.00 751.
#> # ℹ 1 more variable: ess_tail <dbl>RTMB also works with Pathfinder, Stan’s fast variational inference method:
fit_ar_path <- stan_pathfinder(ar1_loglik, inits_ar,
grad_fun = "RTMB",
additional_args = list(y = y_ar))summary(fit_ar_path)
#> # A tibble: 5 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp_approx__ 3.18 3.53 1.10 0.777 1.33 4.20 1.01 788.
#> 2 lp__ -284. -284. 1.03 0.750 -286. -283. 1.00 805.
#> 3 path__ 2.48 2 1.14 1.48 1 4 2.65 1.20
#> 4 pars[1] 0.750 0.751 0.0469 0.0460 0.671 0.825 1.01 846.
#> 5 pars[2] 1.01 1.00 0.0507 0.0508 0.921 1.09 0.999 763.
#> # ℹ 1 more variable: ess_tail <dbl>