A Gentle Introduction to Deep Sets and Neural Processes

In this post, I will discuss two topics that I have been thinking a lot about recently: Deep Sets and Neural Processes. I'll lay out what I see as the important bits of these models, and hopefully provide some intuition as to why they are useful. My motivation will, in particular, focus on meta-learning. I'll start with a brief introduction to meta-learning, which will serve the dual purpose of (i) motivating everything I'll talk about later, and (ii) will allow me to introduce some terminology and objects that we'll need for the rest of the post.

Quick Background on Meta-Learning

Let's first recall what the (supervised) meta-learning setting is. In the standard supervised learning setting, we are often interested in learning / approximating a function ($f$) that maps inputs ($x$) to outputs ($y$). A supervised learning algorithm ($L$) may be considered an algorithm that, given a dataset of such input-output pairs, returns a function approximator ($\hat f$). If $L$ is a good algorithm, $\hat f \approx f$ in some meaningful sense.

Fig. 1: Observations from a nonlinear function are passed to a learning algorithm. In this visualization I am using a Gaussian Process (GP) as the learning algorithm $L$. GPs being probabilistic models, this produces a distribution over $\hat{f}$.

In the meta-learning setting, rather than assuming we have access to one such (potentially very large) dataset, we assume that our dataset is comprised of many tasks, each containing a context set ($D_c$) and a target set ($D_\tau$). Each of these, in turn, contain a variable number of input-output pairs. Our assumption is that while the mapping between inputs and outputs may differ across tasks, the tasks share some statistical properties that, when modelled appropriately, should improve the overall performance of the learning algorithms.

So the goal of meta-learning is to learn to produce the black-box algorithm that maps from datasets to function approximators. In other words, our goal is to use our dataset of tasks to train a model that, at test time accepts new training sets, and produces function approximators that perform well on unseen data from the same task. Intuitively, the meta-learning algorithm learns a learning algorithm that is appropriate for all of the observed tasks. A good meta-learning algorithm results in a learning algorithm that has desirable properties for tasks similar to those in our training set (e.g., produce good function approximators, are sample-efficient, etc’). This is where the popular description of meta-learning as learning to learn comes from.

Fig. 2: A large collection of few-shot tasks is provided during meta-training. The meta-learner learns to map few-shot tasks to meaningful predictive distributions.

One of the most compelling motivations for meta-learning is data efficiency. Neural networks notoriously require large datasets to learn from. However, it has been observed1 that humans, for example, are able to learn from just a handful of examples. This is a major difference between human intelligence and our current machine learning systems, and an extremely attractive feature to aspire for with our learning systems. Sample-efficient learners would be wildly useful in many applications such as robotics, reinforcement learning, and others. This line of thinking has lead to the sub-field of research called few-shot learning. In this setting, models are only provided a handful of examples of the task they are required to perform. Meta-learning is a particularly successful approach to designing models that can achieve this.

Fig. 3: Two examples of few-shot learning problems. (Left) An example image is shown (in the red box). All images of the same object must be identified from the set on the bottom. (Right) Same setup, now with characters from an alphabet. Images borrowed from [1].

The Neural Process Family

Let us know focus our attention on the Neural Process (NP) family2. NPs are a recently introduced family of models for probabilistic meta-learning. While there are interesting latent variable variants, in this blog post I'll focus on conditional neural processes (CNPs), which are a simple deterministic variant of NPs.

CNPs leverage a series of mappings to directly model the predictive distribution at some location of interest ($x_t$) conditioned on a context set. Mathematically, for a Gaussian predictive likelihood we can express this as

$$ p(y_t | x_t, D_c) = \mathcal{N} \left(y_t; \mu_\theta(x_t, D_c), \sigma^2_\theta(x_t, D_c) \right), \tag{1}\label{eq1} $$

where $x_t$ is the input we wish to make a prediction at, and $D_c$ is the context set to condition on. The computational structure of NPs is best understood from an encoder-decoder perspective: computationally, we can visualize this model as follows:

Fig. 3: Computational diagram of CNPs. On the left, $(x_i, y_i)$ represent the context set. The $e$ block is the encoder, the $a$-circle the pooling operation, and the $d$ block is the decoder. Output is a Gaussian distribution over $y_t$ for the $x_t$ passed to the decoder. Image borrowed from the DeepMind NP repository3.

Note the important form of the encoder: it first applies the same mapping ($e$) to each of the $(x_i, y_i)$ pairs in the context set, producing a representation ($r_i$). A natural choice is to implement $e$ with a neural network. Then, a pooling operation $a$ is applied to the representations, yielding a single, vector-valued representation of the context set, denoted $r$. This pooling operation is very important, and we'll return to it later. Finally, to make predictions at $x_t$, we concatenate it to $r$, and pass this through our decoder, $d$, which maps to a mean and variance for the predictive distribution. A natural choice for $d$ is also a neural network.

That's it! Given a dataset of tasks, we can now train the parameters of the model (the weights $\theta$ of the encoder and decoder) to maximize the log-likelihood i.e., the $\log$ of Eq. $\eqref{eq1}$: $$ \theta^\ast = \underset{\theta \in \Theta}{\text{argmax}}\ \mathbb{E}_\tau \left[ \sum \log p \left(y | x, D_c \right) \right], $$

where we can take the expectation over tasks $\tau$ with our dataset, and the inner summation is over the $(x_t, y_t)$ pairs in the target set. In practice, at every iteration we sample a batch of tasks from our dataset of tasks, and (if they are not already) partition them into context and target sets. We then pass each context set through $e$ to produce $r$. Each $x_t$ in the target sets is concatenated with $r$, and this is passed through the decoder to get the predictive distribution. The objective function for the iteration is then the sum over the likelihoods in the target sets, averaged across the tasks. This is fully differentiable wrt the model parameters, so we can use standard optimizers and auto-differentiation to train the model.

Below, I show the training progression of a simple CNP trained on samples from a Gaussian Process with an EQ kernel after 0, 20, and 40 epochs of training.

Fig. 4: Snapshots from CNP training procedure after 0, 20, and 40 epochs. Each epoch iterates over 1024 randomly sampled tasks.

One thing you might immediately notice is that the CNP seems to underfit compared to the oracle GP (by oracle, I mean using the ground truth kernel). On the one hand, this is true: the uncertainty is not tight, even when many data points are observed. One the other hand, the GP has far more domain knowledge (the exact kernel used to generate the data), and so it is not clear how certain the CNP really should be. Regardless, CNPs suffer from a severe under-fitting problem. In my next post, I will explore several interesting additions to the NP-family that are aimed at addressing this issue.

Deep Sets

OK, let's dive a little deeper into the inner workings of the CNP. A useful way to think about them is as follows. Our decoder is a standard neural network that maps $X \to Y$, with one small tweak: we condition it on an additional input $r$, which is a dataset-specific representation. The encoder's job is really just to embed datasets into an appropriate vector space. CNPs specify such an embedding, and provide a way to train this embedding end-to-end with a predictive model. However, one might ask – is this the “correct” form for such an embedding?

So the question that really arises is – what does a function that embeds datasets into a vector space look like? This is quite different from the standard machine learning model, where we expect inputs to be an instance from some space, typically a vector space (feature vectors, images, etc’), or at worst – sequences (e.g., for RNNs). Here what we want is a function approximator that accepts as inputs sets.

What are the properties of sets, and what properties would we want such a function to have? Well, the first problem is that sets have varying sizes, and we would like to be able to handle arbitrarily-sized sets. The second key issue is that, by definition, sets have no order. As such, any module operating on sets must be invariant to the order of the elements to be considered valid. This property is known as permutation invariance, and is a key concept in the rapidly growing literature on modelling and representation learning with sets.

Now, this is a bit of sticky situation: on the one hand, we've declared that sets have no ordering. On the other hand, I am now stating that our module must be invariant to their ordering. The resolution to this stickiness is that in practice, our algorithms must operate on representations of sets, and must process these sequentially. Thus, it would be more accurate to say that we are construcing modules that process ordered tuples, and we require that they be permutation invariant, which is equivalent to treating the ordered tuples as sets.

