Skip to main content

Multivariate Bijections


The fundamental idea of normalizing flows also applies to multivariate random variables, and this is where its value is clearly seen - representing complex high-dimensional distributions. In this case, a simple multivariate source of noise, for example a standard i.i.d. normal distribution, XโˆผN(0,IDร—D)X\sim\mathcal{N}(\mathbf{0},I_{D\times D}), is passed through a vector-valued bijection, g:RDโ†’RDg:\mathbb{R}^D\rightarrow\mathbb{R}^D, to produce the more complex transformed variable Y=g(X)Y=g(X).

Sampling YY is again trivial and involves evaluation of the forward pass of gg. We can score YY using the multivariate substitution rule of integral calculus,

EpX(โ‹…)[f(X)]=โˆซsupp(X)f(x)pX(x)dx=โˆซsupp(Y)f(gโˆ’1(y))pX(gโˆ’1(y))detโกโˆฃdxdyโˆฃdy=EpY(โ‹…)[f(gโˆ’1(Y))],\begin{aligned} \mathbb{E}_{p_X(\cdot)}\left[f(X)\right] &= \int_{\text{supp}(X)}f(\mathbf{x})p_X(\mathbf{x})d\mathbf{x}\\ &= \int_{\text{supp}(Y)}f(g^{-1}(\mathbf{y}))p_X(g^{-1}(\mathbf{y}))\det\left|\frac{d\mathbf{x}}{d\mathbf{y}}\right|d\mathbf{y}\\ &= \mathbb{E}_{p_Y(\cdot)}\left[f(g^{-1}(Y))\right], \end{aligned}

where dx/dyd\mathbf{x}/d\mathbf{y} denotes the Jacobian matrix of gโˆ’1(y)g^{-1}(\mathbf{y}). Equating the last two lines we get,

logโก(pY(y))=logโก(pX(gโˆ’1(y)))+logโก(detโกโˆฃdxdyโˆฃ)=logโก(pX(gโˆ’1(y)))โˆ’logโก(detโกโˆฃdydxโˆฃ).\begin{aligned} \log(p_Y(y)) &= \log(p_X(g^{-1}(y)))+\log\left(\det\left|\frac{d\mathbf{x}}{d\mathbf{y}}\right|\right)\\ &= \log(p_X(g^{-1}(y)))-\log\left(\det\left|\frac{d\mathbf{y}}{d\mathbf{x}}\right|\right). \end{aligned}

Inituitively, this equation says that the density of YY is equal to the density at the corresponding point in XX plus a term that corrects for the warp in volume around an infinitesimally small volume around YY caused by the transformation. For instance, in 22-dimensions, the geometric interpretation of the absolute value of the determinant of a Jacobian is that it represents the area of a parallelogram with edges defined by the columns of the Jacobian. In nn-dimensions, the geometric interpretation of the absolute value of the determinant Jacobian is that is represents the hyper-volume of a parallelepiped with nn edges defined by the columns of the Jacobian (see a calculus reference such as [7] for more details).

Similar to the univariate case, we can compose such bijective transformations to produce even more complex distributions. By an inductive argument, if we have LL transforms g(0),g(1),โ€ฆ,g(Lโˆ’1)g_{(0)}, g_{(1)},\ldots,g_{(L-1)}, then the log-density of the transformed variable Y=(g(0)โˆ˜g(1)โˆ˜โ‹ฏโˆ˜g(Lโˆ’1))(X)Y=(g_{(0)}\circ g_{(1)}\circ\cdots\circ g_{(L-1)})(X) is

logโก(pY(y))=logโก(pX((g(Lโˆ’1)โˆ’1โˆ˜โ‹ฏโˆ˜g(0)โˆ’1)(y)))+โˆ‘l=0Lโˆ’1logโก(โˆฃdg(l)โˆ’1(y(l))dyโ€ฒโˆฃ),\begin{aligned} \log(p_Y(y)) &= \log\left(p_X\left(\left(g_{(L-1)}^{-1}\circ\cdots\circ g_{(0)}^{-1}\right)\left(y\right)\right)\right)+\sum^{L-1}_{l=0}\log\left(\left|\frac{dg^{-1}_{(l)}(y_{(l)})}{dy'}\right|\right), \end{aligned}

where we've defined y(0)=xy_{(0)}=x, y(Lโˆ’1)=yy_{(L-1)}=y for convenience of notation.

The main challenge is in designing parametrizable multivariate bijections that have closed form expressions for both gg and gโˆ’1g^{-1}, a tractable Jacobian whose calculation scales with O(D)O(D) or O(1)O(1) rather than O(D3)O(D^3), and can express a flexible class of functions.

Multivariate Bijectors

In this section, we show how to use bij.SplineAutoregressive to learn the bivariate toy distribution from our running example. Making a simple change we can represent bivariate distributions of the form, p(x1,x2)=p(x1)p(x2โˆฃx1)p(x_1,x_2)=p(x_1)p(x_2|x_1):

dist_x = torch.distributions.Independent(
torch.distributions.Normal(torch.zeros(2), torch.ones(2)),
bijector = bij.SplineAutoregressive()
dist_y = dist.Flow(dist_x, bijector)

The bij.SplineAutoregressive bijector extends bij.Spline so that the spline parameters are the output of an autoregressive neural network. See [durkan2019neural] and [germain2015made] for more details.

Similarly to before, we train this distribution on the toy dataset and plot the results:

dataset = torch.tensor(X, dtype=torch.float)
optimizer = torch.optim.Adam(dist_y.parameters(), lr=5e-3)
for step in range(steps):
loss = -dist_y.log_prob(dataset).mean()

if step % 500 == 0:
print('step: {}, loss: {}'.format(step, loss.item()))
step: 0, loss: 8.446191787719727
step: 500, loss: 2.0197808742523193
step: 1000, loss: 1.794958472251892
step: 1500, loss: 1.73616361618042
step: 2000, loss: 1.7254879474639893
step: 2500, loss: 1.691617488861084
step: 3000, loss: 1.679549217224121
step: 3500, loss: 1.6967085599899292
step: 4000, loss: 1.6723777055740356
step: 4500, loss: 1.6505967378616333
step: 5000, loss: 1.8024061918258667
X_flow = dist_y.sample(torch.Size([1000,])).detach().numpy()
plt.title(r'Joint Distribution')
plt.scatter(X[:,0], X[:,1], label='data', alpha=0.5)
plt.scatter(X_flow[:,0], X_flow[:,1], color='firebrick', label='flow', alpha=0.5)

plt.subplot(1, 2, 1)
sns.distplot(X[:,0], hist=False, kde=True,
kde_kws={'linewidth': 2},
sns.distplot(X_flow[:,0], hist=False, kde=True,
bins=None, color='firebrick',
kde_kws={'linewidth': 2},
plt.subplot(1, 2, 2)
sns.distplot(X[:,1], hist=False, kde=True,
kde_kws={'linewidth': 2},
sns.distplot(X_flow[:,1], hist=False, kde=True,
bins=None, color='firebrick',
kde_kws={'linewidth': 2},

We see from the output that this normalizing flow has successfully learnt both the univariate marginals and the bivariate distribution.