Jamba: The New Hybrid Transformer/Mamba
Faster and better than the transformer but more difficult to train
The transformer neural architecture is state-of-the-art. It scales very well, i.e., the larger models learn better, and is efficient to train thanks to the parallel computation of the attention.
However, the transformer also has a few drawbacks, especially for inference. The computational cost of the attention grows quadratically with the length of the sequence to process. Many techniques have been proposed to alleviate this cost, such as Alibi and RoPE.
Alternative neural architectures have also been proposed, such as RWKV and Mamba, a state-space model (SSM), which are attention-free. They are much more efficient for inference than the transformer but still underperform in terms of accuracy.
To take advantage of both the transformer and SSM architecture, Jamba has been proposed. This hybrid model combines SSM and transformer layers. This combination allows balancing memory usage, efficient training, and long context capabilities.
Jamba performs as well as Mixtral-7x8B, one of the best open LLMs, but is more efficient, especially when dealing with long context.
In this article, I review Jamba. We will have a close look at its architecture and training. We will also explore how to fine-tune and quantize the model to reduce its size.
I made a notebook for fine-tuning Jamba here: