Aligning and Extracting Tokens from BERT
Goal: I have a string and words from that string that I want to get out their word embeddings using BERT.
The embeddings for a word taken from BERT will include contextual information. So the word bear will be represented differently in ‘It is a bear market’ vs ‘I saw a bear’.
Aligning Words with Tokens
First, I get the needed code for obtaining the tokens that will be used to get the embeddings. I borrowed some of the BERT code from https://towardsdatascience.com/3-types-of-contextualized-word-embeddings-from-bert-using-transfer-learning-81fcefe3fe6d.
# BERT Stuff
from transformers import BertTokenizer, BertModel
import torch
# Alignment code
import tokenizations # pip install pytokenizations
# Setting up the tokenizer
###################################
# This is the same tokenizer that
# was used in the model to generate
# embeddings to ensure consistency
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def bert_text_preparation(text, tokenizer):
"""Preparing the input for BERT
Takes a string argument and performs
pre-processing like adding special tokens,
tokenization, tokens to ids, and tokens to
segment ids. All tokens are mapped to seg-
ment id = 1.
Args:
text (str): Text to be converted
tokenizer (obj): Tokenizer object
to convert text into BERT-re-
adable tokens and ids
Returns:
list: List of BERT-readable tokens
obj: Torch tensor with token ids
obj: Torch tensor segment ids
"""
marked_text = "[CLS] " + text + " [SEP]"
tokenized_text = tokenizer.tokenize(marked_text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [1]*len(indexed_tokens)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
return tokenized_text, tokens_tensor, segments_tensors
Getting the Words
Then I have some code that can extract the words from some text.
def extract_word_lists(text):
"""
As part of the words, this will include punctuations. So 'hello there, you' will be ['hello', 'there,', 'you'].
"""
# Get the words: index, start char loc, end char loc, word
word_list = [ (i, w.start(), w.end(), w.group(), w.group().rstrip(',?!'))
for i,w in enumerate(re.finditer(r"[\S]+", text)) ]
return word_list
I can then get the words and tokens for a piece of text like so:
txt = "I'm going to the park to play on my phone."
word_list = extract_word_lists(txt.strip('.'))
word_list # output below
[(0, 0, 3, "I'm", "I'm"),
(1, 4, 9, 'going', 'going'),
(2, 10, 12, 'to', 'to'),
(3, 13, 16, 'the', 'the'),
(4, 17, 21, 'park', 'park'),
(5, 22, 24, 'to', 'to'),
(6, 25, 29, 'play', 'play'),
(7, 30, 32, 'on', 'on'),
(8, 33, 35, 'my', 'my'),
(9, 36, 41, 'phone', 'phone')]
Getting the Tokens
And I can get the tokens as so:
tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(txt, tokenizer)
tokenized_text # output below
['[CLS]',
'i',
'am',
'going',
'to',
'the',
'park',
'to',
'play',
'on',
'my',
'phone',
'.',
'[SEP]']
Alignment
tokens_a = [ w[4] for w in word_list ]
tokens_b = tokenized_text
a2b, b2a = tokenizations.get_alignments(tokens_a, tokens_b)
for i in range(len(tokens_a)):
print(tokens_a[i])
for j in a2b[i]:
print(" ", tokens_b[j])
I'm
i
am
going
going
to
to
the
the
park
park
to
to
play
play
on
on
my
my
phone
phone
Embeddings
We can get the embeddings for each token with BERT.
# Loading the pre-trained BERT model
###################################
# Embeddings will be derived from
# the outputs of this model
model = BertModel.from_pretrained('bert-base-uncased',
output_hidden_states = True)
def get_bert_embeddings(tokens_tensor, segments_tensors, model):
"""Get embeddings from an embedding model
Args:
tokens_tensor (obj): Torch tensor size [n_tokens]
with token ids for each token in text
segments_tensors (obj): Torch tensor size [n_tokens]
with segment ids for each token in text
model (obj): Embedding model to generate embeddings
from token and segment ids
Returns:
list: List of list of floats of size
[n_tokens, n_embedding_dimensions]
containing embeddings for each token
"""
# Gradient calculation id disabled
# Model is in inference mode
with torch.no_grad():
outputs = model(tokens_tensor, segments_tensors)
# Removing the first hidden state
# The first state is the input state
hidden_states = outputs[2][1:]
# Getting embeddings from the final BERT layer
token_embeddings = hidden_states[-1]
# Collapsing the tensor into 1-dimension
token_embeddings = torch.squeeze(token_embeddings, dim=0)
## Converting torchtensors to lists
#list_token_embeddings = [token_embed.tolist() for token_embed in token_embeddings]
return token_embeddings.numpy()
embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
Embeddings for Phrases
So this is big final step. What we want to do is get the embeddings for phrases (one or more words) present in the original text that I’m interested in. Since there isn’t a one to one mapping between words and tokens, we need to map each word to each token. If more than one token is associated with a phrase, then we average the embeddings across those tokens.
For this example, remember that the input sentence was “I’m going to the park to play on my phone.”. Say I want the embeddings for the following phrases:
- I’m going
- play
- my phone
Then I would first have a list where each element is another list with the indices of each word in that phrase. I can convert this set of word indices to token indices based on the alignment that I had gotten from above.
phrases_in_text = [[0,1], [6], [8,9]]
phrases_in_bert = []
for word_indices in phrases_in_text:
token_indices = []
for wi in word_indices:
token_indices.extend(a2b[wi])
phrases_in_bert.append(token_indices)
phrases_in_bert
# [[1, 2, 3], [8], [10, 11]] # each element are the tokens corresponding with that phrase
Finally, we can get the average embedding for all tokens associated with each phrase
# Now we have the embeddings for each phrase!
phrases_embeddings = [ embeddings[phrases_in_bert[0],:].mean(axis=0) for token_indices in phrases_in_bert ]