The $\rho$-sum Decomposition

It turns out that NPs were not the first model to bump into this question. The Neural Statistician4 and PointNet5, to name a few examples, also considered the problem of representing sets, and proposed a similar solution to CNPs. This form has been dubbed the sum decomposition, or Deep Sets. For a set $S$, it is computed as follows \begin{equation} \hat{f}(S) = \rho \left( \sum_{s \in S} \phi (s) \right),\tag{3}\label{eq3} \end{equation}

where $\phi$ maps elements of $s$ to a vector space $\mathbb{R}^D$, and $\rho$ maps elements of $\mathbb{R}^D$ to $\mathbb{R}^M$. Note that this is a more general form of the NP computational graph: in Fig. 3, $e$ plays the role of $\phi$, and $d$ is (a conditional version of) $\rho$. Here the pooling operation ($a$ in Fig. 3) is implemented as a sum. This was also the operation chosen in the Neural Statistician and PointNet. However, the role of this operation is to enforce the permutation invariance, and as such, any permutation invariant pooling operation (such as mean or max) would also be suitable in place of the sum.

Zaheer et al. (2017)6 provided a more rigorous treatment of the issue. Their key theorem demonstrates that (with a few important caveats) any function on sets has a representation of the form of Eq. $\eqref{eq3}$. Some caveats are that $\phi$ and $\rho$ must be universal approximators of their function class (motivating the use of neural networks). An important caveat is, when the elements of $S$ are drawn from an uncountable set (e.g., $\mathbb{R}^m$), the theorem was only proven for fixed-size sets. So just shy of the varying-sized requirement.

Despite the caveats on the theory, Zaheer et al. provide a remarkable result towards characterizing the approximation capabilities of Deep Set networks, and justify their usage in a wide range of machine learning problems.

Summary

I'll wrap this post up here. We have taken a first look at the Neural Process family of models for probabilistic meta-learning, and looked at the key role that representation learning on sets plays in this model class.

The modelling of sets is an important and fascinating subfield at the frontier of machine learning research. Personally, I am drawn to this topic due to its role in meta-learning, a topic which is near and dear to my heart. Indeed, one perspective on the meta / few-shot learning lines of research is as learning to flexibly condition models on small sets of data, and of course, Deep Sets have a key role to play in that perspective.

Neural processes take just this approach. Arguably the most straightforward perspective of CNPs is as simple predictors parameterized by DeepSet networks, trained with maximum-likelihood. Yet this simple approach has proven extremely powerful. The original papers show that this class of models is capable of providing strong probabilistic predictions, and perform very well in the low-data regime. Requeima et al.7 run with this view, and adapt CNPs to the large scale, few-shot image classification setting, demonstrating very strong performance (full disclosure, I am an author on that paper). Finally, there have been several recent advances in this class of models that has increased their capacity and addressed their susceptibility to underfitting. These will be the topic of my next post on Attentive8 and Convolutional9 CNPs.

References



  1. Brendan Lake et al. Human-level concept learning through probabilistic program induction. 2015 ↩︎

  2. Marta Garnelo et al. Neural Processes. 2018 ↩︎

  3. DeepMind Neural Process GitHub Repository ↩︎

  4. Harrison Edwards and Amos Storkey. Towards a neural statistician. 2016 ↩︎

  5. Charles R. Qi et al. PointNet: deep learning on point sets for 3d classification and segmentation. 2016 ↩︎

  6. Manzil Zaheer et al. Deep Sets. 2017 ↩︎

  7. James Requeima et al. Fast and flexible multi-task classification with conditional neural adaptive processes. 2019 ↩︎

  8. Hyunjik Kim et al. Attentive Neural Processes. 2019 ↩︎

  9. Jonathan Gordon et al. Convolutional Conditional Neural Processes. 2020 ↩︎

Avatar
Jonathan Gordon
Machine Learning PhD Student

My research interests include probabilistic machine learning, deep learning, and approximate Bayesian inference.