# bd3lms
**Repository Path**: wang-guanglei99/bd3lms
## Basic Information
- **Project Name**: bd3lms
- **Description**: No description available
- **Primary Language**: Python
- **License**: Not specified
- **Default Branch**: main
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2025-03-13
- **Last Updated**: 2025-03-13
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
# [Block Diffusion Interpolates Between Autoregressive and Diffusion Language Models](https://openreview.net/forum?id=tyEyYT267x) (ICLR 2025 Oral)
By [Marianne Arriola](https://m-arriola.com/), [Aaron Gokaslan](https://skylion007.github.io), [Justin T Chiu](https://justinchiu.netlify.app), [Zhihan Yang](https://zhihanyang2022.github.io/), [Zhixuan Qi](https://zhixuanqi.com/), [Jiaqi Han](https://hanjq17.github.io/), [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)
[](https://openreview.net/forum?id=tyEyYT267x)
[](https://m-arriola.com/bd3lms/)
[](https://huggingface.co/collections/kuleshov-group/bd3-lms-67be95f81b96b15fec50d53f)

We introduce ***BD3-LMs***, a family of **B**lock **D**iscrete **D**enoising **D**iffusion **L**anguage **M**odels that achieve SOTA likelihoods among diffusion models and enable generation of arbitrary-length sequences. BD3-LMs combine the strengths of autoregressive and diffusion language models by decomposing a token sequence into blocks and performing discrete diffusion within each block. By tuning the block size, we interpolate between autoregressive and diffusion models which introduces a trade-off between quality and sample efficiency. We propose a recipe of building effective BD3-LMs that includes an efficient training algorithm, estimators of gradient variance, and data-driven noise schedules to minimize the variance.
In this repo, we provide:
* **The BD3-LM framework**
1. Block-autoregressive likelihood parameterization
2. Data-driven noise schedules to reduce training variance
3. Arbitrary-length discrete diffusion samplers
* **Baseline implementations**
1. Autoregressive model [[AR](https://arxiv.org/abs/2406.07524)]
2. Score Entropy Based Discrete Diffusion [[SEDD](https://arxiv.org/abs/2310.16834)]
3. Masked Diffusion Language Model [[MDLM](https://arxiv.org/abs/2406.07524)]
4. Semi-autoregressive Simplex-based Diffusion Language Model [[SSD-LM](https://arxiv.org/pdf/2210.17432)] *(supports sample generation only)*
## Code Organization
1. ```main.py```: Routines for training and evaluation
2. ```noise_schedule.py```: Noise schedules
3. ```diffusion.py```: Forward/reverse diffusion
4. ```dataloader.py```: Dataloaders
5. ```utils.py```: LR scheduler, logging, `fsspec` handling
6. ```models/```: Network architectures. Supports [DiT](https://arxiv.org/abs/2212.09748) and AR transformer
7. ```configs/```: Config files for datasets/models/noise schedules/LR schedules
8. ```scripts/```: Shell scripts for training/evaluation
- ``train/``: Training scripts (LM1B, OWT)
- ``ppl/``: Likelihood evaluation on the pretraining set (LM1B, OWT)
- ``zs_ppl/``: Zero-shot likelihood evaluation on GPT2 benchmark datasets
- ``gen_ppl/``: Sample quality (generative perplexity under GPT2)
- ``var_len/``: Arbitrary-length sequence generation
9. ```ssd-lm/```: SSD-LM codebase
- ```run_generate_text_batch.sh```: Generates SSD-LM samples
- ```report_genppl.py```: Reports generative perplexity of SSD-LM samples
## Getting Started
To get started, create a conda environment containing the required dependencies.
```bash
conda env create -f requirements.yaml
conda activate bd3lm
```
Create the following directories to store saved models and slurm logs:
```bash
mkdir outputs watch_folder logs sample_logs
```
and run the training as a batch job:
```bash
sbatch scripts/train/train_owt_bd3lm.sh
```
### Checkpoints
We have uploaded BD3-LMs trained on OpenWebText using block sizes 4, 8, 16 for 1M training steps to HuggingFace 🤗:
[kuleshov-group/bd3-lms](https://huggingface.co/collections/kuleshov-group/bd3-lms-67be95f81b96b15fec50d53f) BD3-LMs are finetuned from an MDLM checkpoint trained on OpenWebText for 850K gradient updates. We release the pretraining checkpoint in this [Google Drive folder](https://drive.google.com/drive/folders/1Vm4YZBX7bzVuHhIbkY1RHTUsf8v71oew?usp=sharing).
The MDLM baseline is also found on the HuggingFace:
[kuleshov-group/mdlm-owt](https://huggingface.co/kuleshov-group/mdlm-owt). The AR and SEDD baselines trained on OpenWebText in this [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing).
For arbitrary-length sequence generation, we compare with AR, MDLM (supported as an inference-only technique and does not feature a training objective), and SSD-LM. In order to generate sequences longer than the training context size (fixed at 1024 tokens for OWT), we retrained AR and MDLM from Sahoo et. al without artificially injecting BOS/EOS tokens in the context. We also provide these checkpoints in this [Google Drive folder](https://drive.google.com/drive/folders/1Vm4YZBX7bzVuHhIbkY1RHTUsf8v71oew?usp=sharing).
## Reproducing Experiments
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the [`main.py`](./main.py) script.
We also provide sample `slurm` scripts for launching pre-training and downstream fine-tuning experiments in the [`scripts/`](./scripts) directory.
### Generate Arbitrary-Length Sequences
To generate arbitrary-length sequences, set `mode=sample_eval`. Example scripts are provided in `scripts/var_len/var_len*.sh`. Here's an example script using BD3-LM:
#### HuggingFace model
```bash
BLOCK_SIZE=4 # 4, 8, 16
LENGTH=2048 # arbitrary; needs to be a multiple of the block size
python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
algo.backbone=hf_dit \
data=openwebtext-split \
model.length=$LENGTH \
block_size=$BLOCK_SIZE \
wandb=null \
mode=sample_eval \
eval.checkpoint_path=kuleshov-group/bd3lm-owt-block_size${BLOCK_SIZE} \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize${BLOCK_SIZE}
```
#### Local checkpoint
```bash
BLOCK_SIZE=4 # 4, 8, 16
LENGTH=2048 # arbitrary; needs to be a multiple of the block size
python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
data=openwebtext-split \
model.length=$LENGTH \
block_size=$BLOCK_SIZE \
wandb=null \
mode=sample_eval \
eval.checkpoint_path=/path/to/checkpoint/bd3lm-owt-block_size${BLOCK_SIZE} \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize${BLOCK_SIZE}
```
### Likelihood Evaluation
To compute test perplexity, use `mode=ppl_eval`. Example scripts are provided in `scripts/ppl/eval_owt_*.sh`. Here's an example evaluation script on OpenWebText:
```bash
BLOCK_SIZE=4 # 4, 8, 16
python -u main.py \
loader.eval_batch_size=16 \
model=small \
algo=bd3lm \
algo.backbone=hf_dit \
data=openwebtext-split \
data.insert_valid_special=False \
model.length=1024 \
model.attn_backend=sdpa \
block_size=${BLOCK_SIZE} \
eval.checkpoint_path=kuleshov-group/bd3lm-owt-block_size${BLOCK_SIZE} \
wandb=null \
mode=ppl_eval > logs/bd3lm_owt_block_size${BLOCK_SIZE}.log
```
### Training Pipeline
To train BD3-LMs, use `mode=train` (default mode). Example scripts are provided in `scripts/train/train_owt*.sh`. Here's an example training script on OpenWebText:
```bash
BLOCK_SIZE=4 # we recommend 4, 8, or 16. must be a factor of the context length
PRETRAIN_CKPT=$PWD/bd3lm_base_owt_850k.ckpt # to train from scratch, set to null
python -u main.py \
loader.global_batch_size=512 \
loader.eval_global_batch_size=512 \
loader.batch_size=8 \
loader.eval_batch_size=8 \
model=small \
algo=bd3lm \
data=openwebtext-split \
model.length=1024 \
block_size=$BLOCK_SIZE \
wandb.name=bd3lm-owt-block_size${BLOCK_SIZE} \
mode=train \
model.attn_backend=sdpa \
training.from_pretrained=$PRETRAIN_CKPT
```
The arguments `loader.batch_size` and `loader.eval_batch_size` allow you to control the batch size per GPU. If `loader.batch_size * num_gpus` is less than the global_batch_size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command: `sbatch scripts/train/train_owt_bd3lm.sh`.
### Acknowledgements
This repository was built off of [MDLM](https://github.com/kuleshov-group/mdlm) and [SEDD](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion).
## Citation
```
@inproceedings{
arriola2025block,
title={Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models},
author={Marianne Arriola and Subham Sekhar Sahoo and Aaron Gokaslan and Zhihan Yang and Zhixuan Qi and Jiaqi Han and Justin T Chiu and Volodymyr Kuleshov},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=tyEyYT267x}
}
```