A PyTorch implementation of a Mixtral inspired transformer model with Mixture of Experts (MoE), designed for text generation and understanding tasks. This model is built on the Mixtral architecture with enhancements like Flash Attention, SWiGLU activation, and Liger kernels for optimized performance.
- So, I trained a MoE based a 124M (8x12M) architecture I coded from ground up.
- Trained on TinyStories dataset from HuggingFace consisting of 1M texts for a total of 14000 steps
Provided under the generated_data/
directory, these examples showcase the model's capabilities in text generation and understanding.
📈 View Training Report: SmolMixtral Training Results on WandB
💾 Download Pre-trained Weights:
- Hugging Face Model: YuvrajSingh9886/SmolMixtral
- WandB Checkpoints: Check the WandB report above for additional trained model checkpoints
- Flash Attention: Efficient attention mechanism with memory optimization
- Mixture of Experts (MoE): 8 experts with top-2 routing and noisy top-k support
- SWiGLU Activation: Advanced activation function in expert layers
- Rotary Positional Embeddings: Position encoding for sequence understanding
- Liger Kernels: Optimized kernels for faster training (optional)
- Distributed Training: Support for multi-GPU training with DDP
- Advanced Optimizer: AdamW optimizer with custom learning rate scheduling
- Gradio Interface: Interactive web interface for text generation
- Embedding Dimensions: 768
- Decoder Layers: 12
- Attention Heads: 8
- MoE Experts: 8 (top-2 routing)
- Block Size: 1024 tokens
- Vocabulary Size: Based on Llama-2-7b tokenizer (~32,000 tokens)
- Batch Size: 16
epochs
: Number of training epochs (default: 4)block_size
: Maximum sequence length (default: 1024)batch_size
: Training batch size (default: 16)embeddings_dims
: Model embedding dimensions (default: 512)no_of_heads
: Number of attention heads (default: 8)no_of_decoder_layers
: Number of decoder layers (default: 12)attn_dropout
: Attention dropout rate (default: 0.1)dropout
: General dropout rate (default: 0.1)
experts
: Number of MoE experts (default: 8)top_experts
: Number of experts to route to (default: 2)noisy_topk
: Use noisy top-k routing (default: False)
max_lr
: Maximum learning rate (default: 6e-4)weight_decay_optim
: Weight decay for optimizer (default: 0.01)beta_1
: Beta1 for optimizer (default: 0.9)beta_2
: Beta2 for optimizer (default: 0.95)eps
: Epsilon for optimizer (default: 1e-8)clip
: Gradient clipping value (default: 1.0)
device
: Device to use (default: 'cuda:9')use_checkpointing
: Use gradient checkpointing (default: False)use_liger
: Use Liger kernels for optimization (default: True)use_flash_attention
: Use Flash Attention (default: True)use_compile
: Use torch.compile (default: True)
vocab_size
: Vocabulary size (default: based on tokenizer + 768)val_epochs
: Validation frequency (default: 2)
chmod +x install.sh
./install.sh
Since this model uses the Llama-2 tokenizer, you'll need a Hugging Face token to access the gated model.
-
Get a Hugging Face Token:
- Go to Hugging Face Settings
- Create a new token with "Read" permissions
- Accept the Llama-2 license at meta-llama/Llama-2-7b-hf
-
Set your token in config.py:
TOKEN = 'your_token_here'
-
Download Model Weights:
- Option 1: Download from Hugging Face - YuvrajSingh9886/SmolMixtral
- Option 2: Visit the WandB Training Report for additional checkpoints
- Place downloaded files in the
checkpoints/
directory
-
Load Pre-trained Model for Inference:
# Using the Gradio web interface cd gradio python app.py # Or use in your own code python inference.py
python trainer.py
# Train with larger model (modify config.py)
python trainer.py
# Train with different dataset (modify data.py)
python trainer.py
# 2 GPUs
torchrun --nproc_per_node=2 trainer.py
# 4 GPUs
torchrun --nproc_per_node=4 trainer.py
# 8 GPUs
torchrun --nproc_per_node=8 trainer.py
HF_TOKEN should be set in config.py
to use the Gradio interface. Moreover, set your token as follows:
export HF_TOKEN=<TOKEN_HERE>
# Run the Gradio app
cd gradio
python app.py
# With custom checkpoint (edit app.py to point to your checkpoint)
cd gradio
python app.py
SmolMixtral/
├── config.py # Model configuration and hyperparameters
├── model.py # Model architecture (Mixtral, MoE, Attention, etc.)
├── data.py # Data loading and preparation
├── inference.py # Inference functions and text generation
├── trainer.py # Main training loop with DDP support
├── install.sh # Setup script
├── requirements.txt # Python dependencies
├── model_summary.py # Model architecture summary
├── gradio/
│ └── app.py # Gradio web interface
├── checkpoints/ # Model checkpoints
├── generated_data/ # Generated text outputs
├── images/ # Project images
└── old/ # Original files
- Gradient Accumulation: Configurable batch size scaling
- Learning Rate Scheduling: Cosine decay with warmup
- Gradient Clipping: Prevents gradient explosion
- Wandb Integration: Experiment tracking and logging
- Checkpointing: Regular model checkpoints during training
- Loss Calculation: Optimized cross-entropy with padding token handling
- Distributed Training: Multi-GPU support with DDP
- Memory Optimization: Gradient checkpointing support
- Top-k Sampling: Traditional sampling with temperature control
All parameters can be configured by modifying config.py
:
@dataclass
class ModelArgs:
epochs = 4
block_size = 1024
batch_size = 16
embeddings_dims = 512
# ... other parameters
Modify data.py
to use different datasets:
# TinyStories (default)
tinystories = True
fw = False
# FineWeb
tinystories = False
fw = True
Training automatically logs to WandB with project name "Mixtral-DDP-Pretrain-10-billion-tokens"
- Use Liger Kernels: Keep
use_liger = True
for optimized operations - Flash Attention: Keep
use_flash_attention = True
for memory efficiency - Gradient Checkpointing: Use
use_checkpointing = True
for memory-constrained setups - Batch Size Tuning: Start with smaller batch sizes and increase gradually
- Block Size: Larger block sizes improve quality but require more memory
# Make sure you have accepted the Llama-2 license and have a valid token
# Visit: https://huggingface.co/meta-llama/Llama-2-7b-hf
# Then set your token in config.py
# Reduce batch size and enable checkpointing in config.py
batch_size = 8
use_checkpointing = True
# Enable optimizations in config.py
use_liger = True
use_flash_attention = True
use_compile = True
Feel free to contribute improvements, bug fixes, or new features!
- Python 3.8+
- PyTorch 2.0+
- Transformers
- Datasets
- Gradio
- Wandb
- Liger-kernel (optional)
MIT License