Aligning and Extracting Tokens from BERT

Zarrar Shehzad · November 4, 2021

Goal: I have a string and words from that string that I want to get out their word embeddings using BERT.

The embeddings for a word taken from BERT will include contextual information. So the word bear will be represented differently in ‘It is a bear market’ vs ‘I saw a bear’.

Aligning Words with Tokens

First, I get the needed code for obtaining the tokens that will be used to get the embeddings. I borrowed some of the BERT code from https://towardsdatascience.com/3-types-of-contextualized-word-embeddings-from-bert-using-transfer-learning-81fcefe3fe6d.

# BERT Stuff
from transformers import BertTokenizer, BertModel
import torch
# Alignment code
import tokenizations # pip install pytokenizations

# Setting up the tokenizer
###################################
# This is the same tokenizer that
# was used in the model to generate
# embeddings to ensure consistency
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def bert_text_preparation(text, tokenizer):
    """Preparing the input for BERT

    Takes a string argument and performs
    pre-processing like adding special tokens,
    tokenization, tokens to ids, and tokens to
    segment ids. All tokens are mapped to seg-
    ment id = 1.

    Args:
        text (str): Text to be converted
        tokenizer (obj): Tokenizer object
            to convert text into BERT-re-
            adable tokens and ids

    Returns:
        list: List of BERT-readable tokens
        obj: Torch tensor with token ids
        obj: Torch tensor segment ids


    """
    marked_text = "[CLS] " + text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1]*len(indexed_tokens)

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    return tokenized_text, tokens_tensor, segments_tensors

Getting the Words

Then I have some code that can extract the words from some text.

def extract_word_lists(text):
    """
    As part of the words, this will include punctuations. So 'hello there, you' will be ['hello', 'there,', 'you'].
    """
    # Get the words: index, start char loc, end char loc, word
    word_list = [ (i, w.start(), w.end(), w.group(), w.group().rstrip(',?!'))
                  for i,w in enumerate(re.finditer(r"[\S]+", text)) ]
    
    return word_list

I can then get the words and tokens for a piece of text like so:

txt = "I'm going to the park to play on my phone."
word_list = extract_word_lists(txt.strip('.'))
word_list # output below
[(0, 0, 3, "I'm", "I'm"),
 (1, 4, 9, 'going', 'going'),
 (2, 10, 12, 'to', 'to'),
 (3, 13, 16, 'the', 'the'),
 (4, 17, 21, 'park', 'park'),
 (5, 22, 24, 'to', 'to'),
 (6, 25, 29, 'play', 'play'),
 (7, 30, 32, 'on', 'on'),
 (8, 33, 35, 'my', 'my'),
 (9, 36, 41, 'phone', 'phone')]

Getting the Tokens

And I can get the tokens as so:

tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(txt, tokenizer)
tokenized_text # output below
['[CLS]',
 'i',
 'am',
 'going',
 'to',
 'the',
 'park',
 'to',
 'play',
 'on',
 'my',
 'phone',
 '.',
 '[SEP]']

Alignment

tokens_a = [ w[4] for w in word_list ]
tokens_b = tokenized_text

a2b, b2a = tokenizations.get_alignments(tokens_a, tokens_b)

for i in range(len(tokens_a)):
    print(tokens_a[i])
    for j in a2b[i]:
        print("    ", tokens_b[j])
I'm
     i
     am
going
     going
to
     to
the
     the
park
     park
to
     to
play
     play
on
     on
my
     my
phone
     phone

Embeddings

We can get the embeddings for each token with BERT.

# Loading the pre-trained BERT model
###################################
# Embeddings will be derived from
# the outputs of this model
model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True)

def get_bert_embeddings(tokens_tensor, segments_tensors, model):
    """Get embeddings from an embedding model

    Args:
        tokens_tensor (obj): Torch tensor size [n_tokens]
            with token ids for each token in text
        segments_tensors (obj): Torch tensor size [n_tokens]
            with segment ids for each token in text
        model (obj): Embedding model to generate embeddings
            from token and segment ids

    Returns:
        list: List of list of floats of size
            [n_tokens, n_embedding_dimensions]
            containing embeddings for each token

    """

    # Gradient calculation id disabled
    # Model is in inference mode
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        # Removing the first hidden state
        # The first state is the input state
        hidden_states = outputs[2][1:]

    # Getting embeddings from the final BERT layer
    token_embeddings = hidden_states[-1]
    # Collapsing the tensor into 1-dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=0)
    ## Converting torchtensors to lists
    #list_token_embeddings = [token_embed.tolist() for token_embed in token_embeddings]

    return token_embeddings.numpy()
embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)

Embeddings for Phrases

So this is big final step. What we want to do is get the embeddings for phrases (one or more words) present in the original text that I’m interested in. Since there isn’t a one to one mapping between words and tokens, we need to map each word to each token. If more than one token is associated with a phrase, then we average the embeddings across those tokens.

For this example, remember that the input sentence was “I’m going to the park to play on my phone.”. Say I want the embeddings for the following phrases:

  • I’m going
  • play
  • my phone

Then I would first have a list where each element is another list with the indices of each word in that phrase. I can convert this set of word indices to token indices based on the alignment that I had gotten from above.

phrases_in_text = [[0,1], [6], [8,9]]

phrases_in_bert = []
for word_indices in phrases_in_text:
    token_indices = []
    for wi in word_indices:
        token_indices.extend(a2b[wi])
    phrases_in_bert.append(token_indices)
phrases_in_bert
# [[1, 2, 3], [8], [10, 11]] # each element are the tokens corresponding with that phrase

Finally, we can get the average embedding for all tokens associated with each phrase

# Now we have the embeddings for each phrase!
phrases_embeddings = [ embeddings[phrases_in_bert[0],:].mean(axis=0) for token_indices in phrases_in_bert ]

Twitter, Facebook