Pivoting posteriors
In
Stan, when a parameter is declared as an array, the samples/draws data frame will have columns that use the [i]
notation to denote the i
-th element of the array. For example, suppose we had a model with two parameters – \(\lambda_1\)
and a \(\lambda_2\)
. Instead of declaring them individually – e.g. lambda1
and lambda2
, respectively – we may declare them as a single lambda
array of size 2:
parameters {
real lambda[2];
}
When we sample from that model, we will end up with samples for lambda[1]
and lambda[2]
. We want to extract the i
from [i]
and the name of the parameter into separate columns, yielding a tidy dataset.
There are two ways to accomplish this:
-
tidybayes::gather_draws
is specifically for Bayesian model fits; see using tidy data with Bayesian models -
tidyr::pivot_longer
is for data wrangling in general; see pivoting vignette for an introduction to this function
In this post I demonstrate both approaches with {CmdStanR} models.
Walkthrough
Setup
Package installation instructions for following along:
pkgs <- c(
"purrr", "tidyr", "dplyr", # data
"cmdstanr", "posterior" # modeling
)
options(repos = c(
MC_STAN = "https://mc-stan.org/r-packages/",
CRAN = "https://cran.rstudio.com/"
))
install.packages(pkgs)
Refer to these instructions for installing CmdStan with {CmdStanR}. Windows users should refer to Max Mantei’s post.
Let’s load them up:
suppressPackageStartupMessages({
# Data Manipulation:
library("tidyr") # pivot_longer
library("dplyr")
library("purrr") # *map*
# Modeling:
library("cmdstanr")
library("posterior") # as_draws_df
library("tidybayes") # spread_draws
})
# Set CmdStanR as engine for Stan chunks:
register_knitr_engine()
Tip: I like to leave comments next to library()
commands when there’s a particular function I’ll be using from that package.
Data
Let’s generate data from three Normal distributions, using {purrr}’s
pmap_dfr
to operate on a tuple of mean, standard deviation, and number of observations for each “group” and row-bind the random values into a data frame:
mus <- c(0.4, 0.7, 0.6)
sigmas <- c(0.2, 0.1, 0.3)
n_obs <- c(30, 20, 5)
set.seed(42)
data <- pmap_dfr(
# Process these in parallel:
list(mus, sigmas, n_obs),
# Do the following with each tuple:
function(mu, sigma, n) {
random_values <- tibble(y = rnorm(n, mean = mu, sd = sigma))
return(random_values)
},
.id = "group"
)
data %>%
group_by(group) %>%
summarize(
n_obs = n(),
sample_mean = mean(y),
sample_sd = sd(y)
)
## # A tibble: 3 × 4
## group n_obs sample_mean sample_sd
## <chr> <int> <dbl> <dbl>
## 1 1 30 0.414 0.251
## 2 2 20 0.681 0.0986
## 3 3 5 0.711 0.257
Model
Thanks to the magic of R Markdown and {CmdStanR}’s
knitr engine, the following Stan model is compiled and made available to us as example_model
:
data {
int<lower = 1> G; // groups
int<lower = 0> N; // observations
array[N] real y; // observed outcome
array[N] int<lower = 1, upper = G> group; // group membership
}
parameters {
array[G] real mu;
array[G] real sigma;
}
model {
for (i in 1:N) {
y[i] ~ normal(mu[group[i]], sigma[group[i]]);
}
}
Sampling
example_data <- list(
G = 3, N = nrow(data),
y = data$y, group = as.integer(data$group)
)
example_fit <- example_model$sample(
data = example_data,
seed = 42L,
refresh = 0, # silence progress notification
output_dir = here("temp")
)
## Running MCMC with 4 parallel chains...
##
## Chain 2 finished in 0.1 seconds.
## Chain 1 finished in 0.2 seconds.
## Chain 3 finished in 0.2 seconds.
## Chain 4 finished in 0.2 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 0.2 seconds.
## Total execution time: 0.4 seconds.
example_fit$summary()
## # A tibble: 7 × 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ 64.7 65.2 2.28 2.11 60.5 67.6 1.00 825. 778.
## 2 mu[1] 0.413 0.413 0.0491 0.0475 0.335 0.494 1.00 2336. 2423.
## 3 mu[2] 0.681 0.681 0.0249 0.0229 0.641 0.722 1.00 1567. 1442.
## 4 mu[3] 0.708 0.712 0.204 0.155 0.394 1.01 1.00 2032. 1059.
## 5 sigma[1] 0.262 0.258 0.0361 0.0335 0.211 0.328 1.00 2463. 1580.
## 6 sigma[2] 0.107 0.104 0.0202 0.0173 0.0804 0.144 1.00 1324. 1037.
## 7 sigma[3] 0.401 0.334 0.245 0.143 0.181 0.836 1.00 954. 713.
Approach 1: tidybayes
Currently {tidybayes} does not work with {CmdStanR} but it looks like once
tidy_draws
starts using as_draws_df
from the {posterior} package (not yet on CRAN), this extra step may not be necessary.
To include basic support for new models, one need only implement the
tidy_draws()
generic function for that model.
tidy_draws.CmdStanMCMC <- function(model, ...) {
return(as_draws_df(model$draws()))
}
Once this helper method is available, we can use gather_draws()
:
approach1 <- gather_draws(example_fit, mu[group], sigma[group]) %>%
ungroup() # because otherwise the output is left "grouped"
approach1
## # A tibble: 24,000 × 6
## group .chain .iteration .draw .variable .value
## <int> <int> <int> <int> <chr> <dbl>
## 1 1 1 1 1 mu 0.417
## 2 1 1 2 2 mu 0.391
## 3 1 1 3 3 mu 0.396
## 4 1 1 4 4 mu 0.398
## 5 1 1 5 5 mu 0.424
## 6 1 1 6 6 mu 0.378
## 7 1 1 7 7 mu 0.371
## 8 1 1 8 8 mu 0.451
## 9 1 1 9 9 mu 0.389
## 10 1 1 10 10 mu 0.442
## # … with 23,990 more rows
Approach 2: pivoting
The second approach is a general-purpose technique that is useful to have in your data-wrangling toolkit. First, let’s use {posterior}’s as_draws_df
to get the posterior samples into a tibble:
draws <- example_fit$draws(variables = c("mu", "sigma")) %>%
as_draws_df
draws
## # A draws_df: 1000 iterations, 4 chains, and 6 variables
## mu[1] mu[2] mu[3] sigma[1] sigma[2] sigma[3]
## 1 0.42 0.63 1.10 0.29 0.100 1.02
## 2 0.39 0.67 0.92 0.28 0.095 0.79
## 3 0.40 0.68 0.93 0.27 0.115 0.52
## 4 0.40 0.70 1.03 0.24 0.122 0.28
## 5 0.42 0.69 0.80 0.25 0.112 0.42
## 6 0.38 0.65 0.76 0.22 0.083 0.27
## 7 0.37 0.65 0.77 0.25 0.086 0.17
## 8 0.45 0.72 0.83 0.24 0.123 0.30
## 9 0.39 0.70 0.73 0.23 0.115 0.21
## 10 0.44 0.65 0.67 0.28 0.115 0.17
## # ... with 3990 more draws
## # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
Let’s use
tidyr::pivot_longer
to turn this “wide” dataset – with all its mu
-s and sigma
-s as columns – into a “long” one:
approach2 <- draws %>%
pivot_longer(
cols = matches("(mu|sigma)"),
names_pattern = r"((mu|sigma)\[([1-3])\])",
names_to = c(".variable", "group"),
values_to = ".value" # to mimic output of gather_draws
) %>%
mutate(group = as.integer(group)) # to mimic output of gather_draws
Breaking it down:
-
matches()
(from {tidyselect}) selects columns based on a pattern - the regular expression is created with
r"()"
notation (see earlier blog post on strings in R 4.0.0) and finds two patterns: first for the.variable
column, then for thegroup
column
Let’s see what we got:
approach2
## # A tibble: 24,000 × 6
## .chain .iteration .draw .variable group .value
## <int> <int> <int> <chr> <int> <dbl>
## 1 1 1 1 mu 1 0.417
## 2 1 1 1 mu 2 0.634
## 3 1 1 1 mu 3 1.10
## 4 1 1 1 sigma 1 0.293
## 5 1 1 1 sigma 2 0.100
## 6 1 1 1 sigma 3 1.02
## 7 1 2 2 mu 1 0.391
## 8 1 2 2 mu 2 0.667
## 9 1 2 2 mu 3 0.917
## 10 1 2 2 sigma 1 0.278
## # … with 23,990 more rows
Verification
To verify that the two approaches have yielded the same result (although the ordering of the columns is different), we’ll combine the two results and then use pivot_longer()
’s sibling pivot_wider()
. If everything matches up, we’ll see “OK” in both (tidybayes
& pivoting
) columns:
approaches <- list(
"tidybayes" = mutate(approach1, present = "OK"),
"pivoting" = mutate(approach2, present = "OK")
)
approaches %>%
bind_rows(.id = "approach") %>%
pivot_wider(
# id_cols auto-includes all except "approach" & "present"
names_from = "approach",
values_from = "present"
)
## # A tibble: 24,000 × 8
## group .chain .iteration .draw .variable .value tidybayes pivoting
## <int> <int> <int> <int> <chr> <dbl> <chr> <chr>
## 1 1 1 1 1 mu 0.417 OK OK
## 2 1 1 2 2 mu 0.391 OK OK
## 3 1 1 3 3 mu 0.396 OK OK
## 4 1 1 4 4 mu 0.398 OK OK
## 5 1 1 5 5 mu 0.424 OK OK
## 6 1 1 6 6 mu 0.378 OK OK
## 7 1 1 7 7 mu 0.371 OK OK
## 8 1 1 8 8 mu 0.451 OK OK
## 9 1 1 9 9 mu 0.389 OK OK
## 10 1 1 10 10 mu 0.442 OK OK
## # … with 23,990 more rows
- Posted on:
- September 7, 2020
- Length:
- 7 minute read, 1465 words