Your First Flow
The Task
Let's begin training our first Normalizing Flow with a simple example! The target distribution that we intend to learn is,
that is, a linear transformation of an standard multivariate normal distribution. The base distribution is,
that is, standard normal noise (which is typical for Normalizing Flows). The task is to learn some bijection so that
approximately holds. We will define our Normalizing Flow, by a single affine transformation,
In this notation, , denotes element-wise multiplication, and the parameters are the scalars and the parameters of the neural networks and . (Think of the NNs as very simple shallow feedforward nets in this example.) This is an example of Inverse Autoregressive Flow.
There are several metrics we could use to train to be close in distribution to . First, let us denote the target distribution of by and the learnable distribution of the normalizing flow, , as (in the following sections, we will explain how to calculate from ). Let's use the forward KL-divergence,
where C is a constant that does not depend on . In practice, we draw a finite sample, , from and optimize a Monte Carlo estimate of the KL-divergence with stochastic gradient descent so that the loss is,
So, to summarize, the task at hand is to learn how to transform standard bivariate normal noise into another bivariate normal distribution using an affine transformation, and we will do so by matching distributions with the KL-divergence metric.
Implementation
First, we import the relevant libraries:
import torch
import flowtorch.bijectors as bij
import flowtorch.distributions as dist
The base and target distributions are defined using standard PyTorch:
base_dist = torch.distributions.Independent(
torch.distributions.Normal(torch.zeros(2), torch.ones(2)),
1
)
target_dist = torch.distributions.Independent(
torch.distributions.Normal(torch.zeros(2)+5, torch.ones(2)*0.5),
1
)
Note the use of torch.distributions.Independent
so that our base and target distributions are vector valued.
We can visualize samples from the base and target:
A Normalizing Flow is created in two steps. First, we create a "plan" for the flow as a flowtorch.bijectors.Bijector
object,
# Lazily instantiated flow
bijectors = bij.AffineAutoregressive()
This plan is then made concrete by combining it with the base distributions, which provides the input shape, and constructing a flowtorch.distributions.Flow
object, and extension of torch.distributions.Distribution
:
# Instantiate transformed distribution and parameters
flow = dist.Flow(base_dist, bijectors)
At this point, we have an object, flow
, for the distribution, , that follows the standard PyTorch interface. Therefore, it can be trained with the following code, which will be familiar for readers who have used torch.distributions
before:
# Training loop
opt = torch.optim.Adam(flow.parameters(), lr=5e-3)
for idx in range(3001):
opt.zero_grad()
# Minimize KL(p || q)
y = target_dist.sample((1000,))
loss = -flow.log_prob(y).mean()
if idx % 500 == 0:
print('epoch', idx, 'loss', loss)
loss.backward()
opt.step()
Note how we obtain the learnable parameters of the normalizing flow from the flow
object, which is a torch.nn.Module
. Visualizing samples after learning, we see that we have been successful in matching the target distribution:
Discussion
This simple example illustrates a few important points of FlowTorch's design:
Firstly, Bijector
objects are agnostic to their shape. A Bijector
object specifies how the shape is changed by the forward and inverse operations, and then calculates the exact shapes when it obtains knowledge of the base distribution, when flow = dist.Flow(base_dist, bijectors)
is run. Any neural networks or other parametrized functions, which also require this shape information, are not instantiated until the same moment. In this sense, a Bijector
can be thought of as a lazy plan for creating a normalizing flow. The advantage of doing things this way is that the shape information can be "type checked" and does not need to be specified in multiple locations (ensuring these quantities are consistent).
Secondly, all objects are designed to have sensible defaults. We do not need to define the conditioning network for bijectors.AffineAutoregressive
, it will use a MADE network with sensible hyperparameters and defer initialization until it later receives shape information. Thirdly, there is compatibility, in as far as is possible, with standard PyTorch interfaces such as torch.distributions
.