How to use BERT for question answering
A step by step guide with PyTorch and Hugging Face’s transformers library
BERT is a transformer model that can understand natural language and perform various natural language processing tasks, such as question answering, text classification, sentiment analysis, and more. It is short for Bidirectional Encoder Representations from Transformers. It was created by Google in 2018 and has become one of the most popular and influential models in the field.
Most people have heard about BERT, but only a few can explain how this model works. BERT is based on the transformer architecture, which uses attention mechanisms to learn the relationships between different parts of the input and output data. BERT consists of an encoder that takes the input text and transforms it into a sequence of vectors, called the encoder output. These vectors capture the meaning and context of each word in the text.
BERT is trained on a large corpus of text data from various sources, such as Wikipedia, books, news articles, etc. It uses two pre-training objectives: masked language modeling and next sentence prediction. Masked language modeling is a technique that randomly masks some words in the input text and asks the model to predict them based on the rest of the text. Next sentence prediction is a technique that randomly pairs two sentences from the corpus and asks the model to predict if they are consecutive or not.
By using these two objectives, BERT can learn both the syntactic and semantic aspects of natural language. It can also learn from both left and right contexts of each word, which makes it bidirectional. This gives BERT an advantage over other models that only use unidirectional or shallow bidirectional approaches.
BERT can be fine-tuned for various downstream tasks with minimal data and computation. For example, to use BERT for question answering, one can add a simple layer on top of the encoder output that predicts the start and end positions of the answer span in the text. To use BERT for text classification, one can add a simple layer on top of the encoder output that predicts the class label of the text.
In this article, we will show you how to use BERT for question answering using PyTorch and Hugging Face's transformers library. We will use the SQuAD dataset, which is a collection of questions and answers based on Wikipedia articles. We will follow these steps:
Import modules
Load model and tokenizer
Load dataset
Preprocess dataset
Split dataset into train and test sets
Create data loaders for train and test sets
Create optimizer for model
Define training loop
Define evaluation loop
Train and evaluate the model
Let's go together step by step and understand what's happening here.
Import modules
First, we import the necessary modules: torch for deep learning, transformers for accessing the BERT model, and datasets for loading the SQuAD dataset.
# Import modules
import torch
from transformers import BertForQuestionAnswering, BertTokenizer
from datasets import load_dataset
Load model and tokenizer
Next, we load the BERT model and tokenizer using the from_pretrained()
function from transformers.BertForQuestionAnswering
and transformers.BertTokenizer
. This function downloads and caches the pre-trained model and tokenizer from Hugging Face's model hub. We use the bert-base-uncased
version, which is a smaller and faster version of BERT that uses lowercased words.
# Load model and tokenizer
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
Load dataset
Then, we load the SQuAD dataset using the load_dataset()
function from datasets
. This function downloads and caches the dataset from Hugging Face's dataset hub. We use the squad
version, which is a collection of questions and answers based on Wikipedia articles.
# Load dataset
dataset = load_dataset('squad')
Preprocess dataset
Next, we preprocess the dataset using the map()
function from datasets
. This function applies a given function to each example in the dataset. We use a custom function that tokenizes the question and context texts using the BERT tokenizer, encodes them into input ids and attention masks, and converts the answer positions into start and end token indices.
# Preprocess dataset
def preprocess(example):
# Tokenize question and context
encoding = tokenizer(example['question'], example['context'], truncation=True)
# Encode input ids and attention masks
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
# Convert answer positions to start and end token indices
start_token = input_ids.index(tokenizer.sep_token_id) + 1
end_token = len(input_ids) - 1
start_index = example['answers']['answer_start'][0]
end_index = start_index + len(example['answers']['text'][0])
# If answer is out of span, use default values
if start_index < start_token or end_index > end_token:
start_index = end_index = 0
# Return preprocessed example
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'start_index': start_index, 'end_index': end_index}
dataset = dataset.map(preprocess)
Split dataset into train and test sets
Then, we split the dataset into train and test sets using the train_test_split()
function from datasets
. This function randomly splits the dataset into two subsets with a given ratio. We use a ratio of 0.8 to create an 80% train set and a 20% test set.
# Split dataset into train and test sets
train_set, test_set = dataset.train_test_split(test_size=0.2)
Create data loaders for train and test sets
Next, we create data loaders for the train and test sets using the DataLoader
class from torch.utils.data
. This class creates batches of data from the dataset and applies optional transformations, such as shuffling and padding. We use a batch size of 16 and set padding=True
to pad the input ids and attention masks to the maximum length in each batch.
# Create data loaders for train and test sets
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, padding=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, padding=True)
Create optimizer for model
Next, we create an optimizer for the BERT model using the AdamW
class from transformers
. This class implements a variant of the Adam algorithm that is suitable for BERT. We use a learning rate of 3e-5 and a weight decay of 0.01.
# Create optimizer for model
optimizer = transformers.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
Define training loop
Next, we define a training loop that iterates over the train data loader for a given number of epochs. For each batch, we perform the following steps:
We move the batch data to the device (CPU or GPU) that we are using.
We set the model to training mode and clear any previous gradients.
We feed the input ids and attention masks to the model and get the start and end logits as outputs.
We compute the loss by comparing the logits with the true start and end token indices using the
cross_entropy()
function fromtorch.nn.functional
.We backpropagate the loss and update the model parameters using the optimizer.
We print the loss value every 100 steps for monitoring.
# Define training loop
def train():
# Set model to training mode
model.train()
# Loop over epochs
for epoch in range(epochs):
# Initialize total loss
total_loss = 0.0
# Loop over batches
for step, batch in enumerate(train_loader):
# Move batch data to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
start_index = batch['start_index'].to(device)
end_index = batch['end_index'].to(device)
# Clear previous gradients
optimizer.zero_grad()
# Feed input ids and attention masks to model and get start and end logits as outputs
outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# Compute loss by comparing logits with true start and end token indices
loss = torch.nn.functional.cross_entropy(start_logits, start_index) + torch.nn.functional.cross_entropy(end_logits, end_index)
# Backpropagate loss and update model parameters
loss.backward()
optimizer.step()
# Accumulate total loss
total_loss += loss.item()
# Print loss every 100 steps
if (step + 1) % 100 == 0:
print(f'Epoch {epoch + 1}, Step {step + 1}, Loss: {loss.item():.4f}')
# Print average loss per epoch
print(f'Epoch {epoch + 1}, Average Loss: {total_loss / len(train_loader):.4f}')
Define evaluation loop
Then, we define an evaluation loop that iterates over the test data loader and calculates the accuracy of the model on the test set. For each batch, we perform the following steps:
- We move the batch data to the device that we are using.
We set the model to evaluation mode and disable gradient computation.
We feed the input ids and attention masks to the model and get the start and end logits as outputs.
We get the predicted start and end token indices by taking the argmax of the logits along the last dimension.
We compare the predictions with the true start and end token indices and count how many are correct.
We compute the accuracy by dividing the number of correct predictions by the total number of examples in the test set.
# Define evaluation loop
def evaluate():
# Set model to evaluation mode
model.eval()
# Initialize number of correct predictions and total number of examples
correct = 0
total = 0
# Loop over batches
for batch in test_loader:
# Move batch data to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
start_index = batch['start_index'].to(device)
end_index = batch['end_index'].to(device)
# Disable gradient computation
with torch.no_grad():
# Feed input ids and attention masks to model and get start and end logits as outputs
outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# Get predicted start and end token indices by taking argmax of logits along last dimension
pred_start_index = torch.argmax(start_logits, dim=-1)
pred_end_index = torch.argmax(end_logits, dim=-1)
# Compare predictions with true start and end token indices and count how many are correct
correct += ((pred_start_index == start_index) & (pred_end_index == end_index)).sum().item()
# Accumulate total number of examples
total += len(input_ids)
# Compute accuracy by dividing number of correct predictions by total number of examples
accuracy = correct / total
# Print accuracy
print(f'Accuracy: {accuracy:.4f}')
Train and evaluate the model
Finally, we train and evaluate the model using the defined functions. We also define the number of epochs as 3 and the device (CPU or GPU) that we are using.
# Define number of epochs
epochs = 3
# Define device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Move model to device
model.to(device)
# Train and evaluate the model
train()
evaluate()
That's it! You have successfully used BERT for question answering. You can try running this code on Google Colab or your local machine and see how well it performs. You can also experiment with different hyperparameters, such as batch size, learning rate, weight decay, etc. You can also try different versions of BERT, such as bert-large-uncased
or bert-base-cased
, or other transformer models, such as RoBERTa, ALBERT, or DistilBERT.
I hope you enjoyed this article and learned something new. If you have any questions or feedback, please let me know.