Linearity of Relation Decoding in Transformer Language Models

Authors: Evan Hernandez, Arnab Sen Sharma, Tal Haklay, Kevin Meng, Martin Wattenberg, Jacob Andreas, Yonatan Belinkov, David Bau

Abstract: Much of the knowledge encoded in transformer language models (LMs) may be expressed in terms of relations: relations between words and their synonyms, entities and their attributes, etc. We show that, for a subset of relations, this computation is well-approximated by a single linear transformation on the subject representation. Linear relation representations may be obtained by constructing a first-order approximation to the LM from a single prompt, and they exist for a variety of factual, commonsense, and linguistic relations. However, we also identify many cases in which LM predictions capture relational knowledge accurately, but this knowledge is not linearly encoded in their representations. Our results thus reveal a simple, interpretable, but heterogeneously deployed knowledge representation strategy in transformer LMs.

What, Why and How

Here is a summary of the key points from the paper:

What:

  • The paper investigates how transformer language models (LMs) represent and decode knowledge about relations between entities or concepts.
  • It shows that for a subset of relations, the LM’s decoding can be approximated as a linear transformation called a linear relational embedding (LRE).
  • LREs map from a subject’s representation to the representation of the object it’s related to.
  • The authors curate a dataset of 47 relations covering factual, commonsense and linguistic knowledge to analyze.

Why:

  • Understanding how relations are represented in LMs provides insight into their knowledge representation and reasoning capabilities.
  • Showing that some relations have a simple linear structure makes parts of LMs more interpretable.
  • LREs could be used to visualize or edit relational knowledge in LMs.

How:

  • The authors take the Jacobian of the LM at a prompt expressing a relation to get an LRE.
  • They evaluate LREs on their faithfulness to LM decoding and ability to causally influence predictions.
  • LREs achieve high faithfulness and causality for many relations, suggesting linear decoding.
  • But other relations are not linear, so LMs use multiple decoding schemes.
  • LREs are used to build an “attribute lens” to visualize relational knowledge.

In summary, the paper provides evidence that transformers employ simple linear transformations to decode some relational knowledge from entities, while also showing the limits of this scheme, and demonstrates an application of LREs to analyze model knowledge.

Main Contributions

Here are the main contributions of the paper:

  • Shows that transformer language models employ linear relational embeddings (LREs) to decode some relational knowledge, approximating the complex model as a simple linear function.

  • Provides a method to estimate LREs for relations directly from the model’s Jacobian, without additional training.

  • Evaluates the fidelity and causality of estimated LREs on a dataset of 47 diverse relations. Finds linear decoding for about half.

  • Demonstrates that LREs can be used to visualize and edit relational knowledge in LMs.

  • Introduces the “attribute lens”, which uses LREs to reveal when LMs contradict their own internal knowledge.

  • Highlights the heterogeneity of knowledge encoding in LMs, with evidence of both linear and non-linear decoding schemes for relations.

  • Analyzes the layer-wise evolution of LREs, revealing a potential “mode switch” from relational encoding to next-token prediction.

  • Compares LREs estimated for three major language models (GPT-J, GPT-2, LLaMA).

In summary, the main contributions are introducing LREs as a characterization of some relational reasoning in LMs, providing methods to estimate and evaluate LREs, demonstrating applications like the attribute lens, and analyzing the scope, limits, and model-dependence of linear decoding. The paper expands our understanding of knowledge representation in large language models.

Method Section

Here is a summary of the method section from the paper:

The paper proposes that transformer language models (LMs) implement linear relational embeddings (LREs) to decode some relational knowledge.

To estimate an LRE for a relation r, they calculate the Jacobian J of the model at a prompt expressing r. This gives the local linear approximation to how the model maps subject representations to related object representations.

Specifically, the LRE is estimated as:

LRE(s) = β * W * s + b

Where s is the subject representation, and W and b are the mean Jacobian and bias terms averaged over several subject samples.

