How to use BERT for question answering

How to use BERT for question answering

A step by step guide with PyTorch and Hugging Face’s transformers library

BERT is a transformer model that can understand natural language and perform various natural language processing tasks, such as question answering, text classification, sentiment analysis, and more. It is short for Bidirectional Encoder Representations from Transformers. It was created by Google in 2018 and has become one of the most popular and influential models in the field.

Most people have heard about BERT, but only a few can explain how this model works. BERT is based on the transformer architecture, which uses attention mechanisms to learn the relationships between different parts of the input and output data. BERT consists of an encoder that takes the input text and transforms it into a sequence of vectors, called the encoder output. These vectors capture the meaning and context of each word in the text.

BERT is trained on a large corpus of text data from various sources, such as Wikipedia, books, news articles, etc. It uses two pre-training objectives: masked language modeling and next sentence prediction. Masked language modeling is a technique that randomly masks some words in the input text and asks the model to predict them based on the rest of the text. Next sentence prediction is a technique that randomly pairs two sentences from the corpus and asks the model to predict if they are consecutive or not.

By using these two objectives, BERT can learn both the syntactic and semantic aspects of natural language. It can also learn from both left and right contexts of each word, which makes it bidirectional. This gives BERT an advantage over other models that only use unidirectional or shallow bidirectional approaches.

BERT can be fine-tuned for various downstream tasks with minimal data and computation. For example, to use BERT for question answering, one can add a simple layer on top of the encoder output that predicts the start and end positions of the answer span in the text. To use BERT for text classification, one can add a simple layer on top of the encoder output that predicts the class label of the text.

In this article, we will show you how to use BERT for question answering using PyTorch and Hugging Face's transformers library. We will use the SQuAD dataset, which is a collection of questions and answers based on Wikipedia articles. We will follow these steps:

  • Import modules

  • Load model and tokenizer

  • Load dataset

  • Preprocess dataset

  • Split dataset into train and test sets

  • Create data loaders for train and test sets

  • Create optimizer for model

  • Define training loop

  • Define evaluation loop

  • Train and evaluate the model

Let's go together step by step and understand what's happening here.

Import modules

First, we import the necessary modules: torch for deep learning, transformers for accessing the BERT model, and datasets for loading the SQuAD dataset.

# Import modules
import torch
from transformers import BertForQuestionAnswering, BertTokenizer
from datasets import load_dataset

Load model and tokenizer

Next, we load the BERT model and tokenizer using the from_pretrained() function from transformers.BertForQuestionAnswering and transformers.BertTokenizer. This function downloads and caches the pre-trained model and tokenizer from Hugging Face's model hub. We use the bert-base-uncased version, which is a smaller and faster version of BERT that uses lowercased words.

# Load model and tokenizer
model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Load dataset

Then, we load the SQuAD dataset using the load_dataset() function from datasets. This function downloads and caches the dataset from Hugging Face's dataset hub. We use the squad version, which is a collection of questions and answers based on Wikipedia articles.

# Load dataset
dataset = load_dataset('squad')

Preprocess dataset

Next, we preprocess the dataset using the map() function from datasets. This function applies a given function to each example in the dataset. We use a custom function that tokenizes the question and context texts using the BERT tokenizer, encodes them into input ids and attention masks, and converts the answer positions into start and end token indices.

# Preprocess dataset
def preprocess(example):
  # Tokenize question and context
  encoding = tokenizer(example['question'], example['context'], truncation=True)
  # Encode input ids and attention masks
  input_ids = encoding['input_ids']
  attention_mask = encoding['attention_mask']
  # Convert answer positions to start and end token indices
  start_token = input_ids.index(tokenizer.sep_token_id) + 1
  end_token = len(input_ids) - 1
  start_index = example['answers']['answer_start'][0]
  end_index = start_index + len(example['answers']['text'][0])
  # If answer is out of span, use default values
  if start_index < start_token or end_index > end_token:
    start_index = end_index = 0
  # Return preprocessed example
  return {'input_ids': input_ids, 'attention_mask': attention_mask, 'start_index': start_index, 'end_index': end_index}

dataset = dataset.map(preprocess)

Split dataset into train and test sets

Then, we split the dataset into train and test sets using the train_test_split() function from datasets. This function randomly splits the dataset into two subsets with a given ratio. We use a ratio of 0.8 to create an 80% train set and a 20% test set.

# Split dataset into train and test sets
train_set, test_set = dataset.train_test_split(test_size=0.2)

Create data loaders for train and test sets

Next, we create data loaders for the train and test sets using the DataLoader class from torch.utils.data. This class creates batches of data from the dataset and applies optional transformations, such as shuffling and padding. We use a batch size of 16 and set padding=True to pad the input ids and attention masks to the maximum length in each batch.

# Create data loaders for train and test sets
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, padding=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, padding=True)

Create optimizer for model

Next, we create an optimizer for the BERT model using the AdamW class from transformers. This class implements a variant of the Adam algorithm that is suitable for BERT. We use a learning rate of 3e-5 and a weight decay of 0.01.

# Create optimizer for model
optimizer = transformers.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)

Define training loop

Next, we define a training loop that iterates over the train data loader for a given number of epochs. For each batch, we perform the following steps:

  • We move the batch data to the device (CPU or GPU) that we are using.

  • We set the model to training mode and clear any previous gradients.

  • We feed the input ids and attention masks to the model and get the start and end logits as outputs.

  • We compute the loss by comparing the logits with the true start and end token indices using the cross_entropy() function from torch.nn.functional.

  • We backpropagate the loss and update the model parameters using the optimizer.

  • We print the loss value every 100 steps for monitoring.

# Define training loop
def train():
  # Set model to training mode
  model.train()
  # Loop over epochs
  for epoch in range(epochs):
    # Initialize total loss
    total_loss = 0.0
    # Loop over batches
    for step, batch in enumerate(train_loader):
      # Move batch data to device
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      start_index = batch['start_index'].to(device)
      end_index = batch['end_index'].to(device)
      # Clear previous gradients
      optimizer.zero_grad()
      # Feed input ids and attention masks to model and get start and end logits as outputs
      outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
      start_logits = outputs.start_logits
      end_logits = outputs.end_logits
      # Compute loss by comparing logits with true start and end token indices
      loss = torch.nn.functional.cross_entropy(start_logits, start_index) + torch.nn.functional.cross_entropy(end_logits, end_index)
      # Backpropagate loss and update model parameters
      loss.backward()
      optimizer.step()
      # Accumulate total loss
      total_loss += loss.item()
      # Print loss every 100 steps
      if (step + 1) % 100 == 0:
        print(f'Epoch {epoch + 1}, Step {step + 1}, Loss: {loss.item():.4f}')
    # Print average loss per epoch
    print(f'Epoch {epoch + 1}, Average Loss: {total_loss / len(train_loader):.4f}')

Define evaluation loop

Then, we define an evaluation loop that iterates over the test data loader and calculates the accuracy of the model on the test set. For each batch, we perform the following steps:

  • We move the batch data to the device that we are using.
  • We set the model to evaluation mode and disable gradient computation.

  • We feed the input ids and attention masks to the model and get the start and end logits as outputs.

  • We get the predicted start and end token indices by taking the argmax of the logits along the last dimension.

  • We compare the predictions with the true start and end token indices and count how many are correct.

  • We compute the accuracy by dividing the number of correct predictions by the total number of examples in the test set.

# Define evaluation loop
def evaluate():
  # Set model to evaluation mode
  model.eval()
  # Initialize number of correct predictions and total number of examples
  correct = 0
  total = 0
  # Loop over batches
  for batch in test_loader:
    # Move batch data to device
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    start_index = batch['start_index'].to(device)
    end_index = batch['end_index'].to(device)
    # Disable gradient computation
    with torch.no_grad():
      # Feed input ids and attention masks to model and get start and end logits as outputs
      outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
      start_logits = outputs.start_logits
      end_logits = outputs.end_logits
      # Get predicted start and end token indices by taking argmax of logits along last dimension
      pred_start_index = torch.argmax(start_logits, dim=-1)
      pred_end_index = torch.argmax(end_logits, dim=-1)
      # Compare predictions with true start and end token indices and count how many are correct
      correct += ((pred_start_index == start_index) & (pred_end_index == end_index)).sum().item()
    # Accumulate total number of examples
    total += len(input_ids)
  # Compute accuracy by dividing number of correct predictions by total number of examples
  accuracy = correct / total
  # Print accuracy
  print(f'Accuracy: {accuracy:.4f}')

Train and evaluate the model

Finally, we train and evaluate the model using the defined functions. We also define the number of epochs as 3 and the device (CPU or GPU) that we are using.

# Define number of epochs
epochs = 3

# Define device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move model to device
model.to(device)

# Train and evaluate the model
train()
evaluate()

That's it! You have successfully used BERT for question answering. You can try running this code on Google Colab or your local machine and see how well it performs. You can also experiment with different hyperparameters, such as batch size, learning rate, weight decay, etc. You can also try different versions of BERT, such as bert-large-uncased or bert-base-cased, or other transformer models, such as RoBERTa, ALBERT, or DistilBERT.

I hope you enjoyed this article and learned something new. If you have any questions or feedback, please let me know.