# tensor_annotations **Repository Path**: mirrors_deepmind/tensor_annotations ## Basic Information - **Project Name**: tensor_annotations - **Description**: Annotating tensor shapes using Python types - **Primary Language**: Unknown - **License**: Apache-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2021-02-27 - **Last Updated**: 2025-10-05 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # TensorAnnotations :warning: WARNING: TensorAnnotations is no longer being maintained. Instead, we recommend users switch to [jaxtyping](https://github.com/google/jaxtyping). For more information, see [Why TensorAnnotations is being deprecated](https://docs.google.com/document/d/1AAP-wq06j1TQwJPtrlky4lfyPHyl7-itgN5S47oZO98/edit). --- TensorAnnotations is an experimental library enabling annotation of data-type and semantic shape information using type annotations - for example: ```python def calculate_loss(frames: Array4[uint8, Time, Batch, Height, Width]): ... ``` This annotation states that the data-type of `frames` is `uint8`, and that the dimensions are time-like, batch-like, etc. (while saying nothing about the actual _values_ - e.g. the actual batch size). Why? Two reasons: * Shape annotations can be checked _statically_. This can catch a range of bugs caused by e.g. wrong selection or reduction of axes before you run your code - even when the errors would not necessarily throw a runtime exception! * Interface documentation (also enabling shape autocompletion in IDEs). To do this, the library provides three things: * A set of custom tensor types for TensorFlow and JAX, supporting the above kinds of annotations * A collection of common semantic labels (e.g. `Time`, `Batch`, etc.) * Type stubs for common library functions that preserve semantic shape information (e.g. `reduce_sum(Tensor[Time, Batch], axis=0) -> Tensor[Batch]`) TensorAnnotations is being developed for JAX and TensorFlow. ## Example Here is some code that takes advantage of static shape checking: ```python import tensorflow as tf from tensor_annotations import axes import tensor_annotations.tensorflow as ttf uint8 = ttf.uint8 Batch, Time = axes.Batch, axes.Time def sample_batch() -> ttf.Tensor2[uint8, Time, Batch]: return tf.zeros((3, 5)) def train_batch(batch: ttf.Tensor2[uint8, Batch, Time]): m: ttf.Tensor1[uint8, Batch] = tf.reduce_max(batch, axis=1) # Do something useful def main(): batch1 = sample_batch() batch2 = tf.transpose(batch1) train_batch(batch2) ``` This code contains shape annotations in the signatures of `sample_batch` and `train_batch`, and in the line calling `reduce_max`. It is otherwise the same code you would have written in an unchecked program. You can check these annotations for inconsistencies by running a static type checker on your code (see 'General usage' below). For example, running `train_batch` directly on `batch1` will result in the following error from pytype: ``` File "example.py", line 10: Function train_batch was called with the wrong arguments [wrong-arg-types] Expected: (batch: Tensor2[uint8, Batch, Time]) Actually passed: (batch: Tensor2[uint8, Time, Batch]) ``` Similarly, changing the the call to `reduce_max` from `axis=1` to `axis=0` results in: ``` File "example.py", line 15: Type annotation for m does not match type of assignment [annotation-type-mismatch] Annotation: Tensor1[uint8, Batch] Assignment: Tensor1[uint8, Time] ``` (These messages were shortened for readability. The actual errors will be more verbose because fully qualified type names will be displayed. We are looking into improving this.) See `examples/tf_time_batch.py` for a complete example. ## Requirements TensorAnnotatations requires Python 3.8 or above, due to the use of `typing.Literal`. ## Installation To install custom tensor types: ```bash pip install tensor_annotations ``` Then, depending on whether you use JAX or TensorFlow: ```bash pip install tensor_annotations_jax_stubs # and/or pip install tensor_annotations_tensorflow_stubs ``` If you use pytype, you'll also need to take a few extra steps to let it take advantage of JAX/TensorFlow stubs (since it doesn't yet support PEP 561 stub packages). First, make a copy of typeshed in e.g. your home directory: ```bash git clone https://github.com/python/typeshed "$HOME/typeshed" ``` Next, symlink the stubs into your copy of typeshed: ```bash site_packages=$(python3 -m site --user-site) # Custom tensor classes mkdir -p "$HOME"/typeshed/stubs/{tensor_annotations/tensor_annotations,tensorflow,jax} ln -s "$site_packages/tensor_annotations/__init__.py" "$HOME/typeshed/stubs/tensor_annotations/tensor_annotations/__init__.pyi" ln -s "$site_packages/tensor_annotations/jax.pyi" "$HOME/typeshed/stubs/tensor_annotations/tensor_annotations/jax.pyi" ln -s "$site_packages/tensor_annotations/tensorflow.pyi" "$HOME/typeshed/stubs/tensor_annotations/tensor_annotations/tensorflow.pyi" ln -s "$site_packages/tensor_annotations/axes.py" "$HOME/typeshed/stubs/tensor_annotations/tensor_annotations/axes.pyi" # TensorFlow ln -s "$site_packages/tensorflow-stubs" "$HOME/typeshed/stubs/tensorflow/tensorflow" # JAX ln -s "$site_packages/jax-stubs" "$HOME/typeshed/stubs/jax/jax" ``` ## General usage First, import `tensor_annotations` and start annotating function signatures and variable assignments. This can be done gradually. Next, run a static type checker on your code. If you use Mypy, it should just work. If you use pytype, you need to invoke it in a special way in order to let it know about the custom typeshed installation: ``` TYPESHED_HOME="$HOME/typeshed" pytype your_code.py ``` We recommend you deliberately introduce a shape error and then confirm that your type checker gives you an error to be sure you're set up correctly. ### Annotated tensor classes TensorAnnotations provides tensor classes for JAX and TensorFlow: ```python # JAX import tensor_annotations.jax as tjax tjax.arrayN # Where N is the rank of the tensor # TensorFlow import tensor_annotations.tensorflow as ttf ttf.TensorN # Where N is the rank of the tensor ``` These classes can be parameterized by semantic axis labels (below) using generics, similar to `List[int]`. (Different classes are needed for each rank because Python currently does not support variadic generics, but we're working on it.) ### Data types TensorAnnotations also provides its own data-type types: ```python # JAX from tensor_annotations.jax import uint8, float32 # Etc # TensorFlow from tensor_annotations.tensorflow import uint8, float32 # Etc ``` This is because, for various reasons, the native data-type types like `tf.uint8` and `jnp.uint8` are unsuitable for use in type annotations. See `tensorflow.py` and `jax.py` for more information. ### Axis labels Axis labels are used to indicate the semantic meaning of each dimension in a tensor - whether the dimension is batch-like, features-like, etc. Note that no connection is made between the symbol, e.g. `Batch`, and the actual _value_ of that dimension (e.g. the batch size) - the symbol really does only describe the semantic meaning of the dimension. See `axes.py` for the list of axis labels we provide out of the box. To define a custom axis label, simply subclass `tensor_annotations.axes.Axis`. You can also use `typing.NewType` to do this using a single line: ```python CustomAxis = typing.NewType('CustomAxis', axes.Axis) ``` In the future we intend to support axis types that are tied to the actual size of that axis. Currently, however, we don't have a good way of doing this. If you nonetheless want to annotate certain dimensions with a literal size, e.g. for documentation of interfaces which are hardcoded for specific sizes, we recommend you just use a custom axis for this purpose. (Just to be clear, though: these sizes will _not_ be checked - neither statically, nor at runtime!) ```python L64 = typing.NewType('L64', axes.Axis) ``` ### Stubs By default, TensorFlow and JAX are not aware of our annotations. For example, if you have a tensor `x: Array2[uint8, Time, Batch]` and you call `jnp.sum(x, axis=0)`, you won't get a `Array1[uint8, Batch]`, you'll just get an `Any`. We therefore provide a set of custom type annotations for TensorFlow and JAX packaged in 'stub' (`.pyi`) files. Our stubs currently cover the following parts of the API. All operations are supported for rank 1, 2, 3 and 4 tensors, unless otherwise noted. Unary operators are also supported for rank 0 (scalar) tensors. #### TensorFlow See [Coverage](docs/coverage.md). **Tensor unary operators**: For tensor `x`: `abs(x)`, `-x`, `+x` **Tensor binary operators**: For tensors `a` and `b`: `a + b`, `a / b`, `a // b`, `a ** b`, `a < b`, `a > b`, `a <= b`, `a >= b`, `a * b`. Yet to be typed: `a ? float`, `a ? int` for `Tensor0`, broadcasting where one axis is 1 #### JAX See [Coverage](docs/coverage.md). **Tensor unary operators**: For tensor `x`, `abs(x)`, `-x`, `+x` **Tensor binary operators**: For tensors `a` and `b`, `a + b`, `a / b`, `a // b`, `a ** b`, `a < b`, `a > b`, `a <= b`, `a >= b`, `a * b`. Yet to be typed: `a ? float`, `a ? int` for `Tensor0`, broadcasting where one axis is 1 ### Casting Some of your code might be already typed with existing library tensor types: ```python def sample_batch() -> jnp.array: ... ``` If this is the case, and you don't want to change these types globally in your code, you can cast to TensorAnnotations classes with `typing.cast`: ```python from typing import cast x = cast(tjax.Array2[uint8, Batch, Time], x) ``` Note that this is only a hint to the type checker - at runtime, it's a no-op. An alternative syntax emphasising this fact is: ```python x: tjax.Array2[uint8, Batch, Time] = x # type: ignore ``` ## Gotchas **Use tuples for shape/axis specifications** For type inference with TensorFlow and JAX API functions we often have to match additional arguments. I.e., the rank of a `tf.zeros(...)` tensor depends on the length of the shape argument. This only works with tuples, and not with lists: ```python a = tf.zeros((10, 10)) # Correctly infers type Tensor2[Any, Any] b: ttf.Tensor2[uint8, Time, Batch] = get_batch() c = tf.transpose(b, perm=(0, 1)) # Tracks and infers the axes-types of b ``` while ```python a = tf.zeros([10, 10]) # Returns Any b: ttf.Tensor2[uint8, Time, Batch] = get_batch() c = tf.transpose(b, perm=[0, 1])) # Does not track permutations and returns Any ``` **Runtime vs static checks** Note that we do not verify that the rank of a tensor at runtime matches the one specified in the annotations. If you were in an evil mood, you could create an untyped (Any) tensor, and statically type it as something completely wrong. This is in line with the rest of the python type-checking approach, which does not *enforce* consistency with the annotated types at runtime. **Value consistency**. Not only do we not verify the rank, we don't verify anything about the actual shape value either. The following will _not_ raise an error: ```python x: tjax.Array1[uint8, Batch] = jnp.zeros((3,)) y: tjax.Array1[uint8, Batch] = jnp.zeros((5,)) ``` Note that _this is by design_! Shape symbols such as `Batch` are _not_ placeholders for actual values like 3 or 5. Symbols only refer to the _semantic meaning_ of a dimension. In the above example, say, `x` might be a train batch, and `y` might be a test batch, and therefore they have different sizes, even though both of their dimensions are batch-like. This means that even element-wise operations like `z = x + y` would in this case not raise a type-check error. ## FAQs **Why doesn't e.g. `tjax.ArrayN` subclass `jnp.DeviceArray`?** We'd *like* this to be the case, but haven't figured out how to yet because of circular dependencies: * `ArrayN` is defined in `tensor_annotations.jax`, which would need to import `jax.numpy` in order to subclass `jnp.DeviceArray`. * However, our `jax.numpy` stubs make use of `ArrayN`, so `jax.numpy` itself needs to import `tensor_annotations.jax`. We ultimate solution to this will hopefully be to upstream our `ArrayN` classes such that they can be defined in `jax.numpy` too. Until then, we'll just be trying to make e.g. `tjax.ArrayN` look as close to `jnp.DeviceArray` as possible through dummy methods and dummy attributes so that autocomplete still works. If there are particular methods/attributes you'd like added, please do let us know. **Why are so many methods annotated as `Any` in the JAX stubs?** We don't yet have a good way of automatically generating stubs in general. For the methods where we *do* generate stubs automatically (all the ones not annotated as `Any`), we've checked their signature manually and written stub generators for each method individually. Ideally we'd start from stubs generated by e.g. pytype and then customise them to include shape information, but we haven't got around to setting this up yet. **Why not use [PEP 646](https://peps.python.org/pep-0646)?** Compatibility. There are two factors: a) concise syntax for PEP 646 is only available in Python 3.11 onwards, which not everyone can migrate to yet; and b) PEP 646 is (as of January 2022) only supported by Pyre and Pyright - not by Mypy or pytype, which are both popular. Type checker support is the biggest thing - so once there _is_ better support for PEP 646 in Mypy and pytype, we may revisit this question. ## See also This library is one approach of many to checking tensor shapes. We don't expect it to be the final solution; we create it to explore one point in the space of possibilities. Other tools for checking tensor shapes include: * [Pythia](https://yanniss.github.io/tensor-ecoop20.pdf), a static analyzer designed specifically for detecting TensorFlow shape errors * [tsanley](https://github.com/ofnote/tsanley), which uses string annotations combined with runtime verification * [PyContracts](https://github.com/AndreaCensi/contracts), a general-purpose library for specifying constraints on function arguments that has special support for NumPy * [Shape Guard](https://github.com/Qwlouse/shapeguard), another runtime verification tool using concise helper methods * [swift-tfp](https://github.com/google-research/swift-tfp), a static analyzer for tensor shapes in Swift To learn more about tensor shape checking in general, see: * Stephan Hoyer's [Ideas for array shape typing in Python](https://docs.google.com/document/d/1vpMse4c6DrWH5rq2tQSx3qwP_m_0lyn-Ij4WHqQqRHY/edit) document * The [Typing for multi-dimensional arrays](https://github.com/python/typing/issues/513) GitHub issue in `python/typing` * Our [Shape annotation feature scoping](https://docs.google.com/document/d/1t-j1MJ9M0f0KMAnM22J97tCHSfVoFjAy9k4Lexi75c4/edit) and our [Shape annotation syntax proposal](https://docs.google.com/document/d/1But-hjet8-djv519HEKvBN6Ik2lW3yu0ojZo6pG9osY/edit) documents (a synthesis of the most promising ideas from the full doc) * The Python [typing-sig](https://mail.python.org/archives/list/typing-sig@python.org/) mailing list (in particular, [this thread](https://mail.python.org/archives/list/typing-sig@python.org/thread/IOBJGI5SJCUHJAUE4BOULGFBBEO5DCVG/) ) * [Notes and recordings](https://docs.google.com/document/d/1oaG0V2ZE5BRDjd9N-Tr1N0IKGwZQcraIlZ0N8ayqVg8/edit) from the Tensor Typing Open Design Meetings ## Repository structure The `tensor_annotations` package contains four types of things: * **Custom tensor classes**. We provide our own versions of e.g. TensorFlow's `Tensor` class and JAX's `Array` class in order to support shape parameterisation. These are stored in **`tensorflow.py`** and **`jax.py`**. (Note that these are only used in the context of type annotations - they are never instantiated - hence no implementation being present.) * **Type stubs for custom tensor classes**. We also need to provide type annotations specifying what the shape of, say, `x: Tensor[A, B] + y: Tensor[B]` is. These are **`tensorflow.pyi`** and **`jax.pyi`**. * These are generated from `templates/tensors.pyi` using `tools/render_tensor_template.py`. * **Type stubs for library functions**. Finally, we need to specify what the shape of, say, `tf.reduce_sum(x: Tensor[A, B], axis=0)` is. This information is stored in type stubs in **`library_stubs`**. (The `third_party/py` directory structure is necessary to indicate to pytype exactly which packages these stubs are for.) Ideally, these will eventually live in the libraries themselves. * JAX stubs are auto-generated from `templates/jax.pyi` using `tools/render_jax_library_template.pyi`. Note that we currently specify the signature of the library members we don't generate automatically as `Any`. Ideally, we'd like to automatically generate complete type stubs and then tweak them to include shape information, but we haven't gotten around to this yet. * For TensorFlow stubs, we start from stubs generated by a Google-internal TensorFlow stub generator and then hand-edit those stubs to include shape stubs. The edits we've made are demarcated by `BEGIN/END tensor_annotations annotations for ...` blocks. Again, we'll make this more automated in the future. * **Common axis types**. Finally, we also provide a canonical set of common axis labels such as 'time' and 'batch'. These are stored in **`axes.py`**.