β is a scalar >1 that is tuned to improve the LRE estimate.

The authors evaluate LREs on:

  • Faithfulness - Whether LRE(s) makes the same predictions as the full LM. Measured by accuracy of matching next tokens.

  • Causality - Whether inverting LRE can edit s to make LM predict a different object o’. Measured by success rate of targeted edits.

Good LREs should be high in both faithfulness and causality.

To select the layer and rank of LREs, hyperparameters are tuned on a grid search to maximize causality.

The method provides a way to extract interpretable LREs from LMs without additional training, and quantitatively evaluate how well they approximate the model’s relational reasoning.

High-Level Pseudo Code

Here is high-level pseudocode for the main steps of the paper:

 
# Estimate LRE
for relation r in relations:
  samples = get_subject_samples(r) 
  J, b = [], []
  for s in samples:
    prompt = make_prompt(r, s)
    h = get_hidden_state(LM, prompt, s)  
    J.append(Jacobian(LM, h)) 
    b.append(LM(prompt) - J[-1] * h)
 
  W = mean(J)
  b = mean(b)
 
  LRE = lambda s: beta * W * s + b
 
# Evaluate LRE
for r in relations:
  LRE = learned_LREs[r] 
  fidelity = test_faithfulness(LRE, r)
  causality = test_causality(LRE, r)
 
# Applications
for r in relations:
  LRE = learned_LREs[r]
  attribute_lens = [LRE(h) for h in LM.hidden_states]
 

The main steps are:

  1. Estimate an LRE for each relation by taking the Jacobian of the LM on prompts for that relation.
  2. Evaluate each LRE on faithfulness and causality metrics.
  3. Use LREs for applications like the attribute lens by applying them to LM hidden states.

The pseudocode summarizes the process of learning, evaluating, and applying LREs to analyze relational reasoning in LMs.

Detailed Pseudo Code

Here is a more detailed pseudocode implementation of the key algorithms from the paper:

# Estimate LRE
 
def estimate_LRE(LM, relation, samples):
 
  J, b = [], []
 
  for s in samples:
  
    prompt = f"{s} {relation} the" 
    h = get_representation(LM, prompt, s)
 
    J.append(compute_jacobian(LM, h))
    b.append(LM(prompt) - J[-1] @ h)
  
  W = np.mean(J, axis=0)
  b = np.mean(b, axis=0)
 
  return lambda s: BETA * W @ s + b
 
# Compute Jacobian
 
def compute_jacobian(LM, h):
 
  eps = 0.01
  d = h.shape[0]
  
  J = np.zeros((d, d))
 
  for i in range(d):
    h_eps = h.copy()
    h_eps[i] += eps
    y_eps = LM(h_eps) 
    y = LM(h)
    J[:,i] = (y_eps - y) / eps
  
  return J
 
# Evaluate LRE
 
def test_faithfulness(LRE, samples):
  
  accuracy = 0
  for s in samples:
    h = get_representation(LM, s)
    y_hat = LRE(h)
    y = LM(f"{s} {relation} the")
    if np.argmax(y_hat) == np.argmax(y):
      accuracy += 1
 
  return accuracy / len(samples)
 
def test_causality(LRE, samples):
 
  edits = 0
  for s, s' in sample_pairs:
    h, h' = get_representations(LM, s, s') 
    delta = LRE.invert(LRE(h') - LRE(h))
    h_edit = h + delta
    y = LM(f"{s} {relation} the")
    if np.argmax(y) == np.argmax(LM(h_edit)):
      edits += 1
 
  return edits / len(samples)  

The core functions are:

  • estimate_LRE to approximate the model’s relational decoding
  • compute_jacobian to get the Jacobian at a prompt
  • test_faithfulness to evaluate prediction accuracy
  • test_causality to test targeted edits

Additional utilities like get_representation and invert could be implemented. The code shows one way to operationalize the key algorithms in Python and NumPy.