Skip to main content

Univariate Bijections

Background

Normalizing Flows are a family of methods for constructing flexible distributions. As mentioned in the introduction, Normalizing Flows can be seen as a modern take on the change of variables method for random distributions, and this is most apparent for univariate bijections. Thus, in this first section we restrict our attention to representing univariate distributions with bijections.

The basic idea is that a simple source of noise, for example a variable with a standard normal distribution, XโˆผN(0,1)X\sim\mathcal{N}(0,1), is passed through a bijective (i.e. invertible) function, g(โ‹…)g(\cdot) to produce a more complex transformed variable Y=g(X)Y=g(X). For such a random variable, we typically want to perform two operations: sampling and scoring. Sampling YY is trivial. First, we sample X=xX=x, then calculate y=g(x)y=g(x). Scoring YY, or rather, evaluating the log-density logโก(pY(y))\log(p_Y(y)), is more involved. How does the density of YY relate to the density of XX? We can use the substitution rule of integral calculus to answer this. Suppose we want to evaluate the expectation of some function of XX. Then,

EpX(โ‹…)[f(X)]=โˆซsupp(X)f(x)pX(x)dx=โˆซsupp(Y)f(gโˆ’1(y))pX(gโˆ’1(y))โˆฃdxdyโˆฃdy=EpY(โ‹…)[f(gโˆ’1(Y))],\begin{aligned} \mathbb{E}_{p_X(\cdot)}\left[f(X)\right] &= \int_{\text{supp}(X)}f(x)p_X(x)dx\\ &= \int_{\text{supp}(Y)}f(g^{-1}(y))p_X(g^{-1}(y))\left|\frac{dx}{dy}\right|dy \\ &= \mathbb{E}_{p_Y(\cdot)}\left[f(g^{-1}(Y))\right], \end{aligned}

where supp(X)\text{supp}(X) denotes the support of XX, which in this case is (โˆ’โˆž,โˆž)(-\infty,\infty). Crucially, we used the fact that gg is bijective to apply the substitution rule in going from the first to the second line. Equating the last two lines we get,

logโก(pY(y))=logโก(pX(gโˆ’1(y)))+logโก(โˆฃdxdyโˆฃ)=logโก(pX(gโˆ’1(y)))โˆ’logโก(โˆฃdydxโˆฃ).\begin{aligned} \log(p_Y(y)) &= \log(p_X(g^{-1}(y)))+\log\left(\left|\frac{dx}{dy}\right|\right)\\ &= \log(p_X(g^{-1}(y)))-\log\left(\left|\frac{dy}{dx}\right|\right). \end{aligned}

Inituitively, this equation says that the density of YY is equal to the density at the corresponding point in XX plus a term that corrects for the warp in volume around an infinitesimally small length around YY caused by the transformation.

If gg is cleverly constructed (and we will see several examples shortly), we can produce distributions that are more complex than standard normal noise and yet have easy sampling and computationally tractable scoring. Moreover, we can compose such bijective transformations to produce even more complex distributions. By an inductive argument, if we have LL transforms g(0),g(1),โ€ฆ,g(Lโˆ’1)g_{(0)}, g_{(1)},\ldots,g_{(L-1)}, then the log-density of the transformed variable Y=(g(0)โˆ˜g(1)โˆ˜โ‹ฏโˆ˜g(Lโˆ’1))(X)Y=(g_{(0)}\circ g_{(1)}\circ\cdots\circ g_{(L-1)})(X) is

