Home


The Type Theory of Transformers


3/14/2023


Next-token prediction is critical to the recent success of large language models like GPT-3. Today, we will see why next token prediction is so exciting purely through analyzing the model's function type. In particular, we will see how properties such as prompting and in-context learning can be viewed as currying.

Transformer Capabilities

Functionally, an autoregressive model takes in a series of tokens (think words) and predicts a distribution over the next word. For example, given the input "The cat in the ", a well-trained model might give a \(90\%\) chance that the next word is "hat", \(0.0001\%\) that the next word is pickle, etc. Given this distribution, one can decode an output by passing in a prompt, selecting a next token, appending it to the input, and repeating. In this article, we'll assume the next token is selected greedily by picking the likeliest next token.


The model learns this distribution by practicing next-token prediction over the entire internet. The transformer architecture (Vaswani et al, 2022) with causal masking is known for being easy to scale for this training. People also find that this architecture can compactly express complicated computation, such as parallelism (Liu et al, 2022), assembly (Giannou et al, 2023), and linear algebra (Akyürek et al, 2022).


As these transformers get better at next-word prediction, some capabilities that were not explicitly targeted qualitatively emerge. For example, if I want the model to give more accurate answers to geology queries, I can prepend "Suppose you are a leading expert in geology" to my question. To get the model to show its work and somehow provide more correct answers, I can apppend "Let's think step by step" to the question (Wei et al, 2022). If I want the model to perform a task, such as translation, I can provide a few examples and my question, such as "cat chat apple pomme hello", and get the output, such as "bonjour" (Brown et al, 2020).

Type Theory Tutorial

I'll denote a value x has type t by x : t. The tuple type of t1 and t2 is type t1 * t2, and a function type from t1 to t2 is type t1 -> t2.


Let's say you have function f : t1 * t2 -> t3 which takes in two arguments and produces an output. Sometimes, you will have the first argument available before the second argument and you'd like to partially evaluate the function on what you have. You can curry the function f into g : t1 -> (t2 -> t3). After passing in the first input, you receive a function which can take in the second value to produce the output (shown in the following python code). This allows you to pass in the first input before you have the second input. Its important to note that for every f, there exists an (extensionally) equivalent g and vice versa.

    def f(x, y):
    return x ** x + y
def g(x):
    xpow = x ** x
    def helper(y):
        return xpow + y
    return helper

There's no reason to stop at currying two arguments. For example, t1 * t2 * t3 -> t4 is isomorphic to t1 -> t2 -> t3 -> t4 (for right associative arrows).


Beyond allowing further modularity, currying can help stage computation. For example, suppose you need to perform an expensive pre-compute with the first argument. If you didn't curry this, you would have to repeat the pre-compute every time you called f(x, y). With currying, we can first construct f_with_x = g x. For subsequent calls that have x as the first argument, we can simply call f_with_x y as many times as we'd like.

Partially Evaluated Transformers

A typical neural network like a multi-layer perceptron might map from \(\mathbb{R}^d\) to \(\mathbb{R}^k\) and have type MLP = real * real * ... * real -> real * ... * real. Typically, when I have some new information, I have to take my original_NN and train it on the new data to produce a completely new finetuned_NN. This process is costly and annoying; moreover, to share the model, I have to send over the entire function finetuned_NN.


However, the specific emergent capabilities we discussed earlier defy this paradigm. We know that a transformer with greedy decoding has type transformer = token list -> token. The killer property of this type is partial evaluation. Suppose I have new information in prompt : token list. I can write a new function prompted_transformer L = original_transformer (prompt concat L). When called, this prompted_transformer can utilize this new information without any training or weight updates whatsoever! Moreover, if I want to share prompted_transformer to somebody, I only have to share prompt with somebody who has a copy of original_transformer! This means that the entire execution environment is compactly encoded as a string rather than billions of model weights. The prompting paradigm uniquely enables this "stateless fine-tuning".

Infinite Currying??

All of this actually emulates the power of currying. However, there is a very important difference. In typical currying, since tuples have fixed size, you have a fixed number of arguments you can partially evaluate before producing an output. However, with a list as an input, you can successively curry as much information as you'd like, leading to infinite currying?!?


In particular, the input list L : token list could be a length 2 tuple, a length 4 tuple, or any other length. So your transformer has to support token * token -> token and token * token * token * token -> token (mathematically shown in the next section). We can curry these functions to get token -> token -> token and token -> (token * token) -> token -> token and a bunch of other whacky in-betweens. Therefore, as soon as you write an interesting function of type token list -> token, you automatically create an infinite collection of curried functions that represent any partial evaluation of any token count.

"Math"

Here, I'll make some fun remarks that use type algebra and isomorphisms, which you can learn about from these slides. We can see that token list \(\;\cong\;\) unit + token + token * token + ... either by casing on the length of the list or directly from the generating function of the list type. We also know that the type a -> b has cardinality \(|b|^{|a|}\). Putting these together with serious notational abuse,


token list -> token

\(\cong\;\) token ^ (unit + token + token * token + ...)

\(\cong\;\) token ^ unit * token ^ token * token ^ (token * token) * ...

\(\cong\;\) (unit -> token) * (token -> token) * (token * token -> token) * ...


This spells out the intuition that writing a function from token list -> token is as powerful as writing a function from every tuple length to token!


Through seeing the flexibility of the type, we can get a better picture of why its so impressive to get a model with impressive next-token prediction. Thank you for reading, and feel free to reach out with any questions or thoughts!