In Stan, when a parameter is declared as an array, the samples/draws data frame will have columns that use the [i] notation to denote the i-th element of the array. For example, suppose we had a model with two parameters – \(\lambda_1\) and a \(\lambda_2\). Instead of declaring them individually – e.g. lambda1 and lambda2, respectively – we may declare them as a single lambda array of size 2:

parameters {
  real lambda[2];
}

When we sample from that model, we will end up with samples for lambda[1] and lambda[2]. We want to extract the i from [i] and the name of the parameter into separate columns, yielding a tidy dataset.

There are two ways to accomplish this:

In this post I demonstrate both approaches with {CmdStanR} models.

Walkthrough

Setup

Package installation instructions for following along:

pkgs <- c(
  "purrr", "tidyr", "dplyr", # data
  "cmdstanr", "posterior" # modeling
)
options(repos = c(
    MC_STAN = "https://mc-stan.org/r-packages/",
    RSTUDIO = "https://packagemanager.rstudio.com/all/latest",
    CRAN = "https://cran.rstudio.com/"
))
install.packages(pkgs)

Refer to these instructions for installing CmdStan with {CmdStanR}. Windows users should refer to Max Mantei’s post.

Let’s load them up:

suppressPackageStartupMessages({
  # Data Manipulation:
  library("tidyr") # pivot_longer
  library("dplyr")
  library("purrr") # *map*
  # Modeling:
  library("cmdstanr")
  library("posterior") # as_draws_df
  library("tidybayes") # spread_draws
})

# Set CmdStanR as engine for Stan chunks:
register_knitr_engine()

Tip: I like to leave comments next to library() commands when there’s a particular function I’ll be using from that package.

Data

Let’s generate data from three Normal distributions, using {purrr}’s pmap_dfr to operate on a tuple of mean, standard deviation, and number of observations for each “group” and row-bind the random values into a data frame:

mus <- c(0.4, 0.7, 0.6)
sigmas <- c(0.2, 0.1, 0.3)
n_obs <- c(30, 20, 5)

set.seed(42)
data <- pmap_dfr(
  # Process these in parallel:
  list(mus, sigmas, n_obs),
  # Do the following with each tuple:
  function(mu, sigma, n) {
    random_values <- tibble(y = rnorm(n, mean = mu, sd = sigma))
    return(random_values)
  },
  .id = "group"
)
data %>%
  group_by(group) %>%
  summarize(
    n_obs = n(),
    sample_mean = mean(y),
    sample_sd = sd(y)
  )
## # A tibble: 3 x 4
##   group n_obs sample_mean sample_sd
##   <chr> <int>       <dbl>     <dbl>
## 1 1        30       0.414    0.251 
## 2 2        20       0.681    0.0986
## 3 3         5       0.711    0.257

Model

Thanks to the magic of R Markdown and {CmdStanR}’s knitr engine, the following Stan model is compiled and made available to us as example_model:

data {
  int<lower = 1>            G;        // groups
  int<lower = 0>            N;        // observations
  real                      y[N];     // observed outcome
  int<lower = 1, upper = G> group[N]; // group membership
}
parameters {
  real mu[G];
  real sigma[G];
}
model {
  for (i in 1:N) {
    y[i] ~ normal(mu[group[i]], sigma[group[i]]);
  }
}

Sampling

example_data <- list(
  G = 3, N = nrow(data),
  y = data$y, group = as.integer(data$group)
)

example_fit <- example_model$sample(
  data = example_data,
  seed = 42L,
  refresh = 0, # silence progress notification
  output_dir = here("temp")
)
## Running MCMC with 4 parallel chains...
## 
## Chain 1 finished in 2.2 seconds.
## Chain 2 finished in 2.1 seconds.
## Chain 3 finished in 1.2 seconds.
## Chain 4 finished in 1.1 seconds.
## 
## All 4 chains finished successfully.
## Mean chain execution time: 1.6 seconds.
## Total execution time: 2.4 seconds.
example_fit$summary()
## # A tibble: 7 x 10
##   variable   mean median     sd    mad      q5    q95  rhat ess_bulk ess_tail
##   <chr>     <dbl>  <dbl>  <dbl>  <dbl>   <dbl>  <dbl> <dbl>    <dbl>    <dbl>
## 1 lp__     64.9   65.3   2.25   2.00   60.7    67.7    1.00     872.     989.
## 2 mu[1]     0.414  0.414 0.0491 0.0480  0.334   0.495  1.00    1988.    1866.
## 3 mu[2]     0.681  0.680 0.0236 0.0226  0.643   0.720  1.00    1765.    1745.
## 4 mu[3]     0.705  0.710 0.204  0.146   0.408   1.01   1.00    1820.    1467.
## 5 sigma[1]  0.262  0.259 0.0351 0.0348  0.211   0.326  1.00    2008.    1788.
## 6 sigma[2]  0.105  0.103 0.0188 0.0165  0.0799  0.140  1.01    1482.     992.
## 7 sigma[3]  0.399  0.330 0.243  0.140   0.185   0.833  1.00    1042.     784.

Approach 1: tidybayes

Currently {tidybayes} does not work with {CmdStanR} but it looks like once tidy_draws starts using as_draws_df from the {posterior} package (not yet on CRAN), this extra step may not be necessary.

