RAFT: Teach LLMs to be better at RAG
- Nagesh Singh Chauhan
- Jun 22, 2024
- 7 min read
"Blending the Best of Both Worlds: How 'Retrieval-Augmented Fine-Tuning' Merges Retrieval-Augmented Generation with Fine-Tuning for Enhanced Domain Adaptation"
Introduction
One of the most transformative uses of generative AI for businesses is the development of natural language interfaces that can efficiently tap into existing knowledge bases. This capability is vital for providing accurate answers in specialized fields such as banking, law, and medicine. Currently, there are two main strategies to achieve this:
Domain-Specific Fine-Tuning (DSF): This method involves training a base model on a collection of documents specific to a domain, allowing the model to learn and understand the nuances of that field.
Retrieval Augmented Generation (RAG): This approach involves storing documents in a vector database and, at the time of a query, retrieving documents based on their semantic similarity to the question. These documents are then used to provide context for the language model to generate answers.
Both strategies have their strengths but also significant limitations. DSF can lead to overfitting and hallucinations, where the model produces information that is not actually present in the training data. Conversely, RAG can sometimes retrieve irrelevant documents, resulting in less precise answers.
In this article, we will examine these limitations and introduce a novel approach developed by UC Berkeley researchers Tianjun Zhang and Shishir G. Patil. Known for their work on Gorilla LLM, the team presents a new methodology in their RAFT (Retrieval Augmented Fine-Tuning) paper. They illustrate how they utilized Meta Llama 2 and Azure AI Studio to enhance generative AI models.
The Berkeley researchers have also published a detailed blog post outlining the benefits and drawbacks of DSF and RAG, and demonstrating how the RAFT approach can deliver more effective results. Their implementation of RAFT is available on GitHub, offering a valuable resource for those looking to explore or extend their work.
In the sections that follow, we will provide an overview of the RAFT approach, explaining its mechanics and highlighting why it is a superior method for incorporating domain-specific knowledge into generative AI models.
Understanding the RAFT Method
In a traditional Retrieval-Augmented Generation (RAG) setup, when a user asks a question, the model retrieves a few relevant documents from a database and uses these documents to generate an answer.
Imagine the model answering questions like a student in an open-book exam, where they can look up answers in a textbook. This is different from a closed-book exam, where the student must rely solely on their memory. RAG is like an open-book exam because the model can refer to documents to help formulate its response, making it more effective.
However, both methods have their drawbacks. Fine-tuning a model is like training a student to take a closed-book exam. The model can only use what it has learned, which might not cover everything and can lead to errors or made-up information (hallucinations). On the other hand, RAG relies on the documents it retrieves, but sometimes it may pull in irrelevant documents, leading to less accurate answers.
Researchers aimed to address these limitations of RAG. They proposed that a student who reviews textbooks before an open-book exam would perform better than one who only looks things up during the exam. Translating this idea to language models, they developed Retrieval Augmented Fine Tuning (RAFT). This method trains the model to understand and adapt to the domain before being used in a RAG setup.
Using the Meta Llama 2 7B language model, they created a synthetic dataset where each sample includes:
• A question
• A mix of relevant and irrelevant reference documents
• An answer derived from the relevant documents
• A Chain-of-Thought (CoT) explanation that highlights key excerpts from the relevant documents
They fine-tuned the Meta Llama 2 7B model with this dataset, helping it better understand the domain and extract useful information from the context. The CoT reasoning improves the model’s ability to provide well-thought-out answers and prevents overfitting.
RAFT stands between RAG and Domain-Specific Fine-tuning (DSF). It enhances the model’s domain knowledge and style (like DSF) while improving the quality of answers generated from retrieved documents. This approach is particularly beneficial for pretrained models like Meta Llama 2, which are trained on diverse topics, making them more suitable for specialized fields like healthcare or legal datasets.
Key Components of RAFT
Let’s take a closer look at how Retrieval-Augmented Fine-Tuning (RAFT) works. RAFT introduces a new way to prepare fine-tuning data for training in-domain Retrieval-Augmented Generation (RAG) models. Each data point in the RAFT training dataset includes:
1. A Question (Q): The query that needs to be answered.
2. A Set of Documents (Dk): These are divided into two categories:
“Oracle” Documents: These documents contain the answers to the question. There can be multiple oracle documents for each question.
“Distractor” Documents: These documents do not contain relevant information for answering the question.
3. Chain-of-Thought Style Answer (A*): An answer generated from the oracle documents, including a detailed reasoning process.
In the RAFT fine-tuning dataset, each question is paired with a set of documents, some containing the answers (oracle documents) and some not (distractor documents), along with a chain-of-thought style answer. This setup helps the model learn to distinguish between useful and irrelevant information when forming answers.
To further improve the model’s learning, the RAFT training dataset includes a mix of question types:
• Questions with Both Oracle and Distractor Documents: A percentage P of the questions include both types of documents. This teaches the model to identify and prioritize relevant information.
• Questions with Only Distractor Documents: The remaining 1 - P percentage of questions only have distractor documents. This mimics traditional fine-tuning, training the model to handle questions without relying on external documents.
Finally, the chain-of-thought style answers incorporate segments from the oracle documents and a detailed reasoning process. This approach enhances the model’s accuracy in answering questions by teaching it to form a reasoning chain using relevant segments from the original context.
RAFT Evaluation
Datasets and Baselines for Model Evaluation
In the experiments, Evaluation was done on model and all baselines using the following datasets, chosen for their popularity and diversity across various domains, including Wikipedia, Coding/API documents, and medical question answering.
1. Natural Questions (NQ), Trivia QA, and Hotpot QA: These are open-domain question datasets based on Wikipedia, focusing primarily on general knowledge topics such as movies and sports.
2. HuggingFace, Torch Hub, and TensorFlow Hub: These datasets are part of the APIBench benchmark introduced in the Gorilla paper. They are designed to evaluate the generation of correct functional API calls based on documentation.
3. PubMed QA: This is a specialized question-answering dataset tailored for biomedical research, focusing on medical and biology questions derived from a set of documents.
Baselines Considered in Our Experiments
Following baselines to benchmark were used for experiments:
1. LlaMA2-7B-chat Model with Zero-Shot Prompting: This widely-used instruction-tuned model for QA tasks provides clearly written instructions but does not reference any documentation.
2. LlaMA2-7B-chat Model with RAG (Llama2 + RAG): Similar to the zero-shot prompting setup, but with the addition of reference context. This combination is commonly employed for domain-specific QA tasks.
3. Domain-Specific Fine-Tuning with Zero-Shot Prompting (DSF): This involves performing standard instruction fine-tuning without including documents in the context.
4. Domain-Specific Fine-Tuning with RAG (DSF + RAG): This approach enhances a domain-specific fine-tuned model with external knowledge using RAG. It allows the model to refer to the context for information it does not inherently possess.
Implementation of Retrieval Augmented Fine-tuning(RAFT)
Let us now learn the implementation of Retrieval Augmented Fine-tuning. Initially we start with installing the required libraries using the following commands:
!pip install llama-index
!pip install llama-index-packs-raft-dataset
!pip install llama-index-embeddings-huggingface
Then you can import the RAFTDataset:
from llama_index.packs.raft_dataset import RAFTDatasetPack
import json
Data Preparation for Q/A Generation: RAFTDatasetPack Configuration
For the data preparation process in Q/A generation, the RAFTDatasetPack is set up with the following parameters:
1. filepath: This specifies the path to the file used to generate questions and answers, serving as the primary content source for the dataset.
2. llm: This defines the Large Language Model (LLM) used for generating questions and answers. If no model is specified, GPT-4 is used by default. Selecting the model should be done carefully, considering the associated costs.
3. embed-model: This parameter indicates the embedding model used to calculate the similarity between a query and its context, which is crucial for selecting relevant context chunks.
4. num_questions_per_chunk: This determines how many questions are generated for each data chunk, directly influencing the comprehensiveness of the training dataset.
5. num_distract_docs: This sets the number of random context chunks used as distractors for each question, which helps challenge the model to identify relevant information.
6. chunk_size: Although Llama-index uses SemanticSplitterNodeParser to split the dataset into chunks, this parameter is not useful in this context.
7. default_breakpoint_percentile_threshold: This controls the threshold for combining chunks based on their dissimilarity. A higher value results in fewer, larger chunks, affecting the granularity of the data used for training.
Semantic Splitter Node Parser
The SemanticNodeParser works by breaking down data at the sentence level, initially dividing the text into smaller segments or ‘chunks’. Here’s how the process works:
1. Initial Chunk Formation: The system splits the text into initial chunks at the sentence level.
2. Cosine Dissimilarity Calculation: For each pair of adjacent chunks, the parser calculates the cosine dissimilarity (1 - cosine similarity). This metric measures how different the chunks are from each other based on their vector representations in a multi-dimensional semantic space.
3. Threshold for Concatenation: A predefined dissimilarity threshold is set. If the dissimilarity between adjacent chunks exceeds this threshold, it indicates that the chunks are significantly different.
4. Chunk Concatenation: If the dissimilarity does not exceed the threshold, the system concatenates the chunks to form larger, unified chunks. This ensures that each chunk represents a cohesive piece of information, enhancing the model’s ability to understand and process the data effectively.
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
llm = OpenAI(model="gpt-3.5-turbo")
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
Download the paul graham Dataset.
!wget --user-agent "Mozilla" "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt" -O './paul_graham_essay.txt'
# create RAFT Dataset object
raft_dataset = RAFTDatasetPack(file_path="/content/paul_graham_essay.txt",
llm = llm, embed_model=embed_model,
num_questions_per_chunk=1, num_distract_docs=2, chunk_size=1024,
default_breakpoint_percentile_threshold=99)
# create the dataset
dataset = raft_dataset.run()
# save the dataset in jsonl format
output_path = './raft_dataset'
dataset.to_json(output_path + ".jsonl")
Load the dataset.
with open('./raft_dataset.jsonl', 'r') as json_file:
dataset = list(json_file)
# We can access the dataset with the following
json.loads(dataset[0]).keys()
# output
# dict_keys(['id', 'type', 'question', 'context', 'oracle_context', 'cot_answer', 'instruction'])
json.loads(dataset[0])['question']
# output
# 'What were the two main things the author worked on before college?'
Conclusion
The RAFT method represents a significant advancement in language model fine-tuning. By improving the quality of generated answers and enhancing the model’s ability to extract relevant information from retrieved contexts, RAFT shows great promise for a wide range of future applications.
The research utilizing the Meta Llama 2 7B language model highlights its versatility and adaptability across diverse tasks. The insights and recommendations from the research team offer valuable guidance for those looking to fine-tune Meta Llama or similar models.
Additionally, Azure AI Studio plays a crucial role in making cutting-edge generative AI capabilities more accessible. By simplifying the processes of fine-tuning, testing, and deployment, the platform empowers developers and businesses to create innovative and customized solutions without needing extensive machine learning expertise.
Comments