Hacker News new | past | comments | ask | show | jobs | submit login

With a quick skim, the paper delivers on its promise. It's not a particularly long or difficult paper to follow.

> Causal tracing. The transformer could be viewed as a causal graph that propagates information from the input to the output through a grid of intermediate states, which allows for a variety of causal analyses on its internal computation

> [...] There are in total three steps:

> 1. The normal run records the model’s hidden state activations on a regular input [...]

> 2. In the perturbed run, a slightly perturbed input is fed to the model which changes the prediction, where again the hidden state activations are recorded. [...] Specifically, for the hidden state of interest, we replace the input token at the same position as the state to be a random alternative of the same type (e.g., r1 → r′1) that leads to a different target prediction (e.g., t → t′).

> 3. Intervention. During the normal run, we intervene the state of interest by replacing its activation with its activation in the perturbed run. We then run the remaining computations and measure if the target state (top-1 token through logit lens) is altered. The ratio of such alterations (between 0 and 1) quantitatively characterizes the causal strength between the state of interest and the target.

> The generalizing circuit. [...] The discovered generalizing circuit (i.e., the causal computational pathways after grokking) is illustrated in Figure 4(a). Specifically, we locate a highly interpretable causal graph consisting of states in layer 0, 5, and 8, [...]. Layer 5 splits the circuit into lower and upper layers, where 1) the lower layers retrieve the first-hop fact (h, r1, b) from the input h, r1, store the bridge entity b in S[5, r1], and “delay” the processing of r2 to S[5, r2]; 2) the upper layers retrieve the second-hop fact (b, r2, t) from S[5, r1] and S[5, r2], and store the tail t to the output state S[8, r2].

> What happens during grokking? To understand the underlying mechanism behind grokking, we track the strengths of causal connections and results from logit lens across different model checkpoints during grokking (the “start” of grokking is the point when training performance saturates). We observe two notable amplifications (within the identified graph) that happen during grokking. The first is the causal connection between S[5, r1] and the final prediction t, which is very weak before grokking and grows significantly during grokking. The second is the r2 component of S[5, r2] via logit lens, for which we plot its mean reciprocal rank (MRR). Additionally, we find that the state S[5, r1] has a large component of the bridge entity b throughout grokking. These observations strongly suggest that the model is gradually forming the second hop in the upper layers (5-8) during grokking. This also indicates that, before grokking, the model is very likely mostly memorizing the examples in train_inferred by directly associating (h, r1, r2) with t, without going through the first hop

> Why does grokking happen? These observations suggest a natural explanation of why grokking happens through the lens of circuit efficiency. Specifically, as illustrated above, there exist both a memorizing circuit Cmem and a generalizing circuit Cgen that can fit the training data [...]




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: