Vector Research Blog: Structured Neural Networks for Density Estimation and Causal Inference
January 22, 2024
January 22, 2024
By Asic Q. Chen, Ruian Shi, Xiang Gao, Ricardo Baptista, Rahul G. Krishnan
In this blog post we introduce StrNN, an efficient way to inject previously assumed variable structure into arbitrary neural networks via weight masking. The modular nature of StrNN leads to applications in density estimation, generative modeling, causal inference, and much more.
Based on the NeurIPS 2023 paper: Structured Neural Networks for Density Estimation and Causal Inference.
To run your own version of StrNN, find our Python package on GitHub.
While neural networks are universal function estimators, it could often be beneficial to constrain the class of functions a neural network can model. For example, in the widely cited Deep Sets1 paper, Zaheer et al. focused on functions that are permutation invariant, which means the output of the function remains unchanged regardless of the order of the inputs, to better deal with training data that come in unordered sets. In our paper, we demonstrate how structure in neural networks could lead to function invariances beneficial in other applications.
One main motivating use case is probability density estimation, where we use neural networks to model the joint probability density of random variables. In this setting, we frequently already know certain facts about the data generating process itself, namely independence relationships between the random variables. We usually come up with these independence statements through domain expertise or structure discovery algorithms. The structure of these independencies are commonly described using Bayesian networks.
More generally, we use an adjacency matrix to model arbitrary independence statements for inputs and outputs. If we have d-dimensional data, our adjacency matrix is a binary-valued A∈{0, 1}d × d where Aij=0 if and only if xi⊥ xj | x{1, …, i \ pa(i)} and Aij=1 otherwise. Consider an autoencoder with d inputs and outputs. When the outputs are autoregressive in relation to the inputs, A has all ones under the diagonal and all zeros elsewhere. Existing work such as the Masked Autoencoder for Density Estimation (MADE)2 has exploited this structure to use each output node of an autoencoder to model a marginal probability factor in the probability chain rule. We are instead interested in more complex known independence structures, which means A is not only lower triangular, but also sparse under the diagonal, like the example shown in Figure 1.
Taking inspiration from the approach used in MADE, we enforce the structure given by A through element-wise multiplying the neural network weight matrices with binary masks. In this way, we zero out certain connections in the network so that there are no paths between independent inputs and outputs according to A. More concretely, for a toy neural network y=f(x) with a single hidden layer, element-wise multiplies (denoted by ʘ) the weight matrices W and V with binary masks MW and MV.
h(x) = g((W ʘ MW)x+b), y = f((V ʘ MV)h(x)+c)
To inject the structure prescribed by A, all we need to do is find appropriate mask matrices MW and MV.
Then the key insight is that structure injection in a neural network boils down to a binary matrix factorization problem. We frame the problem formally as follows:
Given an adjacency matrix A∈{0,1 }d × d and a neural network with L hidden layers, each with h1, h2, …, hL (≥ d) hidden units, we want to factor A into mask matrices M1∈ {0, 1}h1×d, M2∈ {0, 1}h2xh1, …, ML ∈ {0, 1}d × hL such that A’ ~ A where A’ = ML × … × M2 × M1. We use the notation A’ ~ A to denote that matrices A’ and A share the same sparsity pattern, i.e.: exact same locations of zeros and non-zeros. (Overloading notation here – it doesn’t mean matrix similarity like in linear algebra!) Note that here A is a binary matrix and A’ is an integer matrix. We then mask the neural network’s hidden layers using M1, M2, …, ML through element-wise multiplication like in the above equation to obtain a Structured Neural Network (StrNN), which respects the independence constraints prescribed by A. The value of each entry A’ij thus corresponds to the number of connections flowing from input xj to output x̂ i in the StrNN.
Binary matrix factorization itself is an NP-hard problem. Although there is a lot of existing literature, it is mostly focused on finding low-rank factors while minimizing reconstruction loss. In our case, we need zero reconstruction loss if we want to respect the structure given by A, so we have to find our own ways to factorize.
Identifiability is also an issue. When the hidden layer sizes are large, there are many possible masks given the same adjacency matrix. So, we need to specify optimization objectives, which are directly related to the neural network architecture. For example, in this paper, we mainly test the hypothesis that maximizing the number of remaining paths in our masked neural network leads to better expressiveness and generalization. Building on this idea, we investigate two objectives. The first one can be seen in Equation 2 below:
It maximizes the sum of all entries in A’. As we mentioned earlier, this is equivalent to maximizing the total paths between all inputs and outputs in the StrNN. A second objective we consider is Equation 3:
Here we add a variance penalty term, so that the remaining paths are not too concentrated on one output. Through empirical evaluations on various synthetic datasets, we conclude that the added variance penalty does not make a significant difference in StrNN’s density estimation performance, so we adopt Equation 2 as our objective through the rest of the project.
We now turn to methods of solving the binary matrix factorization problem. We can find exact solutions to maximize Equations 2 and 3 via integer programming, but empirically it is prohibitively slow for larger input dimensions (e.g.: d ≥ 20). Therefore, we propose a simple and efficient greedy algorithm that approximates the optimization objective (Equation 2) but keeps zero reconstruction loss when it comes to the sparsity pattern of adjacency matrix A. The pseudocode of our algorithm is described in Algorithm 1 in our paper, and we provide a visualization for one example adjacency matrix below.
Now let’s look at some applications of StrNN.
Normalizing flows is a popular framework for probability density estimation and generative modeling tasks. It belongs to a subset of density estimation models that frame the problem as learning complex functions between high-dimensional spaces. In particular, the normalizing flows framework learns invertible maps between a simple base distribution and a complex target distribution. This leads to simple and efficient likelihood estimation as well as sample generation. So, it’s not hard to see why we might want to make use of the known conditional independencies between input variables while training a flow. We do so by using StrNN to enforce function invariances in the flow networks.
The most natural class of normalizing flows to extend is autoregressive flows (ARF3), which enforces an autoregressive structure between inputs and outputs so that the probability change-of-variables formula central to computing probability densities is easy to compute. Replacing the autoregressive conditioners in ARFs with StrNN, we can encode additional independence statements to improve both likelihood estimation and sample generation quality. We call this the Structured Autoregressive Flow (StrAF), as seen in Figure 3. Based on a similar rationale, we introduce structure to a continuous normalizing flow called FFJORD4 by using StrNN to parameterize the differential equation that describes the continuous data generating dynamics. Figure 4 shows a comparison of samples generated by StrAF, StrCNF, and baselines.
While injecting structure, both StrAF and StrCNF inherit the efficiency of StrNN due to our choice of weight masking. Specifically, the output of the StrNN can be computed with a single forward pass through the network. In comparison, input masking approaches such as the Graphical Normalizing Flows5 baseline must perform d forward passes to compute the output for a single datum. This not only prevents efficient application of input masking to high dimensional data, but also is a barrier to integrating the method with certain architectures. For example, FFJORD already requires many neural network evaluations to numerically solve the ODE defining the flow map, so making d passes per evaluation is particularly inefficient. This makes our StrNN the simplest and most efficient way to inject structure into this type of continuous flow.
We further apply StrAF to causal effect estimation. We build on prior work (Causal Autoregressive Flows6) that models structural equation models as affine flows, which leads to favourable identifiability theorems. In Figure 5, experimental results show that leveraging the exact independence structure improves performance on interventional and counterfactual queries for many variables.
We introduced the Structured Neural Network, a function approximator that allows us to inject arbitrary variable structure through weight masking. We framed weight masking as a binary matrix factorization problem and proposed various algorithms to solve it. We applied structured neural networks to normalizing flows for improved density estimation and generative modeling, which also gives us a powerful tool to model structure equation models for causal effect estimation.
In our work, we have demonstrated the plug-and-play advantage of the StrNN by integrating it into flow architectures to perform density estimation. Similarly, the StrNN can also be easily incorporated into other existing SOTA architectures to enforce known structure for various tasks. We believe extension into diffusion models, variational inference, and even supervised learning could be promising avenues for future work.
[1] Zaheer, Manzil, Satwik Kottur, Siamak Ravanbakhsh, Barnabas Poczos, Russ R. Salakhutdinov, and Alexander J. Smola. “Deep sets.” In Advances in neural information processing systems 30 (2017).
[2] Germain, Mathieu, Karol Gregor, Iain Murray, and Hugo Larochelle. “Made: Masked autoencoder for distribution estimation.” In International conference on machine learning, pp. 881-889. PMLR, 2015.
[3] Huang, Chin-Wei, David Krueger, Alexandre Lacoste, and Aaron Courville. “Neural autoregressive flows.” In International Conference on Machine Learning, pp. 2078-2087. PMLR, 2018.
[4] Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. “Ffjord: Free-form continuous dynamics for scalable reversible generative models.” arXiv preprint arXiv:1810.01367 (2018).
[5] Wehenkel, Antoine, and Gilles Louppe. “Graphical normalizing flows.” In International Conference on Artificial Intelligence and Statistics, pp. 37-45. PMLR, 2021.
[6] Khemakhem, Ilyes, Ricardo Monti, Robert Leech, and Aapo Hyvarinen. “Causal autoregressive flows.” In International conference on artificial intelligence and statistics, pp. 3520-3528. PMLR, 2021.