Skip to main content

Shapes

One of the advantages of using FlowTorch is that we have carefully thought out how shape information is propagated from the base distribution through the sequence of bijective transforms. Before we explain how shapes are handled in FlowTorch, let us revisit the shape conventions shared across PyTorch and TensorFlow.

Shape Conventions

FlowTorch shares the shape conventions of PyTorch's torch.distributions.Distribution and TensorFlow's tfp.distributions.Distribution for representing random distributions. In these conventions, the shape of a tensor sampled from a random distribution is divided into three parts: the sample shape, the batch shape, and the event shape.

As described in the TensorFlow documentation,

  • Event shape describes the shape of a single draw from the distribution, which may or may not be dependent across dimensions.
  • Batch shape describes independent, not identically distributed draws, that is, a "batch" of distributions.
  • Sample shape describes independent, identically distributed draws of batches from the distribution family.

Examples

This is best illustrated with some simple examples. Let's begin with a standard normal distribution:

import torch
import torch.distributions as dist
d = dist.Normal(loc=0, scale=1)
sample_shape = torch.Size([])

assert d.event_shape == torch.Size([])
assert d.batch_shape == torch.Size([])
assert d.sample(sample_shape).shape == torch.Size([])

In this example, we have a single scalar normal distribution from which we draw a scalar sample. Since it is a scalar distribution, the event_shape == torch.Size([]). Since it is a single distribution, batch_shape == torch.Size([]). And we draw a scalar sample since sample_shape == torch.Size([]).

Note that the event shape and batch shape are properties of the distribution itself, whereas the sample shape depends on the size argument passed to Distribution.sample or Distribution.rsample. Also, the shape of d.sample(sample_shape) is the concatenation of the sample_shape, batch_shape, and event_shape, in that order.

Let's look at another example:

d = dist.Normal(loc=torch.zeros(1), scale=torch.ones(1))
sample_shape = torch.Size([2])

assert d.event_shape == torch.Size([])
assert d.batch_shape == torch.Size([1])
assert d.sample(sample_shape).shape == torch.Size([2, 1])

In this case, event_shape = torch.Size([]) since we have a scalar distribution, but batch_shape = torch.Size([1]) since we have tensor of parameters of that shape defining the distribution. Also, sample_shape = torch.Size([2]) so that d.sample(sample_shape).shape = torch.Size([2, 1]).

A further example:

d = dist.Normal(loc=torch.zeros(2, 5), scale=torch.ones(2, 5))
sample_shape = torch.Size([3, 4])

assert d.event_shape == torch.Size([])
assert d.batch_shape == torch.Size([2, 5])
assert d.sample(sample_shape).shape == torch.Size([3, 4, 2, 5])

We see that batch shapes, sample shapes (and event shapes) can have an arbitrary number of dimensions are are not restricted to being vectors.

Is the event shape always torch.Size([])? This is not true for multivariate distributions, that is, distributions over vectors, matrices, and higher-order tensors that can have dependencies across their dimensions. For example:

d = dist.MultivariateNormal(loc=torch.zeros(2, 5), covariance_matrix=torch.eye(5))
sample_shape = torch.Size([3, 4])

assert d.event_shape == torch.Size([5])
assert d.batch_shape == torch.Size([2])
assert d.sample(sample_shape).shape == torch.Size([3, 4, 2, 5])

Note that the covariance_matrix tensor will be broadcast across loc. Whereas the previous example defined a matrix batch of scalar normal distributions, this example defines a vector batch of multivariate normal distributions. This is an important distinction!

See this page for further explanation on shape conventions.

Non-conditional Transformed Distributions

How do shapes work for transformed distributions that do not condition on a context variable, that is, distributions of the form pθ(x)p_\theta(\mathbf{x})? The sample shape depends strictly on the input to .sample or .rsample and so we restrict our attention to the batch and event shapes.

Returning to the diagram on the intro page, suppose the base distribution is p0p_0, and the distribution after applying the the initial bijection, f1f_1, is p1p_1. Denote by z0z_0 a sample from the base distribution and z1=f1(z0)z_1=f_1(z_0). We make a few observations:

Firstly, since f1f_1 is a bijection, z0z_0 must have the same number of dimensions as z1z_1. In our shape terminology, the sum of the event shape of the base distribution must be the same as the sum of the event shape of the transformed one.

Secondly, the batch shape is preserved from the base distribution to transformed one. By convention, we assume that a single bijection, f1f_1, is applied to a batch of base distributions, {p0,i}\{p_{0,i}\}, to produce a batch of the same shape of transformed distributions, {p1,i}\{p_{1,i}\}.

Thirdly, the event shape of the base distribution must be compatible with the domain of the bijection. For instance, if the base distribution has event shape torch.Size([]) and is a scalar, it does not make sense to applied a bijection on matrices with, e.g., Dom[f1]Rn×m\text{Dom}[f_1]\subseteq \mathbb{R}^{n\times m}.

Given a base distribution, base, and a non-conditional bijector bijector, the pseudo-code to calculate the batch and event shape of the transformed distribution, flow, looks like this:

# Input event shape must have at least as many dimensions as that which bijector operates over
assert len(base.event_shape) >= bijector.domain.event_dim

flow.batch_shape = base.batch_shape
flow.event_shape = bijector.forward_shape(base.event_shape)

# bijector.forward_shape and bijector.codomain.event_dim must be consistent
assert len(flow.event_shape) >= bijector.codomain.event_dim

# bijectors preserve dimensions
assert sum(flow.event_shape) == sum(base.event_shape)

The bijector class defines the number of dimensions that it operates over in bijector.domain.event_dim and bijector.codomain.event_dim, and has a method bijector.forward_shape that specifies how the event shape of the input relates to that of the output. (In most cases, this will be the identity function.)

This information is sufficient to construct the batch and event shapes of the transformed distribution from the base. For a Normalizing Flow that is the composition of multiple bijections, we apply this logic in succession, using the transformed distribution of the previous step as the base distribution of the next.