xLSTM: Extended Long Short Term Memory
- Nagesh Singh Chauhan
- Jun 9, 2024
- 12 min read
With a new upgrade, LSTM is here to challenge the LLM dominance.
Introduction
For years, Long Short-Term Memory (LSTM) neural networks were the cornerstone of sequence data processing, excelling in tasks like language modeling, text generation, and speech recognition. However, in 2017, the introduction of the Transformer architecture revolutionised the field, offering superior efficiency and scalability, particularly with large datasets, and quickly overtaking LSTMs.
Recognising the limitations of traditional LSTMs, such as constrained memory and lack of parallelisation, Sepp Hochreiter and his team at NXAI have developed an enhanced variant known as extended LSTM (xLSTM). This new model incorporates exponential gating and enhanced memory structures, significantly improving flexibility and storage capacity. By integrating parallelization and residual stacking techniques from large language models, xLSTMs can efficiently handle long sequences and complex language tasks.
In this article, we will delve into the architectural innovations of xLSTM and examine its performance advancements over existing state-of-the-art models.
If you are new to LSTMs, I would urge you to first understand its architecture and limitations in this article.
The Limitations of LSTM
Before exploring xLSTM, it’s crucial to understand the limitations that have challenged traditional LSTM architectures and prompted the development of xLSTM and other alternatives.
While LSTMs excel in long-term memory and sequence processing, they have three main drawbacks:
Difficulty Revising Storage Decisions: LSTMs struggle to update previously stored information when encountering more relevant data later in a sequence. This is evident in tasks like the Nearest Neighbor Search problem, where LSTMs find it challenging to identify the most similar vector and accurately report its associated value.
Limited Storage Capacity: LSTM memory cells compress information into a single scalar value, restricting the amount of information stored. This limitation leads to poorer performance in tasks involving rare token prediction, as the nuances of infrequent words are inadequately captured.
Lack of Parallelizability: The sequential nature of LSTMs, where each hidden state depends on the previous one, hinders parallelization on modern hardware like GPUs. This prevents LSTMs from fully utilizing available computational power for training and inference.
These limitations have driven the rise of Transformers and other architectures, which surpass LSTMs in scalability and efficiency for larger models.
The xLSTM Architecture
At the heart of xLSTM are two significant modifications to the traditional LSTM framework: exponential gating and novel memory structures. These enhancements give rise to two new variants, known as sLSTM (scalar LSTM) and mLSTM (matrix LSTM).
sLSTM: Scalar LSTM with Exponential Gating and Memory Mixing
sLSTM is often described as an enhanced version of LSTM with scalar or sequence-level updates, which may include improvements to the gating mechanisms (such as exponential gating) and optimizations of the memory structure. The focus in the paper might be more on enhancing the capabilities of LSTM through algorithmic optimizations rather than employing complex network layers and structures
For sLSTM, we utilize post-up projections. Initially, the input passes through causal convolution layers with a Swish activation function. The output from these layers is then fed through a block-diagonal linear layer comprising four diagonal blocks or “heads.” This output is subsequently processed by the sLSTM block, also configured with four heads. Finally, the result is up-projected using a gated MLP layer with GeLU activation and down-projected via a gated MLP function.
Few points to note about sLSTM:
Exponential Gating: sLSTM incorporates exponential activation functions for the input and forget gates, providing more flexible control over information flow.
Normalization and Stabilization: To prevent numerical instabilities, sLSTM introduces a normalizer state that tracks the product of input and future forget gates.
Memory Mixing: sLSTM supports multiple memory cells and allows for memory mixing via recurrent connections, enhancing its ability to extract complex patterns and track states.
mLSTM: Matrix LSTM with Enhanced Storage Capacities
mLSTM is an LSTM variant featuring matrix memory, which allows it to process and store more information in parallel. This involves a fundamental shift in the memory structure, using matrices instead of scalars to store the LSTM cell states.
In the mLSTM block, we employ pre-up projections. This means the input is initially up-projected with a projection factor of 2. One projection output is directed to the mLSTM, while the other goes to the output gate. The input to the mLSTM block passes through causal convolution and then through block-diagonal projection matrices with a block size of 4. These matrices produce the query, key, and value, which are then utilized by the mLSTM block.
Few points to note about mLSTM:
Matrix Memory: mLSTM utilizes a matrix memory instead of a scalar memory cell, increasing storage capacity and enabling more efficient information retrieval.
Covariance Update Rule: Inspired by Bidirectional Associative Memories (BAMs), mLSTM employs a covariance update rule to store and retrieve key-value pairs efficiently.
Parallelizability: By eliminating memory mixing, mLSTM achieves full parallelizability, allowing for efficient computations on modern hardware accelerators.
These variants, sLSTM and mLSTM, can be integrated into residual block architectures, forming xLSTM blocks. By residually stacking these xLSTM blocks, researchers can create powerful xLSTM architectures tailored for specific tasks and application domains.
Comparative study of xLSTM with other LLMs
the authors decided to train their model on 300B tokens and compare it with other architectures (including Transformers, State Space Models, and Recurrent Neural Networks) with the same datasets. They also experimented with the composition and proportion of sLSTM/mLSTM.
What we infer is that xLSTM has an advantage in those tasks where memory mixing or g (state tracking) is required to solve tasks, like the parity task. In this case, Transformers or State Space Models fail to solve it precisely because they do not conduct state-tracking
In general, the model is competitive with other architectures when obviously trained with the same dataset and number of parameters.
Ablation studies show that all the added elements are necessary to achieve this performance
As the importance of having all these gates is discussed, the authors also do selective ablation of these gates, showing that they have an incremental positive impact. However, it shows that still the impact of these gates is not exactly essential.
Obviously, the authors test for context length (after all, this is the theme of the moment). xLSTM models maintain low perplexities for longer contexts, much better than other models.
Finally, they test for downstream tasks, showing that the xLSTM model are the best model for all model sizes with respect to the validation set perplexity (apart from a few cases where Mamba seems better, but the authors do not investigate).
Key Features and Advantages of xLSTM
Ability to Revise Storage Decisions: Exponential gating allows xLSTM to dynamically revise stored values when encountering more relevant information, addressing a major limitation of traditional LSTMs.
Enhanced Storage Capacities: The matrix memory in mLSTM significantly increases storage capacity, enabling xLSTM to handle rare tokens, long-range dependencies, and complex data patterns more effectively.
Parallelizability: The mLSTM variant of xLSTM is fully parallelizable, facilitating efficient computations on modern hardware accelerators like GPUs, and supporting scalability to larger models.
Memory Mixing and State Tracking: The sLSTM variant retains the memory mixing capabilities of traditional LSTMs, enabling robust state tracking and making xLSTM more expressive than Transformers and State Space Models for certain tasks.
Scalability: Leveraging the latest techniques from modern Large Language Models (LLMs), xLSTM can scale to billions of parameters, unlocking new possibilities in language modeling and sequence processing tasks.
Performance and Applications of xLSTM
Language Modeling xLSTM has demonstrated promising results in language modeling tasks, often surpassing traditional LSTMs and competing with state-of-the-art Transformer models. Its ability to scale to billions of parameters while maintaining efficiency makes it a strong candidate for large-scale language modeling applications.
Time Series Analysis xLSTM’s robust handling of long-range dependencies and complex temporal patterns makes it well-suited for time series analysis and forecasting. Applications range from financial market predictions to weather forecasting, where accurate long-term predictions are crucial.
Machine Translation xLSTM’s state tracking capabilities allow it to maintain context and understand long-range dependencies within languages, leading to more accurate and nuanced translations. This makes it a powerful tool for machine translation tasks, where maintaining context over long sequences is essential.
Speech Recognition and Generation The parallelizability and scalability of xLSTM make it ideal for speech recognition and generation applications. Its ability to process long sequences of speech data efficiently paves the way for advancements in real-time voice interactions and natural language interfaces.
Pytorch Implementation of xLSTM
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalConv1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
super(CausalConv1D, self).__init__()
self.padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, **kwargs)
def forward(self, x):
x = self.conv(x)
return x[:, :, :-self.padding]
class BlockDiagonal(nn.Module):
def __init__(self, in_features, out_features, num_blocks):
super(BlockDiagonal, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.num_blocks = num_blocks
assert in_features % num_blocks == 0
assert out_features % num_blocks == 0
block_in_features = in_features // num_blocks
block_out_features = out_features // num_blocks
self.blocks = nn.ModuleList([
nn.Linear(block_in_features, block_out_features)
for _ in range(num_blocks)
])
def forward(self, x):
x = x.chunk(self.num_blocks, dim=-1)
x = [block(x_i) for block, x_i in zip(self.blocks, x)]
x = torch.cat(x, dim=-1)
return x
class sLSTMBlock(nn.Module):
def __init__(self, input_size, hidden_size, num_heads, proj_factor=4/3):
super(sLSTMBlock, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.proj_factor = proj_factor
assert hidden_size % num_heads == 0
assert proj_factor > 0
self.layer_norm = nn.LayerNorm(input_size)
self.causal_conv = CausalConv1D(1, 1, 4)
self.Wz = BlockDiagonal(input_size, hidden_size, num_heads)
self.Wi = BlockDiagonal(input_size, hidden_size, num_heads)
self.Wf = BlockDiagonal(input_size, hidden_size, num_heads)
self.Wo = BlockDiagonal(input_size, hidden_size, num_heads)
self.Rz = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.Ri = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.Rf = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.Ro = BlockDiagonal(hidden_size, hidden_size, num_heads)
self.group_norm = nn.GroupNorm(num_heads, hidden_size)
self.up_proj_left = nn.Linear(hidden_size, int(hidden_size * proj_factor))
self.up_proj_right = nn.Linear(hidden_size, int(hidden_size * proj_factor))
self.down_proj = nn.Linear(int(hidden_size * proj_factor), input_size)
def forward(self, x, prev_state):
assert x.size(-1) == self.input_size
h_prev, c_prev, n_prev, m_prev = prev_state
x_norm = self.layer_norm(x)
x_conv = F.silu(self.causal_conv(x_norm.unsqueeze(1)).squeeze(1))
z = torch.tanh(self.Wz(x) + self.Rz(h_prev))
o = torch.sigmoid(self.Wo(x) + self.Ro(h_prev))
i_tilde = self.Wi(x_conv) + self.Ri(h_prev)
f_tilde = self.Wf(x_conv) + self.Rf(h_prev)
m_t = torch.max(f_tilde + m_prev, i_tilde)
i = torch.exp(i_tilde - m_t)
f = torch.exp(f_tilde + m_prev - m_t)
c_t = f * c_prev + i * z
n_t = f * n_prev + i
h_t = o * c_t / n_t
output = h_t
output_norm = self.group_norm(output)
output_left = self.up_proj_left(output_norm)
output_right = self.up_proj_right(output_norm)
output_gated = F.gelu(output_right)
output = output_left * output_gated
output = self.down_proj(output)
final_output = output + x
return final_output, (h_t, c_t, n_t, m_t)
class sLSTM(nn.Module):
# TODO: Add bias, dropout, bidirectional
def __init__(self, input_size, hidden_size, num_heads, num_layers=1, batch_first=False, proj_factor=4/3):
super(sLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_layers = num_layers
self.batch_first = batch_first
self.proj_factor_slstm = proj_factor
self.layers = nn.ModuleList([sLSTMBlock(input_size, hidden_size, num_heads, proj_factor) for _ in range(num_layers)])
def forward(self, x, state=None):
assert x.ndim == 3
if self.batch_first: x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if state is not None:
state = torch.stack(list(state))
assert state.ndim == 4
num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
assert num_hidden == 4
assert state_num_layers == self.num_layers
assert state_batch_size == batch_size
assert state_input_size == self.input_size
state = state.transpose(0, 1)
else:
state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size)
output = []
for t in range(seq_len):
x_t = x[t]
for layer in range(self.num_layers):
x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
state[layer] = torch.stack(list(state_tuple))
output.append(x_t)
output = torch.stack(output)
if self.batch_first:
output = output.transpose(0, 1)
state = tuple(state.transpose(0, 1))
return output, state
class mLSTMBlock(nn.Module):
def __init__(self, input_size, hidden_size, num_heads, proj_factor=2):
super(mLSTMBlock, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.proj_factor = proj_factor
assert hidden_size % num_heads == 0
assert proj_factor > 0
self.layer_norm = nn.LayerNorm(input_size)
self.up_proj_left = nn.Linear(input_size, int(input_size * proj_factor))
self.up_proj_right = nn.Linear(input_size, hidden_size)
self.down_proj = nn.Linear(hidden_size, input_size)
self.causal_conv = CausalConv1D(1, 1, 4)
self.skip_connection = nn.Linear(int(input_size * proj_factor), hidden_size)
self.Wq = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
self.Wk = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
self.Wv = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
self.Wi = nn.Linear(int(input_size * proj_factor), hidden_size)
self.Wf = nn.Linear(int(input_size * proj_factor), hidden_size)
self.Wo = nn.Linear(int(input_size * proj_factor), hidden_size)
self.group_norm = nn.GroupNorm(num_heads, hidden_size)
def forward(self, x, prev_state):
h_prev, c_prev, n_prev, m_prev = prev_state
assert x.size(-1) == self.input_size
x_norm = self.layer_norm(x)
x_up_left = self.up_proj_left(x_norm)
x_up_right = self.up_proj_right(x_norm)
x_conv = F.silu(self.causal_conv(x_up_left.unsqueeze(1)).squeeze(1))
x_skip = self.skip_connection(x_conv)
q = self.Wq(x_conv)
k = self.Wk(x_conv) / (self.head_size ** 0.5)
v = self.Wv(x_up_left)
i_tilde = self.Wi(x_conv)
f_tilde = self.Wf(x_conv)
o = torch.sigmoid(self.Wo(x_up_left))
m_t = torch.max(f_tilde + m_prev, i_tilde)
i = torch.exp(i_tilde - m_t)
f = torch.exp(f_tilde + m_prev - m_t)
c_t = f * c_prev + i * (v * k) # v @ k.T
n_t = f * n_prev + i * k
h_t = o * (c_t * q) / torch.max(torch.abs(n_t.T @ q), 1)[0] # o * (c @ q) / max{|n.T @ q|, 1}
output = h_t
output_norm = self.group_norm(output)
output = output_norm + x_skip
output = output * F.silu(x_up_right)
output = self.down_proj(output)
final_output = output + x
return final_output, (h_t, c_t, n_t, m_t)
class mLSTM(nn.Module):
# TODO: Add bias, dropout, bidirectional
def __init__(self, input_size, hidden_size, num_heads, num_layers=1, batch_first=False, proj_factor=2):
super(mLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_layers = num_layers
self.batch_first = batch_first
self.proj_factor_slstm = proj_factor
self.layers = nn.ModuleList([mLSTMBlock(input_size, hidden_size, num_heads, proj_factor) for _ in range(num_layers)])
def forward(self, x, state=None):
assert x.ndim == 3
if self.batch_first: x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if state is not None:
state = torch.stack(list(state))
assert state.ndim == 4
num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
assert num_hidden == 4
assert state_num_layers == self.num_layers
assert state_batch_size == batch_size
assert state_input_size == self.input_size
state = state.transpose(0, 1)
else:
state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size)
output = []
for t in range(seq_len):
x_t = x[t]
for layer in range(self.num_layers):
x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
state[layer] = torch.stack(list(state_tuple))
output.append(x_t)
output = torch.stack(output)
if self.batch_first:
output = output.transpose(0, 1)
state = tuple(state.transpose(0, 1))
return output, state
class xLSTM(nn.Module):
# TODO: Add bias, dropout, bidirectional
def __init__(self, input_size, hidden_size, num_heads, layers, batch_first=False, proj_factor_slstm=4/3, proj_factor_mlstm=2):
super(xLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.layers = layers
self.num_layers = len(layers)
self.batch_first = batch_first
self.proj_factor_slstm = proj_factor_slstm
self.proj_factor_mlstm = proj_factor_mlstm
self.layers = nn.ModuleList()
for layer_type in layers:
if layer_type == 's':
layer = sLSTMBlock(input_size, hidden_size, num_heads, proj_factor_slstm)
elif layer_type == 'm':
layer = mLSTMBlock(input_size, hidden_size, num_heads, proj_factor_mlstm)
else:
raise ValueError(f"Invalid layer type: {layer_type}. Choose 's' for sLSTM or 'm' for mLSTM.")
self.layers.append(layer)
def forward(self, x, state=None):
assert x.ndim == 3
if self.batch_first: x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if state is not None:
state = torch.stack(list(state))
assert state.ndim == 4
num_hidden, state_num_layers, state_batch_size, state_input_size = state.size()
assert num_hidden == 4
assert state_num_layers == self.num_layers
assert state_batch_size == batch_size
assert state_input_size == self.input_size
state = state.transpose(0, 1)
else:
state = torch.zeros(self.num_layers, 4, batch_size, self.hidden_size)
output = []
for t in range(seq_len):
x_t = x[t]
for layer in range(self.num_layers):
x_t, state_tuple = self.layers[layer](x_t, tuple(state[layer].clone()))
state[layer] = torch.stack(list(state_tuple))
output.append(x_t)
output = torch.stack(output)
if self.batch_first:
output = output.transpose(0, 1)
state = tuple(state.transpose(0, 1))
return output, state
CausalConv1D is a convolution layer tailored to preserve the causal relationship in time-series data processing. It ensures that the convolution operation does not access future information, which is crucial for sequence prediction tasks.
BlockDiagonal implements a specialized linear (fully connected) layer, where the weight matrix is composed of multiple independent blocks arranged on the main diagonal, forming a block diagonal matrix. This configuration allows each block to interact solely with its corresponding input segment, effectively simulating a series of independent linear transformations.
Additionally, BlockDiagonal includes a parameter, num_blocks, indicating the number of internal linear layers duplicated within the block structure. In transformer contexts, this is akin to the number of attention heads in multi-head attention.
Github: https://github.com/NX-AI/xlstm
Limitations of xLSTM
According to the authors, several limitations persist:
Parallelization Issues with sLSTM: sLSTM does not support parallelization. To mitigate this, the authors developed a fast CUDA kernel for sLSTM, making an sLSTM block only 1.5 times slower than mLSTM.
Suboptimal CUDA Kernels for mLSTM: The CUDA kernels for mLSTM are not optimized, resulting in poor compatibility with FlashAttention.
Complexity of Large Matrices: The presence of large matrices with complexity d \times d scales with the context, potentially hindering efficient model use due to the size of these matrices.
Lack of Optimization: There remains a lack of optimization in the architecture and hyperparameters (e.g., careful initialization of the forget gates).
Potential Impact of xLSTMs
The authors partially answer the question: How far can LSTMs go in language modeling when scaled to billions of parameters? So far, the answer is: “At least as far as current technologies like Transformers or State Space Models.”
Despite theoretical interest in scaling and parallelizing LSTMs, the practical impact of xLSTMs appears mixed. While LSTMs benefit from numerous NLP advancements—many inspired by lessons from Transformers—this might not be sufficient to convince the community to abandon the well-established ecosystem of libraries and methods optimized for Transformer training and use.
However, xLSTM’s real advantage lies in state tracking. Recent years have shown that in-context learning is a powerful concept. Without retraining the model, techniques can be developed to overcome Transformer limitations in various tasks. New prompting techniques may help address these challenges.
Vision-LSTM (ViL)
2 days ago another experiment on computer visiongot published VisionLSTM (ViL). ViL comprises a stack of xLSTM blocks that process sequences of patch tokens in a bidirectional manner. Early experiments indicate that ViL holds promise as a new generic backbone for computer vision architectures https://arxiv.org/html/2406.04303v1.
Conclusion
The introduction of xLSTM marks a significant milestone in the evolution of language modeling and sequence processing architectures. By addressing the limitations of traditional LSTMs and incorporating innovative techniques such as exponential gating and matrix memory structures, xLSTM has achieved remarkable performance across diverse tasks and benchmarks.
Yet, this achievement is only the beginning. xLSTM opens the door to exciting opportunities for further exploration, refinement, and real-world application. As researchers continue to push the boundaries of what is possible, we can anticipate even more impressive advancements in the fields of natural language processing and artificial intelligence, driven by the pioneering capabilities of xLSTM.
Bình luận