RAVEN: In-Context Learning with Retrieval Augmented Encoder-Decoder Language Models

Authors: Jie Huang, Wei Ping, Peng Xu, Mohammad Shoeybi, Kevin Chen-Chuan Chang, Bryan Catanzaro

Abstract: In this paper, we investigate the in-context learning ability of retrieval-augmented encoder-decoder language models. We first conduct a comprehensive analysis of the state-of-the-art ATLAS model and identify its limitations in in-context learning, primarily due to a mismatch between pretraining and testing, as well as a restricted context length. To address these issues, we propose RAVEN, a model that combines retrieval-augmented masked language modeling and prefix language modeling. We further introduce Fusion-in-Context Learning to enhance the few-shot performance by enabling the model to leverage more in-context examples without requiring additional training or model modifications. Through extensive experiments, we demonstrate that RAVEN significantly outperforms ATLAS and achieves results comparable to the most advanced language models in certain scenarios, despite having substantially fewer parameters. Our work underscores the potential of retrieval-augmented encoder-decoder language models for in-context learning and encourages further research in this direction.

What, Why and How

Here is a summary of the key points from this paper:

\textbf{What}: This paper investigates the in-context learning ability of retrieval-augmented encoder-decoder language models. The authors first analyze the state-of-the-art Atlas model and identify limitations in its in-context learning ability. They then propose Raven, a model that combines retrieval-augmented masked language modeling and prefix language modeling to enhance in-context learning. Additionally, they introduce Fusion-in-Context Learning and In-Context Example Retrieval to further boost the few-shot performance.

\textbf{Why}: In-context learning with retrieval-augmented encoder-decoder models is relatively unexplored compared to decoder-only models like GPT-3. However, encoder-decoder models offer unique advantages such as the ability to leverage multiple retrieved passages effectively. This paper aims to tap into the potential of these models for in-context learning and encourage further research in this direction.

\textbf{How}:

  • Conduct analysis of Atlas on open-domain QA, identifying issues like training-testing mismatch and limited context length.
  • Develop Raven by combining masked LM and prefix LM to align pretraining with testing.
  • Propose Fusion-in-Context Learning to enable models to learn from more examples at test time without modifications.
  • Suggest using retriever to obtain relevant in-context examples.
  • Demonstrate Raven outperforms Atlas significantly and achieves comparable results to large decoder-only models on certain tasks, despite having far fewer parameters.

In summary, this paper performs comprehensive analysis of Atlas, proposes methods to enhance in-context learning for retrieval-augmented encoder-decoder models, and shows impressive results, highlighting the potential of these models. The key novelty is exploring and improving in-context learning for this class of models.

Main Contributions

Based on my reading, the main contributions of this paper are:

  • A comprehensive analysis of the in-context learning ability of Atlas, the current state-of-the-art retrieval-augmented encoder-decoder language model. The analysis identifies limitations related to training-testing mismatch and restricted context length.

  • The development of Raven, a new model that combines retrieval-augmented masked language modeling and prefix language modeling to better align pretraining with testing.

  • The proposal of Fusion-in-Context Learning, a novel approach to enable models to leverage more in-context examples at test time without requiring additional training or model modifications.

  • The suggestion to use the retriever of Raven to obtain relevant in-context examples, further enhancing few-shot performance.

  • Extensive experiments demonstrating Raven significantly outperforms Atlas on various tasks, and achieves comparable results to decoder-only models like GPT-3 and PaLM despite having substantially fewer parameters.

  • Underscoring the potential of retrieval-augmented encoder-decoder models for in-context learning through comprehensive analysis and proposing methods to unlock this potential.

In summary, the key contributions are: rigorous analysis of Atlas, developing Raven to enhance in-context learning, proposing Fusion-in-Context Learning, demonstrating impressive results, and highlighting the promise of this class of models for in-context learning.

Method Section

Here is a summary of the method section from the paper:

The authors propose two main methods to improve in-context learning for retrieval-augmented encoder-decoder models:

  1. Developing Raven:
  • Raven combines retrieval-augmented masked language modeling (the pretraining objective of Atlas) and prefix language modeling.
  • Prefix LM helps mitigate the mismatch between Atlas pretraining (random masked spans) and testing (target question after examples).
  • Raven is initialized with Atlas weights and trained on masked LM first to get good reader and retriever. Then trained on prefix LM.
  1. Fusion-in-Context Learning (FiCL):
  • Allows models to learn from more examples at test time without changing model configuration or retraining.
  • Works by feeding different examples to encoder with each retrieved passage, instead of same examples.
  • Enables incorporating more examples within length constraints.
  • E.g. in 64-shot setting, shuffle examples, take 5 examples for each passage.