logโก(pY(y))=logโก(pX((g(Lโˆ’1)โˆ’1โˆ˜โ‹ฏโˆ˜g(0)โˆ’1)(y)))+โˆ‘l=0Lโˆ’1logโก(โˆฃdg(l)โˆ’1(y(l))dyโ€ฒโˆฃ),\begin{aligned} \log(p_Y(y)) &= \log\left(p_X\left(\left(g_{(L-1)}^{-1}\circ\cdots\circ g_{(0)}^{-1}\right)\left(y\right)\right)\right)+\sum^{L-1}_{l=0}\log\left(\left|\frac{dg^{-1}_{(l)}(y_{(l)})}{dy'}\right|\right), \end{aligned}

where we've defined y(0)=xy_{(0)}=x, y(Lโˆ’1)=yy_{(L-1)}=y for convenience of notation. In the following section, we will see how to generalize this method to multivariate XX.

Fixed Univariate Bijectors

FlowTorch contains classes for representing fixed univariate bijective transformations. These are particularly useful for restricting the range of transformed distributions, for example to lie on the unit hypercube. (In the following sections, we will explore how to represent learnable bijectors.)

Let us begin by showing how to represent and manipulate a simple transformed distribution,

XโˆผN(0,1)Y=exp(X).\begin{aligned} X &\sim \mathcal{N}(0,1)\\ Y &= \text{exp}(X). \end{aligned}

You may have recognized that this is by definition, YโˆผLogNormal(0,1)Y\sim\text{LogNormal}(0,1).

We begin by importing the relevant libraries:

import torch
import flowtorch.bijectors as bij
import flowtorch.distributions as dist
import matplotlib.pyplot as plt
import seaborn as sns

A variety of bijective transformations live in the flowtorch.bijectors module, and the classes to define transformed distributions live in flowtorch.distributions. We first create the base distribution of XX and the class encapsulating the transform exp(โ‹…)\text{exp}(\cdot):

dist_x = torch.distributions.Independent(
torch.distributions.Normal(torch.zeros(1), torch.ones(1)),
1
)
bijector = bij.Exp()

The class bij.Exp derives from bij.Fixed and defines the forward, inverse, and log-absolute-derivative operations for this transform,

g(x)=exp(x)gโˆ’1(y)=logโก(y)logโก(โˆฃdgdxโˆฃ)=y.\begin{aligned} g(x) &= \text{exp(x)}\\ g^{-1}(y) &= \log(y)\\ \log\left(\left|\frac{dg}{dx}\right|\right) &= y. \end{aligned}

In general, a bijector class defines these three operations, from which it is sufficient to perform sampling and scoring. We should think of a bijector as a plan to construct a normalizing flow rather than the normalizing flow itself - it requires being instantiated with a concrete base distribution supplying the relevant shape information,

dist_y = dist.Flow(dist_x, bijector)

This statement returns the object dist_y of type flowtorch.distributions.Flow representing an object that has an interface compatible with torch.distributions.Distribution. We are able to sample and score from dist_y object using its methods .sample, .rsample, and .log_prob.

Now, plotting samples from both the base and transformed distributions to verify that we that have produced the log-normal distribution:

plt.subplot(1, 2, 1)
plt.hist(dist_x.sample([1000]).numpy(), bins=50)
plt.title('Standard Normal')
plt.subplot(1, 2, 2)
plt.hist(dist_y.sample([1000]).numpy(), bins=50)
plt.title('Standard Log-Normal')
plt.show()

Our example uses a single transform. However, we can compose transforms to produce more expressive distributions. For instance, if we apply an affine transformation we can produce the general log-normal distribution,
XโˆผN(0,1)Y=exp(ฮผ+ฯƒX).\begin{aligned} X &\sim \mathcal{N}(0,1)\\ Y &= \text{exp}(\mu+\sigma X). \end{aligned}

or rather, YโˆผLogNormal(ฮผ,ฯƒ2)Y\sim\text{LogNormal}(\mu,\sigma^2). In FlowTorch this is accomplished, e.g. for ฮผ=3,ฯƒ=0.5\mu=3, \sigma=0.5, as follows:

bijectors = bij.Compose([
bij.AffineFixed(loc=3, scale=0.5),
bij.Exp()])
dist_y = dist.Flow(dist_x, bijector)

plt.subplot(1, 2, 1)
plt.hist(dist_x.sample([1000]).numpy(), bins=50)
plt.title('Standard Normal')
plt.subplot(1, 2, 2)
plt.hist(dist_y.sample([1000]).numpy(), bins=50)
plt.title('Log-Normal')
plt.show()

The class bij.Compose combines multiple Bijectors with function composition to produce a single plan for a Normalizing Flow, which is then intiated in the regular way. For the forward operation, transformations are applied in the order of the list. In this case, first AffineFixed is applied to the base distribution and then Exp.

Learnable Univariate Bijectors

Having introduced the interface for bijections and transformed distributions, we now show how to represent learnable transforms and use them for density estimation. Our dataset in this section and the next will comprise samples along two concentric circles. Examining the joint and marginal distributions:

import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler

n_samples = 1000
X, y = datasets.make_circles(n_samples=n_samples, factor=0.5, noise=0.05)
X = StandardScaler().fit_transform(X)

plt.title(r'Samples from $p(x_1,x_2)$')
plt.xlabel(r'$x_1$')
plt.ylabel(r'$x_2$')
plt.scatter(X[:,0], X[:,1], alpha=0.5)
plt.show()

plt.subplot(1, 2, 1)
sns.distplot(X[:,0], hist=False, kde=True,
bins=None,
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 2})
plt.title(r'$p(x_1)$')
plt.subplot(1, 2, 2)
sns.distplot(X[:,1], hist=False, kde=True,
bins=None,
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 2})
plt.title(r'$p(x_2)$')
plt.show()

