See also jaxtyping which, contrary to what its name might imply, covers JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.
https://docs.kidger.site/jaxtyping/
I use jaxtyping as documentation, but the fact it can only be used for runtime checking (in a slightly clunky manner) and can't infer shapes based on ops really limits its utility imo.
I use jaxtyping as documentation, but the fact it can only be used for runtime checking (in a slightly clunky manner) and can't infer shapes based on ops really limits its utility imo.