Masked Observations with NumPyro

Bayesian
NumPyro
Code
Author

Joshua Mabry

Published

June 12, 2024

For building Bayesian models, NumPyro offers flexibility, speed, and near limitless extensiblity. There is a price to pay for all that power as the API can at times be intimidating and hard to decipher. Some of the most important building blocks in Numpyro are effect handlers. Effect handlers allow us to modify how random variables are conditioned, sampled, or observed inside our probabilistic program.

While working on building factorized models for predicting consumer-level purchasing behavior, I ran across the need to use the mask handler so that I could mask certain observations during modeling training. I couldn’t find any minimal examples on how to use the handler that gave me confidence I really understood the intendend usage pattern so I created a simulated dataset to test it out.

Problem Setup and Model Description

We are interested in building a factorized model of retail shopping, where a customer with features \(X\) decides whether or not to visit a store, \(Y_0\), and which items in the store to purchase if they do visit, \(Y_1\). This means that to train the item choice model, we will want to mask observations from any customers that choose not to visit the store.

Let \(X \in \mathbb{R}^D\) be a vector of features of length \(D\). We aim to model the joint probability \(P(Y_0, Y_1 | X)\) where:

  • \(Y_0\) is a binary outcome.
  • \(Y_1 \in \{0, 1\}^I\) is a vector of binary outcomes of length \(I\).

Assuming that the item choice is conditionally independent of the store visite decision, we decompose the joint probability as \[ P(Y_0, Y_1 \mid X)= P(Y_0 \mid X) \times P(Y_1 \mid Y_0 = 1, X) \]

If you are interested in the full details of how to simulate data from this model and how to define it, see the notebook at jmabry/datasci-playground/numpyro_masking/.

NumPyro model with masking

The NumPyro model that allows for masked observations is defined as follows:

def _logit_choice_model(X, name_prefix, n_outputs):
    """ Helper function to define a logistic regression model.
    
    Parameters
    ----------
    X: jnp.ndarray
        Array of customer features of shape (N, D) where N is the number of observations and D is the number of features.
    name_prefix: str
        Prefix for the model parameters.
    n_outputs: int
        Number of output classes.
    """

    n_features = X.shape[1]
    beta = numpyro.sample(f'{name_prefix}_beta', dist.Normal(jnp.zeros((n_outputs, n_features)), jnp.ones((n_outputs, n_features))))
    intercept = numpyro.sample(f'{name_prefix}_intercept', dist.Normal(jnp.zeros(n_outputs), 1.))
    linear_combination = jnp.einsum('ij,kj->ik', X, beta) + intercept
    return jax.nn.sigmoid(linear_combination)

def joint_decision_model(X, I, y0=None, y1=None):
    """Model of joint store visit and item choice decisions.

    Parameters
    ----------
    X: jnp.ndarray
        Array of customer features of shape (N, D) where N is the number of observations and D is the number of features.
    I: int 
        Number of items to choose from.
    y0: jnp.ndarray
        Boolean array of store visit decisions of shape (N, 1).
    y1: jnp.ndarray
        Boolean array of item choice decisions of shape (N, I).

    Returns
    -------
    None
    """

    # Model P(Y0 | X)
    P_Y0 = _logit_choice_model(X, 'Y0', 1).squeeze()

    # Sample Y0
    y0_sample = numpyro.sample('y0', dist.Bernoulli(P_Y0), obs=y0)  

    # Masking to filter out Y1 calculations when Y0 is 0
    mask_array = (y0_sample == 1)[:, None]

    # Model P(Y1 | Y0 = 1, X)
    P_Y1_given_Y0 = _logit_choice_model(X, 'Y1_given_Y0', I)  

    with numpyro.plate('products', I, dim=-1):
        with numpyro.plate('data_y1', X.shape[0]):
            with mask(mask=mask_array):
               numpyro.sample('y1', dist.Bernoulli(P_Y1_given_Y0), obs=y1)

To validate that the model works correctly, we can run MCMC sampling to obtain estimates of the model parameters and compare them to the ground truth values defined in our simulation. Running a toy simulation with 10,000 observations of customers with 3 descriptive features, choosing from 2 items. We get the following parameter estimates, where each row represents a different parameter from the hierarchical model:

                      ground_truth    mean     sd
Y0_beta[0, 0]               0.371    0.355   0.025
Y0_beta[0, 1]               0.305    0.290   0.026
Y0_beta[0, 2]               0.504    0.476   0.026
Y0_intercept[0]             1.353    1.345   0.026
Y1_given_Y0_beta[0, 0]     -2.474   -2.559   0.061
Y1_given_Y0_beta[0, 1]     -1.463   -1.450   0.047
Y1_given_Y0_beta[0, 2]      1.257    1.283   0.043
Y1_given_Y0_beta[1, 0]      2.197    2.290   0.054
Y1_given_Y0_beta[1, 1]     -0.647   -0.650   0.037
Y1_given_Y0_beta[1, 2]      0.478    0.456   0.036
Y1_given_Y0_intercept[0]    0.449    0.517   0.039
Y1_given_Y0_intercept[1]    0.887    0.953   0.035

We see that the parameter estimates match the ground truth used to simulate the data and can feel confident that we are using the NumPyro masking effect handler as intended.

For the sake of comparison, failing to include the masking operation yields very poor parameter estimates as shown below:

                      ground_truth    mean     sd
Y0_beta[0, 0]               0.371    0.355   0.025
Y0_beta[0, 1]               0.305    0.290   0.026
Y0_beta[0, 2]               0.504    0.476   0.028
Y0_intercept[0]             1.353    1.345   0.027
Y1_given_Y0_beta[0, 0]     -2.474   -1.030   0.029
Y1_given_Y0_beta[0, 1]     -1.463   -0.526   0.026
Y1_given_Y0_beta[0, 2]      1.257    0.879   0.027
Y1_given_Y0_beta[1, 0]      2.197    1.466   0.033
Y1_given_Y0_beta[1, 1]     -0.647   -0.205   0.026
Y1_given_Y0_beta[1, 2]      0.478    0.512   0.025
Y1_given_Y0_intercept[0]    0.449   -0.461   0.025
Y1_given_Y0_intercept[1]    0.887    0.016   0.024

I hope this example can help other NumPyro amateurs learn how to use effect handlers and provides a template for validating a black-box API. Reach out to me if you have any questions!