# diffrax **Repository Path**: mirrors_patrick-kidger/diffrax ## Basic Information - **Project Name**: diffrax - **Description**: Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/ - **Primary Language**: Unknown - **License**: Apache-2.0 - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2023-02-27 - **Last Updated**: 2026-04-13 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README

Diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a [JAX](https://github.com/google/jax)-based library providing numerical differential equation solvers. Features include: - ODE/SDE/CDE (ordinary/stochastic/controlled) solvers; - lots of different solvers (including `Tsit5`, `Dopri8`, symplectic solvers, implicit solvers); - vmappable _everything_ (including the region of integration); - using a PyTree as the state; - dense solutions; - multiple adjoint methods for backpropagation; - support for neural differential equations. _From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library._ ## Installation ``` pip install diffrax ``` Requires Python 3.10+. ## Documentation Available at [https://docs.kidger.site/diffrax](https://docs.kidger.site/diffrax). ## Quick example ```python from diffrax import diffeqsolve, ODETerm, Dopri5 import jax.numpy as jnp def f(t, y, args): return -y term = ODETerm(f) solver = Dopri5() y0 = jnp.array([2., 3.]) solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0) ``` Here, `Dopri5` refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems. ## Citation If you found this library useful in academic research, please cite: [(arXiv link)](https://arxiv.org/abs/2202.02435) ```bibtex @phdthesis{kidger2021on, title={{O}n {N}eural {D}ifferential {E}quations}, author={Patrick Kidger}, year={2021}, school={University of Oxford}, } ``` (Also consider starring the project on GitHub.) ## See also: other libraries in the JAX ecosystem **Always useful** [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays. **Deep learning** [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees. **Scientific computing** [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. [Lineax](https://github.com/patrick-kidger/lineax): linear solvers. [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) **Awesome JAX** [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.