Building a Generative Transformer Decoder in TensorFlow: A Step-by-Step Guide
This post documents my process implementing a generative transformer decoder model in TensorFlow in an effort to learn more about this architecture. We will be using a dataset of domain name as an example application to train the model.
Data pre-processing
You can find a step-by-step explanation of how to use BPE to tokenize the dataset in a previous post:
Learn how to train a custom Byte-Pair Encoding (BPE) tokenizer on a dataset of domain names using the Hugging Face library. Improve your NLP models' performance with this efficient tokenization technique.
Let’s start by loading the already tokenized dataset using the Hugging Face datasets library.
from datasets import load_dataset
ds = load_dataset("jeremygf/domains-app-alpha")
ds = ds['train']
Target shifting
The model needs to be fed slightly shifted target sequences. For example, if the correct domain name is “example” during training, the input might be [STA]exampl and the target would be example. This is called “teacher forcing.”
We can use the map method to produce input and target sequences. Here, x[‘ids’] contains the tokenized sequences as arrays of numerical IDs representing individual tokens.
ds = ds.map(lambda x: {
'input_ids': x['ids'][:-1],
'target_ids': x['ids'][1:]
})
ds_dict = ds.train_test_split(test_size=0.1)
To prepare a TensorFlow dataset, we use the DataCollator class from the Hugging Face Transformers library. A data collator is responsible for assembling your data samples into batches as needed for the model’s training or evaluation.
BATCH = 256
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(
tokenizer=tokenizer,
return_tensors="tf"
)
tf_train = ds_dict['train'].to_tf_dataset(
columns="input_ids",
label_cols="target_ids",
collate_fn=data_collator,
batch_size=BATCH,
shuffle=True
)
tf_test = ds_dict['test'].to_tf_dataset(
columns="input_ids",
label_cols="target_ids",
collate_fn=data_collator,
batch_size=BATCH,
shuffle=True
)
Building a Generative Transformer Decoder Model
Let’s break down how generative transformer decoder models work. The primary goal of a generative transformer decoder model is to predict the next element (e.g., word, character, or any other token) in a sequence given the previous elements. They achieve this through a combination of powerful mechanisms from the Transformer architecture.
Positional encoding
Unlike recurrent models (RNNs, LSTMs), Transformers don’t have an inherent sense of sequence order. Positional encodings inject information about the relative positions of tokens into the input representations, which helps the model understand the sequence’s structure.
The original paper uses the following formula for calculating the positional encoding:
This implementation uses sinusoids with varying wavelengths to represent positional information. The different depths i
create different frequencies of sine and cosine waves, giving the model a richer way to “understand” the relative positions of tokens.
import numpy as np
def positional_encoding(length, depth):
depth = depth/2
positions = np.arange(length)[:, np.newaxis]
depths = np.arange(depth)[np.newaxis, :]/depth
angle_rates = 1 / (10000**depths)
angle_rads = positions * angle_rates
pos_encoding = np.concatenate(
[np.sin(angle_rads), np.cos(angle_rads)],
axis=-1)
return tf.cast(pos_encoding, dtype=tf.float32)
The token sequences have to be converted to vectors using a standard embedding layer. The positional encoding array is added to the input token embeddings before being fed into the Transformer model.
class PositionalEmbedding(tf.keras.layers.Layer):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, mask_zero=True)
self.pos_encoding = positional_encoding(length=MAX_SEQ_LEN, depth=embedding_dim)
def call(self, x):
length = tf.shape(x)[1]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.embedding_dim, tf.float32))
x = x + self.pos_encoding[tf.newaxis, :length, :]
return x
Causal self-attention
Self-attention allows the transformer decoder to focus on different parts of the partially generated sequence. An important aspect of the decoder architecture is causal making in the attention mechanism. This means the model can only attend to past tokens, not future ones. This prevents “cheating” by seeing what the correct next token should be. Because we are not making use of an encoder component, the cross-attention portion of the decoder is omitted.
Multi-headed attention
The model contains multiple instances of the attention mechanism. Each “head” within the attention mechanism can focus on different aspects of the input sequence, enabling it to capture complex patterns and relationships within the partially generated sequence.
Let’s break down the code for the CausalSelfAttention
layer class. The constructor takes two parameters:
num_heads
: The number of attention heads in the multi-head attention mechanism.embedding_dim
: The dimensionality of the input embeddings.
With these parameters, we instantiate a MultiHeadAttention
layer, the core of the self-attention mechanism. A LayerNormalization
object is also created for stabilizing the attention mechanism’s output.
The call
method defines what happens when the layer is called on an input x
.
use_causal_mask=True
ensures the model only attends to previous tokens and not to future ones in the sequence.x = x + attn_output
performs a residual connection, adding the attention output to the original input. This helps with gradient flow during training.
class CausalSelfAttention(tf.keras.layers.Layer):
def __init__(self, num_heads, embedding_dim):
super().__init__()
self.mha = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=embedding_dim
)
self.layer_norm = LayerNormalization()
def call(self, x):
attn_output = self.mha(
query=x,
value=x,
use_causal_mask=True
)
x = x + attn_output
x = self.layer_norm(x)
return x
Feed-Forward Networks
After the attention layers, feed-forward networks independently process each token’s representation. This adds further non-linear transformations to enhance the model’s expressive power.
class FeedForward(tf.keras.layers.Layer):
def __init__(self, embedding_dim, feedforward_dim, dropout_rate=0.1):
super().__init__()
self.d1 = tf.keras.layers.Dense(feedforward_dim, activation='gelu')
self.d2 = tf.keras.layers.Dense(embedding_dim)
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.layer_norm = LayerNormalization()
def call(self, x):
x = self.d2(self.d1(x))
x = x + self.dropout(x)
x = self.layer_norm(x)
return x
The DecoderLayer
class combines the multi-headed attention layer and the feed-forward layer in sequence as a block.
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, embedding_dim, num_heads, feedforward_dim, dropout_rate=0.1):
super(DecoderLayer, self).__init__()
self.causal_self_attention = CausalSelfAttention(num_heads=num_heads, embedding_dim=embedding_dim)
self.ffn = FeedForward(embedding_dim=embedding_dim, feedforward_dim=feedforward_dim, dropout_rate=dropout_rate)
def call(self, x):
x = self.causal_self_attention(x=x)
x = self.ffn(x)
return x
Putting it all together
To build a full Transformer decoder, we need to stack multiple instances of the DecoderLayer
. The final layer of the decoder is a dense layer with the same size as your vocabulary. The output is a probability distribution over the entire vocabulary. The token with the highest probability is typically chosen as the next prediction.
class TransformerDecoder(tf.keras.Model):
def __init__(self, vocab_size: int, num_layers: int, embedding_dim: int,
num_heads: int, intermediate_dim: int) -> None:
super().__init__()
self.embeddings = PositionalEmbedding(vocab_size, embedding_dim)
self.decoders = [
DecoderLayer(embedding_dim, num_heads, intermediate_dim)
for _ in range(num_layers)
]
self.dense = tf.keras.layers.Dense(vocab_size, activation='relu')
def call(self, inputs):
x = self.embeddings(inputs)
for layer in self.decoders:
x = layer(x)
return self.dense(x)
Training
We are now ready to start training an instance of our Generative Transformer Decoder model.
Masked loss and metrics
Since the target sequences have various lengths and are padded, it is important to apply a padding mask when calculating the loss and metrics. Masking ensures padding tokens don’t corrupt our calculations.
def masked_loss(label, pred):
mask = label != 0
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
loss = loss_object(label, pred)
mask = tf.cast(mask, dtype=loss.dtype)
loss *= mask
loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
return loss
def masked_accuracy(label, pred):
pred = tf.argmax(pred, axis=2)
label = tf.cast(label, pred.dtype)
match = label == pred
mask = label != 0
match = match & mask
match = tf.cast(match, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.reduce_sum(match)/tf.reduce_sum(mask)
Hyperparameters and compiled model
The following constants define the hyperparameters of our Transformer decoder architecture:
EMBED_DIM
: Dimensionality of the token embeddings.NUM_HEADS
: Number of attention heads in the multi-head attention layers.FEED_FORWARD_DIM
: Dimensionality of the hidden layer in the feed-forward sublayers of the decoder.NUM_LAYERS
: The number of Transformer decoder layers stacked in our model.
These parameters are not ‘one-size-fits-all’. It’s common to experiment with different values to find the best configuration for your specific task.
EPOCHS = 10
EMBED_DIM = 256
NUM_HEADS = 4
FEED_FORWARD_DIM = 512
NUM_LAYERS = 2
model = TransformerDecoder(
vocab_size=tokenizer.vocab_size,
num_layers=NUM_LAYERS,
embedding_dim=EMBED_DIM,
num_heads=NUM_HEADS,
intermediate_dim=FEED_FORWARD_DIM
)
model.compile(loss=masked_loss, metrics=[masked_accuracy], optimizer='adam')
We are now ready to run the training procedure with the fit method of our model.
model.fit(
tf_train,
validation_data=tf_test,
epochs=EPOCHS
)
Inference Once our model is trained, we can use it to generate domain names. The keras_nlp library provides convenient utilities for text generation
prompt = ''
prompt_length = len(tokenizer.encode(prompt))-1
prompt = tokenizer.encode(prompt, padding='max_length', max_length=MAX_SEQ_LEN, return_tensors='tf')
def step(prompt, cache, index):
logits = model(prompt)[:, index-1, :]
hidden_states = None
return logits, hidden_states, cache
import keras_nlp
sampler = keras_nlp.samplers.TopPSampler(0.5)
output_tokens = sampler(
next=step,
prompt=prompt,
index=prompt_length
)[0]
print(tokenizer.decode(output_tokens).split('[END]')[0])
Here are some example domain names generated by the model:
my
doc
tor
y
r
is
k
b
ay
car
e
er
w
iz
ard
Resources
Learn how to train a custom Byte-Pair Encoding (BPE) tokenizer on a dataset of domain names using the Hugging Face library. Improve your NLP models' performance with this efficient tokenization technique.
Keras documentation