Additionally, they suggest retrieving relevant in-context examples using Raven’s retriever.

In summary, the key methods are developing Raven to combine masked LM and prefix LM, and proposing FiCL to provide more in-context examples to model at test time. These aim to enhance in-context learning for retrieval-augmented encoder-decoder models.

High-Level Pseudo Code

Here is a high-level pseudo code overview of the key methods proposed in the paper:

# Raven 
# Initialize model with Atlas weights
model = AtlasModel() 
 
# Further pretrain with masked LM 
for i in num_steps:
  data = get_masked_lm_batch()
  loss = model(data)
  update_model(loss)
 
# Then pretrain with prefix LM
for i in num_steps:
  data = get_prefix_lm_batch() 
  loss = model(data)
  update_model(loss)
 
# Fusion-in-Context Learning
# k = number of examples, m = examples per passage 
def FiCL(model, input, k, m):
  
  examples = get_k_examples(input, k)
  
  # Retrieve passages
  passages = retrieve_passages(input, num_passages)
  
  # Generate different example sets for each passage
  example_sets = split_examples(examples, m) 
  
  # Feed different example set with each passage
  outputs = []
  for passage, examples in zip(passages, example_sets):
    output = model(passage, examples, input)
    outputs.append(output)
 
  return aggregate_outputs(outputs) 

The key steps are pretraining Raven by combining masked LM and prefix LM objectives sequentially, and enabling the model to incorporate more examples at test time with FiCL by cycling different example sets for each passage input.

Detailed Pseudo Code

Here is a more detailed pseudo code implementation of the key methods in the paper:

# Hyperparameters
NUM_MASKED_LM_STEPS = 5000 
NUM_PREFIX_LM_STEPS = 5000
MASK_PROB = 0.15
 
# Initialize Atlas model
model = AtlasModel()
optimizer = AdamW(lr=4e-5) 
 
# Pretrain with masked LM
for i in range(NUM_MASKED_LM_STEPS):
  
  # Get batch of data
  texts = get_texts(batch_size) 
  
  # Mask random spans in each text
  masked_texts = mask_random_spans(texts, MASK_PROB)  
  
  # Get masked LM targets
  targets = get_masked_lm_targets(masked_texts) 
  
  # Retrieve passages  
  passages = retrieve_passages(masked_texts)
  
  # Masked LM forward pass
  predictions = model(masked_texts, passages)
 
  # Compute loss 
  loss = masked_lm_loss(predictions, targets)
  
  # Update model
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()
 
# Switch to prefix LM
optimizer.lr = 1e-5 
 
for i in range(NUM_PREFIX_LM_STEPS):
 
  # Get batch of texts
  texts = get_texts(batch_size)
 
  # Create prefixes
  prefixes = get_prefixes(texts) 
  
  # Get prefix LM targets  
  targets = get_prefix_lm_targets(texts)
 
  # Retrieve passages
  passages = retrieve_passages(prefixes)  
 
  # Prefix LM forward pass
  predictions = model(prefixes, passages)
 
  # Compute loss
  loss = prefix_lm_loss(predictions, targets)
 
  # Update model
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()
  
# Fusion-in-Context Learning
 
def FiCL(model, input, k, m):
 
  # Get k examples
  examples = get_k_examples(input, k)
  
  # Split examples into m-example sets
  example_sets = split_into_sets(examples, m) 
  
  # Retrieve passages
  passages = retrieve_passages(input, num_passages)
 
  # Forward pass with each example set
  outputs = []
  for passage, examples in zip(passages, example_sets):
  
    output = model(passage, examples, input)
    outputs.append(output)
 
  return aggregate_outputs(outputs)
 

The key aspects are:

  • Pretraining with masked LM followed by prefix LM in two stages
  • Using a looping mechanism for multiple training steps
  • Retrieving passages within each loop
  • Creating batches and computing loss/update for each step
  • For FiCL, splitting examples into sets and cycling through them.