Learning partially-observable higher-order sequences using local and immediate credit assignment

Learning partially-observable higher-order sequences using local and immediate credit assignment

One of our key projects is a memory system that can learn to associate distant cause & effect while only using local, immediate & unsupervised credit assignment. Our approach is called RSM – Recurrent Sparse Memory. We recently uploaded a preprint describing RSM. This is the first of several blog posts about the experiments we have performed. These will include more detail than is possible in a standard paper.

Partial Observability

One of the experiments in the paper aims to show successful prediction of “partially-observable”, “higher order” sequences. In a partially-observable world, the true state is never observed directly. Instead, we observe “clues” that suggest the identity of the actual state, but don’t reveal the state completely, reliably and clearly. How do we overcome the limited information? We can use memory to store a history of “clues”, and use them all to predict the next state in a sequence.

In our partially-observable experiment we have an underlying sequence or process – this can be deterministic, or stochastic (subject to uncertainty). Since we have covered the stochastic case in other experiments (see paper), for now we will use deterministic sequences. One such sequence might be:

0,1,2,3,4,5,6,7,8,9

The task is to predict the next digit in the sequence. The sequence repeats when it reaches the end (i.e. 9). To generate a “clue” as to the identity of the current state, we use the MNIST dataset of hand-written digit images. These are relatively easy to classify, but introduce a realistic element of observational uncertainty. Rather than giving the label of a digit to our algorithm, we select a random exemplar of that class from the MNIST dataset and provide that, instead of the label.

Higher-Order Sequences

A first-order sequence is one where the next digit can be predicted from the current digit. For example, 0,1,2,3 is a first-order sequence.

But what about the sequence 0,1,0,2 ? Now we must remember a history of 2 digits to successfully predict the next digit (e.g. [1,0] → 2, and [2,0] → 1). We say this is a 2nd order sequence.

We can make the problem harder by increasing the length of context required to make successful predictions. For example, in our paper we used sequences like:

0,1,2,3,4, 0,1,2,3,4, 0,4,3,2,1

These are tricky for two reasons. First, we must remember a history of 10 digits to successfully predict the third zero. Second, we have deliberately used the same digits in repeating sub-sequences to misleadingly suggest a lower-order, repeating pattern.

Fig. 1: Unlike Stephen Hawking, we find time-travel unacceptable. Hawking invited time travellers to a party, only advertising the event after it had finished. No-one showed up.

Credit assignment constraints

In summary, the task is then to predict the label of the next digit given only a sequence of randomly selected images of prior digits. The underlying sequences are therefore partially observable and higher-order. As stated in our paper, we also set ourselves a “biological plausibility” constraint: use of only local and immediate credit assignment.

The most popular approach to artificial neural network sequence learning requires backpropagation through time (BPTT). When a prediction loss occurs, we work it backwards through historical states of the network until we are able to tweak weights that could’ve made the eventual prediction better. The biologically implausible aspect of this is that we either perform time-travel to do this, or we need to somehow remember the historical state of every synapse in the brain. Unlike Stephen Hawking (see fig. 1), we find neither acceptable.

In the paper we briefly explore the alternatives to back-propagation through time, notably autoregressive feed-forward networks given a fixed, finite historical context as input. But to cut a long story short, for our purposes these are also unacceptable.

Recurrent Sparse Memory approach

We won’t go into the details of the RSM algorithm here, but we’ll elaborate on the conditions in which it is trained and tested in this experiment. Figure 2 shows how RSM is trained on a sequence of MNIST images. RSM never sees any labels, nor are labels provided as input to the loss function. Instead, RSM simply tries to minimize the mean-square error between predicted and observed images.

RSM is a recurrent neural network (RNN), similar to an Elman architecture. Its hidden activity state is fed back in as a recurrent input, which allows it to exploit historical context. But note that the actual labels are never incorporated in this context – only MNIST images.

Fig. 2: RSM training regime. For each occurrence of each digit, we draw a random sample from all images of that class in the MNIST dataset. The image is then provided as input to RSM. RSM generates a prediction of the next image – not the next label – and is trained to minimize the mean-square-error between the predicted and actual next image. RSM never observes the labels from the sequence even as a loss input. To allow the RSM to exploit historical context, the RSM has a recurrent input of its own hidden state.

Obtaining classifications

When trained as described above, RSM produces images of the appearance of the next digit. We can inspect them individually to verify that they are “correct” (i.e. generic versions of the correct next digit). We can also measure the MSE loss – very low loss values indicate good agreement between the predicted and observed digit appearance. But we would like to measure classification accuracy directly.

Our paper describes how we do this by adding a “bolt-on” predictor network. The predictor is a fully-connected 2-layer network. Its input is the RSM hidden state. The output is a classification over the 10 MNIST classes, so we train it with a softmax cross-entropy loss. The hidden layer nonlinearity is a leaky-ReLU. We do not allow gradients to propagate into the RSM. The predictor is trained to predict the next digit class from the current RSM hidden state; it never receives a label as input, but labels are used in the predictor loss function (see figure 3).

Fig. 3: Training a predictor network to generate the label of the next digit in the sequence given only the current hidden state of the RSM layer. To obtain a classification, we must compare predicted labels and actual labels with a supervised loss function (see box marked ‘Loss’). The predictor does not receive labels as an input, nor anything derived from a label. The predictor is a fully-connected two-layer network.

Results

We tested several sequences and observed that the predictor achieves >99% accuracy. Of course, given sufficiently long and/or complex sequences this accuracy would decrease. Occasional sequences of very unusual images throw it off, but it recovers rapidly. RSM alone is similarly accurate; this can be manually verified and/or inferred from the very low MSE loss (i.e. RSM can predict, the additional network translates the prediction into a label).

RSM never has access to any labels. The predictor only uses RSM state as input, but is trained to maximize classification performance on the labels. Error gradients do not propagate into the RSM layer from the predictor. Therefore we have achieved the goal of learning higher-order, partially-observable sequences using only local and immediate credit assignment.

RSM predictions can be viewed as images. In sequences with genuine uncertainty regarding the next digit, or early during training before RSM has discovered all the higher-order sequence structure, uncertain predictions look like superpositions of likely digits. Due to uncertainty in the appearance of any given digit, RSM predicts a “generic” digit form. Examples can be seen in the figures above.

Further research

It would be interesting to combine the partially observable nature of this test with stochastic sequences such as used in the Embedded Reber Grammar task. Using a grammar, we can define the complexity and predictability of generated sequences. The only limitation is being restricted to a set of 10 possible labels. The NIST SD-19 dataset is similar to MNIST, but includes the alphabetic characters A-Z. MNIST + SD19 would offer a set of 36 unique labels. We performed some preliminary experiments of this type and note that RSM seems to be able to cope with genuine uncertainty as you’d expect – it predicts all likely options. The predictor generates the correct multi-modal distributions.

There’s no established benchmark for the stochastic MNIST sequences described above. The problem is open-ended: We could keep increasing the difficulty indefinitely. For this reason, we are also now looking at the Moving-MNIST benchmark:

http://www.cs.toronto.edu/~nitish/unsupervised_video/

In this test, a pair of MNIST digits bounce around a box. Digit appearance is drawn randomly from MNIST. Digits vary in speed and direction which yields a large variety of trajectories over time. The objective is to predict the sequence many steps ahead, without losing track of digit identity and correctly predicting the motion model, including bouncing off the walls. Digits transiently occlude each other as well!

Fig. 4: Moving MNIST dataset. A pair of MNIST digits bounce around a box. The task is to predict – generate – the appearance of the digits many steps into the future.


David Rawlinson

https://agi.io