February 27th, 2023 · 7 min read When doing any sort of tensor/array computation in 1import numpy as np 2 3size1 = (2,3) 4size2 = (4,3) 5 6M1 = np.random.random(size=size1) 7M2 = np.random.random(size=size2) 8 9try: 10 print(np.dot(M1,M2)) 11except Exception as e: 12 print(e) 1shapes (2,3) and (4,3) not aligned: 3 (dim 1) != 4 (dim 0) And most of the time, these kind of errors boil down to something likeaccidentally forgetting to do a reshape or transpose like so. 1import numpy as np 2 3size1 = (2,3) 4size2 = (4,3) 5 6M1 = np.random.random(size=size1) 7M2 = np.random.random(size=size2).T 8 9try: 10 print(np.dot(M1,M2)) 11except Exception as e: 12 print(e) 1[[0.68812413 0.63491692 0.375332 1.22395427] 2 [0.57381506 0.42578404 0.19132443 0.8889217 ]] And while this is a mild case, shape bugs like these become more frequent asoperations grow more complex and as more dimensions are involved. Here's a slightly more complex example of a 1def Linear(A, x, b): 2 """ 3 Takes matrix A (m x n) times a vector x (n x 1) and 4 adds a bias. The resulting ndarray is then ravelled 5 into a vector of size (m). 6 """ 7 Ax = np.dot(A, x) 8 Axb = np.add(Ax, b) 9 return np.ravel(Axb) 10 11A = np.random.random(size=(4,4)) 12x = np.random.random(size=(4,1)) 13b = np.random.random(size=(4)) 14 15result = Linear(A, x, b) 16print(result) 17print(result.shape) 1[1.18041914 1.87580329 0.93373901 1.48799234 1.4920404 2.18742455 2 1.24536027 1.79961361 2.29649806 2.99188221 2.04981793 2.60407127 3 1.31159899 2.00698314 1.06491886 1.6191722 ] 4(16,) The docstring of If we break down Now to fix this is easy, we just need to initialize our 1def Linear(A, x, b): 2 """ 3 Takes matrix A (m x n) times a vector x (n x 1) and 4 adds a bias. The resulting ndarray is then ravelled 5 into a vector of size (m). 6 """ 7 Ax = np.dot(A, x) 8 Axb = np.add(Ax, b) 9 return np.ravel(Axb) 10 11A = np.random.random(size=(4,4)) 12x = np.random.random(size=(4,1)) 13b = np.random.random(size=(4,1)) 14 15result = Linear(A, x, b) 16print(result) 17print(result.shape) 1[1.15227694 1.24640271 0.63951685 1.13304944] 2(4,) We've solved the problem, but how can we be smarter to prevent this error fromhappening again? The simplest way we can try to stop this shape bug is with good docs. Ideallywe should always have good docs, but we can make it a point to include whatthe shape expectations are like so: 1def Linear(A, x, b): 2 """ 3 Args: 4 A: ndarray of shape (M x N) 5 x: ndarray of shape (N x 1) 6 b: ndarray of shape (M x 1) 7 8 Returns: 9 Linear output ndarray of shape (M) 10 """ 11 Ax = np.dot(A, x) # Shape (M x 1) 12 Axb = np.add(Ax, b) # (M x 1) + (M x 1) 13 return np.ravel(Axb) # Shape (M) Now while informative, nothing is preventing us from encountering the same bugagain. The only benefit this gives us, is making the debugging process abit easier. We can do better. Another approach in addition to good docs that's more of a preventative actionis to use assertions. By sprinkling 1def Linear(A, x, b): 2 """ 3 Args: 4 A: ndarray of shape (M x N) 5 x: ndarray of shape (N x 1) 6 b: ndarray of shape (M x 1) 7 8 Returns: 9 Linear output ndarray of shape (M) 10 """ 11 assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}" 12 Am, An = A.shape 13 14 assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot" 15 Ax = np.dot(A, x) # Shape (M x 1) 16 17 assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)" 18 result = np.add(Ax, b) # (M x 1) + (M x 1) 19 20 ravel_result = np.ravel(result) 21 assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}" 22 return ravel_result At every step of this function we do an As a result Incomplete checking: Have we checked all expected shape failure modes? Slow debugging cycles: How many refactor->run cycles will we have to dopass the checks? Additional testing: Do we have to update our tests cover our runtime errorchecks? Overall runtime error checking is not a bad thing. In most cases it's verynecessary! But when it comes to shape errors, we can leverage an additionalapproach, static type checking. Even though Over time many third party libraries (like In order to help us prevent shape errors, let's see what typing capabilitiesexist in As of writing this post, Using 1from typing import TypeVar 2 3import numpy as np 4from numpy.typing import NDArray 5 6GenericType = TypeVar("GenericType", bound=np.generic) 7 8 9def Linear( 10 A: NDArray[GenericType], 11 x: NDArray[GenericType], 12 b: NDArray[GenericType], 13) -> NDArray[GenericType]: 14 """ 15 Args: 16 A: ndarray of shape (M x N) 17 x: ndarray of shape (N x 1) 18 b: ndarray of shape (M x 1) 19 20 Returns: 21 Linear output ndarray of shape (M) 22 """ 23 assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}" 24 Am, An = A.shape 25 26 assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot" 27 Ax: NDArray[GenericType] = np.dot(A, x) # Shape (M x 1) 28 29 assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)" 30 result: NDArray[GenericType] = np.add(Ax, b) # (M x 1) + (M x 1) 31 32 ravel_result: NDArray[GenericType] = np.ravel(result) 33 assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}" 34 return ravel_result 35 36 37A: NDArray[np.float64] = np.random.standard_normal(size=(10, 10)) 38x: NDArray[np.float64] = np.random.standard_normal(size=(10, 1)) 39b: NDArray[np.float32] = np.random.standard_normal(size=(10, 1)) 40y: NDArray[np.float64] = Linear(A, x, b) 41print(y) 42print(y.dtype) 1[-1.81553298 -4.94471634 3.24041295 3.34200411 2.221593 7.59161372 2 3.1321597 -0.37862935 -1.98975116 1.57701057] 3float64 Even though this code is "runnable" and doesn't produce an error, a typechecker like 1pyright linear_bad_typing.py 1No configuration file found. 2No pyproject.toml file found. 3stubPath /mnt/typings is not a valid directory. 4Assuming Python platform Linux 5Searching for source files 6Found 1 source file 7pyright 1.1.299 8/mnt/linear_bad_typing.py 9 /mnt/linear_bad_typing.py:40:26 - error: Expression of type "ndarray[Any, dtype[float64]]" cannot be assigned to declared type "NDArray[float32]" 10 "ndarray[Any, dtype[float64]]" is incompatible with "NDArray[float32]" 11 TypeVar "_DType_co@ndarray" is covariant 12 "dtype[float64]" is incompatible with "dtype[float32]" 13 TypeVar "_DTypeScalar_co@dtype" is covariant 14 "float64" is incompatible with "float32" (reportGeneralTypeIssues) 15 /mnt/linear_bad_typing.py:41:39 - error: Argument of type "NDArray[float32]" cannot be assigned to parameter "b" of type "NDArray[GenericType@Linear]" in function "Linear" 16 "NDArray[float32]" is incompatible with "NDArray[float64]" 17 TypeVar "_DType_co@ndarray" is covariant 18 "dtype[float32]" is incompatible with "dtype[float64]" 19 TypeVar "_DTypeScalar_co@dtype" is covariant 20 "float32" is incompatible with "float64" (reportGeneralTypeIssues) 212 errors, 0 warnings, 0 informations 22Completed in 0.606sec Now we know to adjust the type hint of While Ideally it would be great if in addition to a Shape typing is a technique used to annotate information about thedimensionality and size of an array. In the context of For more information about shape typing checkout this google doc on a shapetyping syntax proposal by Matthew Rahtz, Jörg Bornschein, Vlad Mikulik, TimHarley, Matthew Willson, Dimitrios Vytiniotis, Sergei Lebedev, Adam Paszke. As we've seen, 1ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True) 2 3if TYPE_CHECKING or sys.version_info >= (3, 9): 4 _DType = np.dtype[ScalarType] 5 NDArray = np.ndarray[Any, np.dtype[ScalarType]] 6else: 7 _DType = _GenericAlias(np.dtype, (ScalarType,)) 8 NDArray = _GenericAlias(np.ndarray, (Any, _DType)) And follow the definition of 1class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): We can see that it looks like 1# TODO: Set the `bound` to something more suitable once we 2# have proper shape support 3_ShapeType = TypeVar("_ShapeType", bound=Any) 4_ShapeType2 = TypeVar("_ShapeType2", bound=Any) 😭 Looks like we're stuck with Luckily for us, we don't have to wait for shape support in Now this blog post isn't about the details of PEP 646 or variadicgenerics. Understanding PEP 646 will help, but it's not needed to understandthe rest of this post. In order to add rudimentary shape typing to 1ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True) 2Shape = TypeVarTuple("Shape") 3 4if TYPE_CHECKING or sys.version_info >= (3, 9): 5 _DType = np.dtype[ScalarType] 6 NDArray = np.ndarray[*Shape, np.dtype[ScalarType]] 7else: 8 _DType = _GenericAlias(np.dtype, (ScalarType,)) 9 NDArray = _GenericAlias(np.ndarray, (Any, _DType)) Doing so allows us to fill in a Let's look at an example of using these concepts to type a wrapper functionfor 1import numpy as np 2from numpy.typing import NDArray 3from typing import Tuple, TypeVar, Literal 4 5# Generic dimension sizes types 6T1 = TypeVar("T1", bound=int) 7T2 = TypeVar("T2", bound=int) 8T3 = TypeVar("T3", bound=int) 9 10# Dimension types represented as typles 11Shape = Tuple 12Shape1D = Shape[T1] 13Shape2D = Shape[T1, T2] 14Shape3D = Shape[T1, T2, T3] 15ShapeND = Shape[T1, ...] 16ShapeNDType = TypeVar("ShapeNDType", bound=ShapeND) 17 18def rand_normal_matrix(shape: ShapeNDType) -> NDArray[ShapeNDType, np.float64]: 19 """Return a random ND normal matrix.""" 20 return np.random.standard_normal(size=shape) 21 22# Yay correctly typed 2x2x2 cube! 23LENGTH = Literal[2] 24cube: NDArray[Shape3D[LENGTH, LENGTH, LENGTH], np.float64] = rand_normal_matrix((2,2,2)) 25print(cube) 26 27SIDE = Literal[4] 28 29# Uh oh the shapes won't match! 30square: NDArray[Shape2D[SIDE, SIDE], np.float64] = rand_normal_matrix((3,3)) 31print(square) Notice here there are no Now while this code is "runnable", 1py -m pyright bad_shape_typing.py --lib 1No configuration file found. 2No pyproject.toml file found. 3Assuming Python platform Linux 4Searching for source files 5Found 1 source file 6pyright 1.1.299 7/mnt/bad_shape_typing.py 8 /mnt/bad_shape_typing.py:30:71 - error: Argument of type "tuple[Literal[3], Literal[3]]" cannot be assigned to parameter "shape" of type "ShapeNDType@rand_normal_matrix" in function "rand_normal_matrix" 9 Type "Shape2D[SIDE, SIDE]" cannot be assigned to type "tuple[Literal[3], Literal[3]]" (reportGeneralTypeIssues) 101 error, 0 warnings, 0 informations 11Completed in 0.535sec Huzzah shape typing!! Now that we have shape typed one function, let's step it up a notch. Let's trytyping each If we look at the docs for Both arguments as Both arguments are Either arguments are scalars Either argument is a One argument is We can implement these cases as follows 1ShapeVarGen = TypeVarTuple("ShapeVarGen") 2 3@overload 4def dot(x1: NDArray[Shape1D[T1], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], /) -> GenericDType: 5 ... 6 7 8@overload 9def dot( 10 x1: NDArray[Shape[T1, *ShapeVarGen], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], / 11) -> NDArray[Shape[*ShapeVarGen], GenericDType]: 12 ... 13 14 15@overload 16def dot( 17 x1: NDArray[Shape2D[T1, T2], GenericDType], 18 x2: NDArray[Shape2D[T2, T3], GenericDType], 19 /, 20) -> NDArray[Shape2D[T1, T3], GenericDType]: 21 ... 22 23 24@overload 25def dot(x1: GenericDType, x2: GenericDType, /) -> GenericDType: 26 ... 27 28 29def dot(x1, x2): 30 return np.dot(x1, x2) The only case we can't implement is an 1ShapeVarGen1 = TypeVarTuple("ShapeVarGen1") 2ShapeVarGen2 = TypeVarTuple("ShapeVarGen2") 3 4@overload 5def dot( 6 x1: NDArray[Shape[*ShapeVarGen1, T1], GenericDType], x2: NDArray[Shape[*ShapeVarGen2, T1, T2], GenericDType], / 7) -> NDArray[Shape[*ShapeVarGen1, *ShapeVarGen2], GenericDType]: 8 ... But currently using multiple type variable tuples is not allowed. If you knowof another way to cover this case let me know! Luckily for our Here's how we would use these 1import numpy as np 2from numpy.typing import NDArray 3from numpy_shape_typing.dot import dot 4from numpy_shape_typing.types import ShapeNDType, Shape2D 5from numpy_shape_typing.rand import rand_normal_matrix 6 7from typing import Literal 8 9ROWS = Literal[2] 10COLS = Literal[3] 11A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3)) 12B: NDArray[Shape2D[COLS, ROWS], np.float64] = rand_normal_matrix((3,2)) 13C: NDArray[Shape2D[ROWS, ROWS], np.float64] = dot(A, B) 14print(C) And if we check with 1py -m pyright good_dot.py --lib 1No configuration file found. 2No pyproject.toml file found. 3Assuming Python platform Linux 4Searching for source files 5Found 1 source file 6pyright 1.1.299 70 errors, 0 warnings, 0 informations 8Completed in 0.909sec Everything looks good as it should! And if we change the types to invalid matrix shapes: 1import numpy as np 2from numpy.typing import NDArray 3from numpy_shape_typing.dot import dot 4from numpy_shape_typing.rand import rand_normal_matrix 5from numpy_shape_typing.types import ShapeNDType, Shape2D 6 7from typing import Literal 8 9ROWS = Literal[2] 10COLS = Literal[3] 11SLICES = Literal[4] 12 13# uh oh based on these types we can't do a valid dot product! 14A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3)) 15B: NDArray[Shape2D[SLICES, COLS], np.float64] = rand_normal_matrix((4,3)) 16C: NDArray[Shape2D[ROWS, COLS], np.float64] = dot(A, B) 17print(C) And if we check with 1py -m pyright ./bad_dot.py --lib 1No configuration file found. 2No pyproject.toml file found. 3Assuming Python platform Linux 4Searching for source files 5Found 1 source file 6pyright 1.1.299 7/mnt/bad_dot.py 8 /mnt/bad_dot.py:16:54 - error: Argument of type "NDArray[Shape2D[SLICES, COLS], float64]" cannot be assigned to parameter "x2" of type "GenericDType@dot" in function "dot" 9 Type "NDArray[Shape2D[ROWS, COLS], float64]" cannot be assigned to type "NDArray[Shape2D[SLICES, COLS], float64]" (reportGeneralTypeIssues) 101 error, 0 warnings, 0 informations 11Completed in 0.908sec The next function we are going to type is Two Two Covering the first case is easy, but the second case is much harder as wewould have to come up with a scheme to cover However if we scope down the second case to only two dimensions, we can coverall the array broadcasting rules with a few overloads: 1from typing import overload 2 3import numpy as np 4from numpy.typing import NDArray 5 6from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D, ShapeVarGen 7 8 9@overload 10def add( 11 x1: NDArray[Shape2D[T1, T2], GenericDType], 12 x2: NDArray[Shape1D[T2], GenericDType], 13 /, 14) -> NDArray[Shape2D[T1, T2], GenericDType]: 15 ... 16 17 18@overload 19def add( 20 x1: NDArray[Shape1D[T2], GenericDType], 21 x2: NDArray[Shape2D[T1, T2], GenericDType], 22 /, 23) -> NDArray[Shape2D[T1, T2], GenericDType]: 24 ... 25 26 27@overload 28def add( 29 x1: NDArray[Shape2D[T1, T2], GenericDType], 30 x2: NDArray[Shape1D[ONE], GenericDType], 31 /, 32) -> NDArray[Shape2D[T1, T2], GenericDType]: 33 ... 34 35 36@overload 37def add( 38 x1: NDArray[Shape1D[ONE], GenericDType], 39 x2: NDArray[Shape2D[T1, T2], GenericDType], 40 /, 41) -> NDArray[Shape2D[T1, T2], GenericDType]: 42 ... 43 44 45@overload 46def add( 47 x1: NDArray[Shape2D[T1, T2], GenericDType], 48 x2: NDArray[Shape2D[T1, ONE], GenericDType], 49 /, 50) -> NDArray[Shape2D[T1, T2], GenericDType]: 51 ... 52 53 54@overload 55def add( 56 x1: NDArray[Shape2D[T1, T2], GenericDType], 57 x2: NDArray[Shape2D[ONE, T2], GenericDType], 58 /, 59) -> NDArray[Shape2D[T1, T2], GenericDType]: 60 ... 61 62 63@overload 64def add( 65 x1: NDArray[Shape2D[T1, ONE], GenericDType], 66 x2: NDArray[Shape2D[T1, T2], GenericDType], 67 /, 68) -> NDArray[Shape2D[T1, T2], GenericDType]: 69 ... 70 71 72@overload 73def add( 74 x1: NDArray[Shape2D[ONE, T2], GenericDType], 75 x2: NDArray[Shape2D[T1, T2], GenericDType], 76 /, 77) -> NDArray[Shape2D[T1, T2], GenericDType]: 78 ... 79 80 81@overload 82def add( 83 x1: GenericDType, 84 x2: NDArray[Shape2D[T1, T2], GenericDType], 85 /, 86) -> NDArray[Shape2D[T1, T2], GenericDType]: 87 ... 88 89 90@overload 91def add( 92 x1: NDArray[Shape2D[T1, T2], GenericDType], 93 x2: GenericDType, 94 /, 95) -> NDArray[Shape2D[T1, T2], GenericDType]: 96 ... 97 98 99@overload 100def add( 101 x1: NDArray[*ShapeVarGen, GenericDType], 102 x2: NDArray[*ShapeVarGen, GenericDType], 103 /, 104) -> NDArray[*ShapeVarGen, GenericDType]: 105 ... 106 107 108def add(x1, x2): 109 return np.add(x1, x2) Using these overloads, here is how we would catch unexpected array broadcasts(similar to the one from our original 1from typing import Literal 2 3import numpy as np 4from numpy.typing import NDArray 5 6from numpy_shape_typing.add import add 7from numpy_shape_typing.dot import dot 8from numpy_shape_typing.rand import rand_normal_matrix 9from numpy_shape_typing.types import ONE, Shape1D, Shape2D 10 11COLS = Literal[4] 12A: NDArray[Shape2D[COLS, COLS], np.float64] = rand_normal_matrix((4, 4)) 13B: NDArray[Shape2D[ONE, COLS], np.float64] = rand_normal_matrix((1, 4)) 14C: NDArray[Shape2D[ONE, COLS], np.float64] = add(A, B) 15print(C) In the example above, our output is a 1py -m pyright unnexpected_broadcast.py --lib 1No configuration file found. 2No pyproject.toml file found. 3Assuming Python platform Linux 4Searching for source files 5Found 1 source file 6pyright 1.1.299 7/mnt/unnexpected_broadcast.py 8 /mnt/unnexpected_broadcast.py:14:50 - error: Argument of type "NDArray[Shape2D[COLS, COLS], float64]" cannot be assigned to parameter "x1" of type "NDArray[*ShapeVarGen@add, GenericDType@add]" in function "add" 9 "NDArray[Shape2D[COLS, COLS], float64]" is incompatible with "NDArray[Shape2D[ONE, COLS], float64]" 10 TypeVar "_ShapeType@ndarray" is invariant 11 "*tuple[Shape2D[COLS, COLS]]" is incompatible with "*tuple[Shape2D[ONE, COLS]]" 12 Tuple entry 1 is incorrect type 13 "Shape2D[COLS, COLS]" is incompatible with "Shape2D[ONE, COLS]" (reportGeneralTypeIssues) 141 error, 0 warnings, 0 informations 15Completed in 2.757sec The last function we will type to finish of our From the numpy docs on Ideally we would try to write code that looks something like this: 1ShapeVarGen = TypeVarTuple("ShapeVarGen") 2 3@overload 4def ravel( 5 arr: NDArray[Shape[*ShapeVarGen], GenericDType] 6) -> NDArray[Shape1D[Product[*ShapeVarGen]], GenericDType]: 7 ... But unfortunately However for the sake of completion we can fake it! If we scope down from a generic Here's how we can do it. First we create a bunch of 1ZERO = Literal[0] 2ONE = Literal[1] 3TWO = Literal[2] 4THREE = Literal[3] 5FOUR = Literal[4] 6... Then we define "multiply" types for factor pairs of numbers: 1SHAPE_2D_MUL_TO_ONE = TypeVar( 2 "SHAPE_2D_MUL_TO_ONE", 3 bound=Shape2D[Literal[ONE], Literal[ONE]], 4) 5SHAPE_2D_MUL_TO_TWO = TypeVar( 6 "SHAPE_2D_MUL_TO_TWO", 7 bound=Union[Shape2D[Literal[ONE], Literal[TWO]], Shape2D[Literal[TWO], Literal[ONE]]], 8) 9SHAPE_2D_MUL_TO_THREE = TypeVar( 10 "SHAPE_2D_MUL_TO_THREE", 11 bound=Union[Shape2D[Literal[ONE], Literal[THREE]], Shape2D[Literal[THREE], Literal[ONE]]], 12) 13SHAPE_2D_MUL_TO_FOUR = TypeVar( 14 "SHAPE_2D_MUL_TO_FOUR", 15 bound=Union[ 16 Shape2D[Literal[ONE], Literal[FOUR]], 17 Shape2D[Literal[TWO], Literal[TWO]], 18 Shape2D[Literal[FOUR], Literal[ONE]], 19 ], 20) Then lastly we wire these types up into individual 1@overload 2def ravel(arr: NDArray[SHAPE_2D_MUL_TO_ONE, GenericDType]) -> NDArray[Shape1D[ONE], GenericDType]: 3 ... 4 5 6@overload 7def ravel(arr: NDArray[SHAPE_2D_MUL_TO_TWO, GenericDType]) -> NDArray[Shape1D[TWO], GenericDType]: 8 ... 9 10 11@overload 12def ravel(arr: NDArray[SHAPE_2D_MUL_TO_THREE, GenericDType]) -> NDArray[Shape1D[THREE], GenericDType]: 13 ... 14 15 16@overload 17def ravel(arr: NDArray[SHAPE_2D_MUL_TO_FOUR, GenericDType]) -> NDArray[Shape1D[FOUR], GenericDType]: 18 ... 19 20@overload 21def ravel(arr: NDArray[Shape2D[T1, ONE], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]: 22 ... 23 24 25@overload 26def ravel(arr: NDArray[Shape2D[ONE, T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]: 27 ... 28 29 30@overload 31def ravel(arr: NDArray[Shape1D[T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]: 32 ... Now we can rinse and repeat for as many numbers as we like! Here is how we'd use this typing to catch a shape type error with 1import numpy as np 2from numpy.typing import NDArray 3 4from numpy_shape_typing.rand import rand_normal_matrix 5from numpy_shape_typing.ravel import ravel 6from numpy_shape_typing.types import FOUR, SEVEN, TWO, Shape1D, Shape2D 7 8A: NDArray[Shape2D[TWO, FOUR], np.float64] = rand_normal_matrix((2, 4)) 9B: NDArray[Shape1D[SEVEN], np.float64] = ravel(A) 10print(B) 1py -m pyright raveling.py --lib 1No configuration file found. 2No pyproject.toml file found. 3Assuming Python platform Linux 4Searching for source files 5Found 1 source file 6pyright 1.1.299 7/mnt/raveling.py 8 /mnt/raveling.py:9:42 - error: Expression of type "NDArray[Shape1D[EIGHT], float64]" cannot be assigned to declared type "NDArray[Shape1D[SEVEN], float64]" 9 "NDArray[Shape1D[EIGHT], float64]" is incompatible with "NDArray[Shape1D[SEVEN], float64]" 10 TypeVar "_ShapeType@ndarray" is invariant 11 "*tuple[Shape1D[EIGHT]]" is incompatible with "*tuple[Shape1D[SEVEN]]" 12 Tuple entry 1 is incorrect type 13 "Shape1D[EIGHT]" is incompatible with "Shape1D[SEVEN]" (reportGeneralTypeIssues) 141 error, 0 warnings, 0 informations 15Completed in 0.933sec So far we've gone through typing a small subset of Now we can chain these typed functions together to form a typed 1from typing import Literal 2 3import numpy as np 4from numpy.typing import NDArray 5 6from numpy_shape_typing.add import add 7from numpy_shape_typing.dot import dot 8from numpy_shape_typing.rand import rand_normal_matrix 9from numpy_shape_typing.ravel import ravel 10from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D 11 12 13def Linear( 14 A: NDArray[Shape2D[T1, T2], GenericDType], 15 x: NDArray[Shape2D[T2, ONE], GenericDType], 16 b: NDArray[Shape2D[T1, ONE], GenericDType], 17) -> NDArray[Shape1D[T1], GenericDType]: 18 Ax = dot(A, x) 19 Axb = add(Ax, b) 20 return ravel(Axb) 21 22 23IN_DIM = Literal[3] 24in_dim: IN_DIM = 3 25 26OUT_DIM = Literal[4] 27out_dim: OUT_DIM = 4 28 29# bad type >:( 30BAD_OUT_DIM = Literal[5] 31 32A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim)) 33x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1)) 34b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1)) 35 36# this is a bad type! 37y: NDArray[Shape1D[BAD_OUT_DIM], np.float64] = Linear(A, x, b) I've included an intentional type error which should be caught by 1py -m pyright linear_type_bad.py --lib 1No configuration file found. 2No pyproject.toml file found. 3Assuming Python platform Linux 4Searching for source files 5Found 1 source file 6pyright 1.1.299 7/mnt/linear_type_bad.py 8 /mnt/linear_type_bad.py:37:55 - error: Argument of type "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" cannot be assigned to parameter "A" of type "NDArray[Shape2D[T1@Linear, T2@Linear], GenericDType@Linear]" in function "Linear" 9 "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, IN_DIM], float64]" 10 TypeVar "_ShapeType@ndarray" is invariant 11 "*tuple[Shape2D[OUT_DIM, IN_DIM]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, IN_DIM]]" 12 Tuple entry 1 is incorrect type 13 "Shape2D[OUT_DIM, IN_DIM]" is incompatible with "Shape2D[BAD_OUT_DIM, IN_DIM]" (reportGeneralTypeIssues) 14 /mnt/linear_type_bad.py:37:61 - error: Argument of type "NDArray[Shape2D[OUT_DIM, ONE], float64]" cannot be assigned to parameter "b" of type "NDArray[Shape2D[T1@Linear, ONE], GenericDType@Linear]" in function "Linear" 15 "NDArray[Shape2D[OUT_DIM, ONE], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, ONE], float64]" 16 TypeVar "_ShapeType@ndarray" is invariant 17 "*tuple[Shape2D[OUT_DIM, ONE]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, ONE]]" 18 Tuple entry 1 is incorrect type 19 "Shape2D[OUT_DIM, ONE]" is incompatible with "Shape2D[BAD_OUT_DIM, ONE]" (reportGeneralTypeIssues) 202 errors, 0 warnings, 0 informations 21Completed in 8.155sec And huzzah again! And now we can fix this shape error by changing 1from typing import Literal 2 3import numpy as np 4from numpy.typing import NDArray 5 6from numpy_shape_typing.add import add 7from numpy_shape_typing.dot import dot 8from numpy_shape_typing.rand import rand_normal_matrix 9from numpy_shape_typing.ravel import ravel 10from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D 11 12 13def Linear( 14 A: NDArray[Shape2D[T1, T2], GenericDType], 15 x: NDArray[Shape2D[T2, ONE], GenericDType], 16 b: NDArray[Shape2D[T1, ONE], GenericDType], 17) -> NDArray[Shape1D[T1], GenericDType]: 18 """ 19 Args: 20 A: ndarray of shape (M x N) 21 x: ndarray of shape (N x 1) 22 b: ndarray of shape (M x 1) 23 24 Returns: 25 Linear output ndarray of shape (M) 26 """ 27 Ax = dot(A, x) 28 Axb = add(Ax, b) 29 return ravel(Axb) 30 31 32IN_DIM = Literal[3] 33in_dim: IN_DIM = 3 34 35OUT_DIM = Literal[4] 36out_dim: OUT_DIM = 4 37 38A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim)) 39x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1)) 40b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1)) 41y: NDArray[Shape1D[OUT_DIM], np.float64] = Linear(A, x, b) And if we check with 1py -m pyright linear_type_good.py --lib 1No configuration file found. 2No pyproject.toml file found. 3Assuming Python platform Linux 4Searching for source files 5Found 1 source file 6pyright 1.1.299 70 errors, 0 warnings, 0 informations 8Completed in 8.116sec You tell me! Many open source scientific computing libraries have GitHub issuesabout shape typing such as: So it's well recognized as a desirable feature. Some of the major technicalhurdles we still need to overcome are: Once these hurdles are overcome I don't see any blockers stopping projectslike This post and accompanying repo is just a sample form of what shape typingmight become. With future PEPs and work on the Thanks for reading! (っ◔◡◔)っ ♥python
(via numpy
,pytorch
, jax
, or other), it's more frequent than not to encounter shapeerrors like the one belowLinear
implementation in numpy
with a subtle shape bug.Linear
clearly says the result should be size m
(or4
). But why then did we end up with a vector of size 16
? If we dig intoeach function we will eventually find that our problem is in how numpy
handles an ndarray
of a different shape.Linear
, after np.dot
we have an ndarray
of shape(4,1)
of which we do np.add
with a vector of shape (4)
. And here liesour bug. We might naturally think that np.add
will do this addition elementwise, but instead we fell into an array broadcasting trap. Array broadcastingare sets of rules numpy
uses to determine how to do arithmetic on differentshaped ndarrays
. So instead of doing our computation element wise, numpy
interprets this as doing a broadcast operation of addition, resulting in a(4,4)
matrix, which subsequently gets "raveled" into a size 16
vector.b
variable to be ofshape (4,1)
so numpy
will interpret the np.add
as an element wiseaddition.Existing ways to stop shape bugs
assert
throughout Linear
with aninformative error message, we can "fail early" and start debugging like so:assert
to make sure all thendarray
shapes are what we expect.Linear
is a bit "safer". But compared to what we had originally,this approach is much less readable. We also inherit some of the baggage thatcomes with runtime error checking like:python
is a dynamically typed language, in python>=3.5
thetyping
module was introduced to enable static type checkers to validate typehinted python
code. (See this video for more details)numpy
) have started to type hinttheir codebases which we can use to our benefit.numpy
.dtype
typing numpy
arraysnumpy==v1.24.2
only supports typing on anndarray
's dtype
(uint8
, float64
, etc.).numpy
's existing type hinting tooling, here's how we would includedtype
type information to our Linear
example (note: there is anintentional type error)pyright
tells us a different story.pyright
has noticed that when we create our b
variable, we gave it adtype
type that is incompatible with np.random.standard_normal
.b
to be in line with the dtype
thatis expected of np.random.standard_normal
(NDArray[np.float64]
).Shape typing
numpy
arraysdtype
typing is great, it's not the most useful for preventing shapeerrors (like from our original example).dtype
type, we can alsoinclude information about an ndarray
's shape to do shape typing.numpy
and thepython
type hinting system, we can use shape typing catch shape errorsbefore runtime.numpy
's NDArray
currently only supports dtype
typing anddoesn't have any of this kind of shape typing ability. But why is that? If wedig into the definition of the NDArray
type:np.ndarray
...numpy
uses a Shape
type already! Butunfortunately if we look at the definition for this ...Any
which doesn't add any useful shapeinformation on our types.numpy
. PEP 646 hasthe base foundation for shape typing and has already been accepted into python==3.11
! And it's supported by pyright
! Theoretically these two things giveus most of the ingredients to do basic shape typing.numpy
we can simply change theAny
type in the NDArray
type definition to an unpacked variadic genericlike so:Tuple
based type (indicating shape) in anNDArray
alongside a dtype
type. And shape typing with Tuple
's enables usdefine function overloads which describe to a type checker the possible ways afunction can change the shape of an NDArray
.np.random.standard_normal
from our Linear
example with an intentionaltype error:assert
statements. And instead of several commentsabout shape, we indicate shape in the type hint.pyright
will tell us something else:pyright
is telling us we've incorrectly typed square
and that it'sincompatible with a 3x3
shape. Now we know we need to go back and fix thetype to what a type checker should expect.Moar
numpy
shape typing!numpy
function in our Linear
example to include shapetypes. We've already typed np.random.standard_normal
, so next let's donp.dot
.np.dot
there are 5 type cases it supports.1D
arrays2D
arrays (resulting in a matmul
)ND
array and the other is a 1D
arrayND
array and the other is MD
arrayND
dimensional array with an MD
dimensional array. Ideally we would try implementing it like so:Linear
usecase, it only uses scalars, vectors, and matrices which is covered by our fouroverloads.dot
overloads to do the dot product between a2x3
matrix and a 3x2
matrix with type hints:pyright
:pyright
:pyright
let's us know that the types we are using are incorrect shapes basedon np.dot
's type overloads we've specified.Even moar
numpy
shape typing!np.add
. The numpy
docs only showtwo cases.ND
array arguments of the same shape are added element wiseND
array arguments that are not the same shape must be broadcastable toa common shapenumpy
's array broadcastingsystem. Currently python==3.11
's typing
doesn't have a generic way tocover all the broadcasting rules. (If you know of a way let me know!)Linear
example).4x4
matrix, but what we want from ourtypes is an output shape of 4x1
. Let's see what pyright
sayspyright
informs us that our shapes are off and that we got broadcasted to a4x4
! Huzzah shape typing!Hitting the limitations of shape typing 😿
Linear
example isnp.ravel
. However this is where we start hitting some major limitations ofshape typing as they exist today in python
and numpy
.np.ravel
the only case we need to cover is that anyND
array gets collapsed into a 1D
array of size of the total number ofelements. Luckily all the information to compute the final 1D
size is justthe product of all the input dimension sizes.python
's typing
package currently doesn't have a notionof a Product
type that provides a way to do algebraic typing.ND
typing of np.ravel
to support up to twodimensions and limit the size of the output dimension to some maximum number,we can overload all the possible factors that multiply to the output dimensionsize. We would effectively be typing a multiplication table 😆, but it willwork and get us to a "partially" typed np.ravel
.Literal
types (our factors):ravel
overloads (andcover a few generic ones while we're at it):ravel
:Putting it all together
numpy
's functions(np.random.standard_normal
, np.dot
, np.add
, and np.ravel
in all).Linear
implementation like so:pyright
like so:pyright
has caught the shape type error!BAD_OUT_DIM
to the correctoutput dimension size.pyright
.pyright
tells us that our types are consistent!What's next?
numpy
: https://github.com/numpy/numpy/issues/16544jax
: https://github.com/google/jax/issues/12049pytorch
: https://github.com/pytorch/pytorch/issues/33953numpy
from being fully shape typed.python
type hinting system,we'll hopefully make our code incrementally safer.More posts from The Art of Abstraction
FAQs
Shape typing numpy with pyright and variadic generics? ›
In NumPy, the shape of an array can be changed using the “reshape” method. The reshape method takes a tuple as an argument which represents the desired shape of the array. The number of elements in the original array must match the number of elements in the reshaped array.
How do you change the shape of a NumPy array? ›Change an Array's Shape Using NumPy reshape() NumPy's reshape() enables you to change the shape of an array into another compatible shape. Not all shapes are compatible since all the elements from the original array needs to fit into the new array. You can use reshape() as either a function or a method.
What is shape manipulation in NumPy arrays? ›In this lab, you learned the NumPy shape manipulation functions reshape , concatenate , stack , split , and transpose . These functions allow you to manipulate the shape of NumPy arrays and are essential for many data manipulation tasks.
Can NumPy arrays be reshaped? ›Yes, as long as the elements required for reshaping are equal in both shapes. We can reshape an 8 elements 1D array into 4 elements in 2 rows 2D array but we cannot reshape it into a 3 elements 3 rows 2D array as that would require 3x3 = 9 elements.
How to manipulate a NumPy array? ›NumPy provides functions like reshape , ravel , flatten , transpose to change the shape of an array. These functions does not modify the original array but returns a new array (except for ravel ). The reshape function gives a new shape to an array without modifying its data.