To include basic support for new models, one need only implement the tidy_draws() generic function for that model.

tidy_draws.CmdStanMCMC <- function(model, ...) {
  return(as_draws_df(model$draws()))
}

Once this helper method is available, we can use gather_draws():

approach1 <- gather_draws(example_fit, mu[group], sigma[group]) %>%
  ungroup() # because otherwise the output is left "grouped"

approach1
## # A tibble: 24,000 x 6
##    group .chain .iteration .draw .variable .value
##    <int>  <int>      <int> <int> <chr>      <dbl>
##  1     1      1          1     1 mu         0.409
##  2     1      1          2     2 mu         0.455
##  3     1      1          3     3 mu         0.366
##  4     1      1          4     4 mu         0.355
##  5     1      1          5     5 mu         0.355
##  6     1      1          6     6 mu         0.364
##  7     1      1          7     7 mu         0.432
##  8     1      1          8     8 mu         0.399
##  9     1      1          9     9 mu         0.422
## 10     1      1         10    10 mu         0.419
## # ... with 23,990 more rows

Approach 2: pivoting

The second approach is a general-purpose technique that is useful to have in your data-wrangling toolkit. First, let’s use {posterior}’s as_draws_df to get the posterior samples into a tibble:

draws <- example_fit$draws(variables = c("mu", "sigma")) %>%
  as_draws_df

draws
## # A draws_df: 1000 iterations, 4 chains, and 6 variables
##    mu[1] mu[2] mu[3] sigma[1] sigma[2] sigma[3]
## 1   0.41  0.68  0.99     0.25    0.139     0.78
## 2   0.45  0.71  0.62     0.26    0.082     0.25
## 3   0.37  0.70  0.77     0.27    0.102     0.28
## 4   0.36  0.70  0.67     0.25    0.098     0.20
## 5   0.36  0.70  0.67     0.25    0.098     0.20
## 6   0.36  0.65  0.84     0.28    0.094     0.26
## 7   0.43  0.68  0.76     0.23    0.134     0.29
## 8   0.40  0.66  0.90     0.21    0.141     0.48
## 9   0.42  0.66  0.79     0.21    0.115     0.66
## 10  0.42  0.67  0.93     0.25    0.115     0.33
## # ... with 3990 more draws
## # ... hidden meta-columns {'.chain', '.iteration', '.draw'}

Let’s use tidyr::pivot_longer to turn this “wide” dataset – with all its mu-s and sigma-s as columns – into a “long” one:

approach2 <- draws %>%
  pivot_longer(
    cols = matches("(mu|sigma)"),
    names_pattern = r"((mu|sigma)\[([1-3])\])",
    names_to = c(".variable", "group"),
    values_to = ".value" # to mimic output of gather_draws
  ) %>%
  mutate(group = as.integer(group)) # to mimic output of gather_draws

Breaking it down:

  • matches() (from {tidyselect}) selects columns based on a pattern
  • the regular expression is created with r"()" notation (see earlier blog post on strings in R 4.0.0) and finds two patterns: first for the .variable column, then for the group column

Let’s see what we got:

approach2
## # A tibble: 24,000 x 6
##    .chain .iteration .draw .variable group .value
##     <int>      <int> <int> <chr>     <int>  <dbl>
##  1      1          1     1 mu            1  0.409
##  2      1          1     1 mu            2  0.681
##  3      1          1     1 mu            3  0.986
##  4      1          1     1 sigma         1  0.250
##  5      1          1     1 sigma         2  0.139
##  6      1          1     1 sigma         3  0.777
##  7      1          2     2 mu            1  0.455
##  8      1          2     2 mu            2  0.710
##  9      1          2     2 mu            3  0.618
## 10      1          2     2 sigma         1  0.256
## # ... with 23,990 more rows

Verification

To verify that the two approaches have yielded the same result (although the ordering of the columns is different), we’ll combine the two results and then use pivot_longer()'s sibling pivot_wider(). If everything matches up, we’ll see “OK” in both (tidybayes & pivoting) columns:

approaches <- list(
  "tidybayes" = mutate(approach1, present = "OK"),
  "pivoting" = mutate(approach2, present = "OK")
)

approaches %>%
  bind_rows(.id = "approach") %>%
  pivot_wider(
    # id_cols auto-includes all except "approach" & "present"
    names_from = "approach",
    values_from = "present"
  )
## # A tibble: 24,000 x 8
##    group .chain .iteration .draw .variable .value tidybayes pivoting
##    <int>  <int>      <int> <int> <chr>      <dbl> <chr>     <chr>   
##  1     1      1          1     1 mu         0.409 OK        OK      
##  2     1      1          2     2 mu         0.455 OK        OK      
##  3     1      1          3     3 mu         0.366 OK        OK      
##  4     1      1          4     4 mu         0.355 OK        OK      
##  5     1      1          5     5 mu         0.355 OK        OK      
##  6     1      1          6     6 mu         0.364 OK        OK      
##  7     1      1          7     7 mu         0.432 OK        OK      
##  8     1      1          8     8 mu         0.399 OK        OK      
##  9     1      1          9     9 mu         0.422 OK        OK      
## 10     1      1         10    10 mu         0.419 OK        OK      
## # ... with 23,990 more rows