We will learn the marginals of the above distribution using a learnable transform, bij.Spline, defined on a two-dimensional input:

dist_x = torch.distributions.Independent(
torch.distributions.Normal(torch.zeros(2), torch.ones(2)),
1
)
bijector = bij.Spline()
dist_y = dist.Flow(dist_x, bijector)

bij.Spline passes each dimension of its input through a separate monotonically increasing function known as a spline. From a high-level, a spline is a complex parametrizable curve for which we can define specific points known as knots that it passes through and the derivatives at the knots. The knots and their derivatives are parameters that can be learnt, e.g., through stochastic gradient descent on a maximum likelihood objective, as we now demonstrate:

optimizer = torch.optim.Adam(dist_y.parameters(), lr=1e-2)
for step in range(steps):
optimizer.zero_grad()
loss = -dist_y.log_prob(X).mean()
loss.backward()
optimizer.step()

if step % 200 == 0:
print('step: {}, loss: {}'.format(step, loss.item()))
step: 0, loss: 2.682476758956909
step: 200, loss: 1.278384804725647
step: 400, loss: 1.2647961378097534
step: 600, loss: 1.2601449489593506
step: 800, loss: 1.2561875581741333
step: 1000, loss: 1.2545257806777954

Plotting samples drawn from the transformed distribution after learning:

X_flow = dist_y.sample(torch.Size([1000,])).detach().numpy()
plt.title(r'Joint Distribution')
plt.xlabel(r'$x_1$')
plt.ylabel(r'$x_2$')
plt.scatter(X[:,0], X[:,1], label='data', alpha=0.5)
plt.scatter(X_flow[:,0], X_flow[:,1], color='firebrick', label='flow', alpha=0.5)
plt.legend()
plt.show()

plt.subplot(1, 2, 1)
sns.distplot(X[:,0], hist=False, kde=True,
bins=None,
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 2},
label='data')
sns.distplot(X_flow[:,0], hist=False, kde=True,
bins=None, color='firebrick',
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 2},
label='flow')
plt.title(r'$p(x_1)$')
plt.subplot(1, 2, 2)
sns.distplot(X[:,1], hist=False, kde=True,
bins=None,
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 2},
label='data')
sns.distplot(X_flow[:,1], hist=False, kde=True,
bins=None, color='firebrick',
hist_kws={'edgecolor':'black'},
kde_kws={'linewidth': 2},
label='flow')
plt.title(r'$p(x_2)$')
plt.show()

As we can see, we have learnt close approximations to the marginal distributions, p(x1),p(x2)p(x_1),p(x_2). It would have been challenging to fit the irregularly shaped marginals with standard methods, for example, a mixture of normal distributions. As expected, since there is a dependency between the two dimensions, we do not learn a good representation of the joint, p(x1,x2)p(x_1,x_2). In the next section, we explain how to learn multivariate distributions whose dimensions are not independent.