# Shapes

One of the advantages of using FlowTorch is that we have carefully thought out how shape information is propagated from the base distribution through the sequence of bijective transforms. Before we explain how shapes are handled in FlowTorch, let us revisit the shape conventions shared across PyTorch and TensorFlow.

## Shape Conventions

FlowTorch shares the shape conventions of PyTorch's `torch.distributions.Distribution`

and TensorFlow's `tfp.distributions.Distribution`

for representing random distributions. In these conventions, the shape of a tensor sampled from a random distribution is divided into three parts: the *sample shape*, the *batch shape*, and the *event shape*.

As described in the TensorFlow documentation,

- Event shape describes the shape of a single draw from the distribution, which may or may not be dependent across dimensions.
- Batch shape describes independent, not identically distributed draws, that is, a "batch" of distributions.
- Sample shape describes independent, identically distributed draws of batches from the distribution family.

## Examples

This is best illustrated with some simple examples. Let's begin with a standard normal distribution:

`import torch`

import torch.distributions as dist

d = dist.Normal(loc=0, scale=1)

sample_shape = torch.Size([])

assert d.event_shape == torch.Size([])

assert d.batch_shape == torch.Size([])

assert d.sample(sample_shape).shape == torch.Size([])

In this example, we have a single scalar normal distribution from which we draw a scalar sample. Since it is a scalar distribution, the `event_shape == torch.Size([])`

. Since it is a single distribution, `batch_shape == torch.Size([])`

. And we draw a scalar sample since `sample_shape == torch.Size([])`

.

Note that *the event shape and batch shape are properties of the distribution itself*, whereas the sample shape depends on the size argument passed to `Distribution.sample`

or `Distribution.rsample`

. Also, the shape of `d.sample(sample_shape)`

is the concatenation of the `sample_shape`

, `batch_shape`

, and `event_shape`

, in that order.

Let's look at another example:

`d = dist.Normal(loc=torch.zeros(1), scale=torch.ones(1))`

sample_shape = torch.Size([2])

assert d.event_shape == torch.Size([])

assert d.batch_shape == torch.Size([1])

assert d.sample(sample_shape).shape == torch.Size([2, 1])

In this case, `event_shape = torch.Size([])`

since we have a scalar distribution, but `batch_shape = torch.Size([1])`

since we have tensor of parameters of that shape defining the distribution. Also, `sample_shape = torch.Size([2])`

so that `d.sample(sample_shape).shape = torch.Size([2, 1])`

.

A further example:

`d = dist.Normal(loc=torch.zeros(2, 5), scale=torch.ones(2, 5))`

sample_shape = torch.Size([3, 4])

assert d.event_shape == torch.Size([])

assert d.batch_shape == torch.Size([2, 5])

assert d.sample(sample_shape).shape == torch.Size([3, 4, 2, 5])

We see that batch shapes, sample shapes (and event shapes) can have an arbitrary number of dimensions are are not restricted to being vectors.

Is the event shape always `torch.Size([])`

? This is not true for *multivariate* distributions, that is, distributions over vectors, matrices, and higher-order tensors that can have dependencies across their dimensions. For example:

`d = dist.MultivariateNormal(loc=torch.zeros(2, 5), covariance_matrix=torch.eye(5))`

sample_shape = torch.Size([3, 4])

assert d.event_shape == torch.Size([5])

assert d.batch_shape == torch.Size([2])

assert d.sample(sample_shape).shape == torch.Size([3, 4, 2, 5])

Note that the `covariance_matrix`

tensor will be broadcast across `loc`

. *Whereas the previous example defined a matrix batch of scalar normal distributions, this example defines a vector batch of multivariate normal distributions.* This is an important distinction!

See this page for further explanation on shape conventions.

## Non-conditional Transformed Distributions

How do shapes work for transformed distributions that do not condition on a context variable, that is, distributions of the form $p_\theta(\mathbf{x})$? The sample shape depends strictly on the input to `.sample`

or `.rsample`

and so we restrict our attention to the batch and event shapes.

Returning to the diagram on the intro page, suppose the base distribution is $p_0$, and the distribution after applying the the initial bijection, $f_1$, is $p_1$. Denote by $z_0$ a sample from the base distribution and $z_1=f_1(z_0)$. We make a few observations:

Firstly, since $f_1$ is a bijection, $z_0$ must have the same number of dimensions as $z_1$. In our shape terminology, the sum of the event shape of the base distribution must be the same as the sum of the event shape of the transformed one.

Secondly, the batch shape is preserved from the base distribution to transformed one. *By convention, we assume that a single bijection, $f_1$, is applied to a batch of base distributions, $\{p_{0,i}\}$, to produce a batch of the same shape of transformed distributions, $\{p_{1,i}\}$.*

Thirdly, the event shape of the base distribution must be compatible with the domain of the bijection. For instance, if the base distribution has event shape `torch.Size([])`

and is a scalar, it does not make sense to applied a bijection on matrices with, e.g., $\text{Dom}[f_1]\subseteq \mathbb{R}^{n\times m}$.

Given a base distribution, `base`

, and a non-conditional bijector `bijector`

, the pseudo-code to calculate the batch and event shape of the transformed distribution, `flow`

, looks like this:

`# Input event shape must have at least as many dimensions as that which bijector operates over`

assert len(base.event_shape) >= bijector.domain.event_dim

flow.batch_shape = base.batch_shape

flow.event_shape = bijector.forward_shape(base.event_shape)

# bijector.forward_shape and bijector.codomain.event_dim must be consistent

assert len(flow.event_shape) >= bijector.codomain.event_dim

# bijectors preserve dimensions

assert sum(flow.event_shape) == sum(base.event_shape)

The `bijector`

class defines the number of dimensions that it operates over in `bijector.domain.event_dim`

and `bijector.codomain.event_dim`

, and has a method `bijector.forward_shape`

that specifies how the event shape of the input relates to that of the output. (In most cases, this will be the identity function.)

This information is sufficient to construct the batch and event shapes of the transformed distribution from the base. For a Normalizing Flow that is the composition of multiple bijections, we apply this logic in succession, using the transformed distribution of the previous step as the base distribution of the next.