Implementing a Basic Additive Attention Mechanism in Python
Attention mechanisms are a fundamental component in modern deep learning, particularly in sequence-to-sequence models like those used for machine translation or text summarization. They allow a model to dynamically focus on specific parts of the input sequence when processing the output. This challenge focuses on implementing a simplified additive (Bahdanau-style) attention mechanism from scratch in Python.
Problem Description
Your task is to implement a class that represents a basic additive attention mechanism. This mechanism will take a query vector and a set of key-value pairs (where keys and values are typically derived from an encoder's hidden states) and produce an attention-weighted context vector.
Key Requirements:
- Class Structure: Create a Python class named
AdditiveAttention. - Initialization (
__init__):- The constructor should accept two integer arguments:
query_dim(the dimension of the query vector) andkey_dim(the dimension of the key vectors). - It should initialize the necessary learnable parameters:
W_q: A weight matrix for the query.W_k: A weight matrix for the keys.v: A weight vector.
- You can assume these parameters will be initialized using a standard method (e.g., Xavier initialization), but for this challenge, you can initialize them with dummy NumPy arrays of appropriate shapes for testing purposes.
- The constructor should accept two integer arguments:
- Forward Pass (
forward):- The
forwardmethod should accept two arguments:query: A NumPy array representing the query vector (shape:(batch_size, query_dim)).keys: A NumPy array representing the key vectors (shape:(batch_size, seq_len, key_dim)).values: A NumPy array representing the value vectors (shape:(batch_size, seq_len, value_dim)), wherevalue_dimcan be different fromkey_dim.
- The method should perform the following steps:
- Score Calculation: Compute the attention scores using the additive attention formula: $e_{ij} = v^T \tanh(W_q q_i + W_k k_j)$ where $q_i$ is the $i$-th query, $k_j$ is the $j$-th key, $W_q$ and $W_k$ are weight matrices, and $v$ is a weight vector. You'll need to broadcast the query and keys appropriately.
- Softmax: Apply a softmax function to the calculated scores along the
seq_lendimension to get attention weights ($\alpha_{ij}$). $\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_k \exp(e_{ik})}$ - Context Vector Calculation: Compute the weighted sum of the value vectors using the attention weights to produce the context vector: $c_i = \sum_j \alpha_{ij} v_j$ where $v_j$ is the $j$-th value.
- The
forwardmethod should return two NumPy arrays:context_vector: The attention-weighted context vector (shape:(batch_size, value_dim)).attention_weights: The computed attention weights (shape:(batch_size, seq_len)).
- The
Expected Behavior:
- The
forwardmethod should correctly compute the attention scores, apply softmax to get weights, and then produce a context vector that is a weighted average of the inputvalues. - The dimensions of the output arrays should match the specifications.
Edge Cases:
- Consider what happens if
seq_lenis 1. The softmax should still operate correctly. - Ensure that operations are performed element-wise or via broadcasting where appropriate to handle batches.
Examples
Example 1:
import numpy as np
# Dummy parameters (replace with actual initialization for a real model)
query_dim = 10
key_dim = 10
value_dim = 5
batch_size = 2
seq_len = 3
# Initialize dummy weights for demonstration
W_q = np.random.randn(query_dim, query_dim)
W_k = np.random.randn(key_dim, query_dim)
v = np.random.randn(query_dim, 1)
# Dummy input
query = np.random.randn(batch_size, query_dim)
keys = np.random.randn(batch_size, seq_len, key_dim)
values = np.random.randn(batch_size, seq_len, value_dim)
# Expected Output Structure:
# context_vector: (batch_size, value_dim)
# attention_weights: (batch_size, seq_len)
Explanation:
This example sets up the dimensions and dummy input for a typical scenario. The AdditiveAttention class will be instantiated with query_dim and key_dim, and the forward method will process the query, keys, and values to produce the context_vector and attention_weights.
Example 2:
# Assume AdditiveAttention class is defined as per requirements
# Input for a single example in the batch
query_single = np.random.randn(1, query_dim)
keys_single = np.random.randn(1, seq_len, key_dim)
values_single = np.random.randn(1, seq_len, value_dim)
# Call forward for this single example
# context_single, weights_single = attention.forward(query_single, keys_single, values_single)
# Expected: context_single.shape == (1, value_dim), weights_single.shape == (1, seq_len)
Explanation: This example highlights processing a batch of size 1, ensuring the mechanism works correctly for individual data points.
Example 3: (Handling different value_dim)
# Input with value_dim different from key_dim
query_dim_ex3 = 8
key_dim_ex3 = 12
value_dim_ex3 = 20
batch_size_ex3 = 1
seq_len_ex3 = 5
# Dummy parameters for this example
W_q_ex3 = np.random.randn(query_dim_ex3, query_dim_ex3)
W_k_ex3 = np.random.randn(key_dim_ex3, query_dim_ex3)
v_ex3 = np.random.randn(query_dim_ex3, 1)
# Dummy input
query_ex3 = np.random.randn(batch_size_ex3, query_dim_ex3)
keys_ex3 = np.random.randn(batch_size_ex3, seq_len_ex3, key_dim_ex3)
values_ex3 = np.random.randn(batch_size_ex3, seq_len_ex3, value_dim_ex3)
# Call forward
# context_ex3, weights_ex3 = attention.forward(query_ex3, keys_ex3, values_ex3)
# Expected: context_ex3.shape == (batch_size_ex3, value_dim_ex3) which is (1, 20)
# Expected: weights_ex3.shape == (batch_size_ex3, seq_len_ex3) which is (1, 5)
Explanation:
This example demonstrates that the value_dim does not need to match the key_dim, as the context vector is a weighted sum of the values. The attention scores and weights are still computed based on query_dim and key_dim.
Constraints
- Library Usage: You are allowed to use
numpyfor all numerical computations. No deep learning frameworks (like TensorFlow or PyTorch) are allowed for implementing the core attention logic. - Input Format:
querywill be a NumPy array of shape(batch_size, query_dim).keyswill be a NumPy array of shape(batch_size, seq_len, key_dim).valueswill be a NumPy array of shape(batch_size, seq_len, value_dim). - Output Format: The
forwardmethod must return two NumPy arrays:context_vectorof shape(batch_size, value_dim)andattention_weightsof shape(batch_size, seq_len). - Parameter Initialization: For testing, you can initialize weights using
np.random.randnor similar. In a real application, these would be trainable parameters. - Batch Processing: The implementation must handle batched inputs correctly.
Notes
- The additive attention mechanism is also known as Bahdanau attention.
- Pay close attention to the dimensions during matrix multiplications and broadcasting operations. NumPy's
np.tanhandnp.expfunctions will be useful. - The softmax needs to be applied along the
seq_lendimension for each example in the batch.np.sumwithaxisparameter andkeepdims=Truewill be helpful for normalization. - Think about how to efficiently compute $W_q q_i + W_k k_j$ for all $i$ and $j$ across the batch. Broadcasting will be your friend here.