My solutions to Sasha Rush’s tensor puzzles. This is useful as a quick refresher for things you can do with vanilla tensor operations. Here’s the official solution video.

arange and where

These are provided and are really useful. The two paradigms I see for tensor operations are masking (apply where to input matrix) and matmuls (need different dimensions, use arange to generate correct size matrix and where to broadcast).

def arange(i: int):
	"Use this function to replace a for-loop."
	return torch.tensor(range(i))
 
def where(q, a, b):
	"Use this function to replace an if-statement."
	return (q * a) + (~q) * b

ones

Use arange to loop over all the indices and set all values to 1.

def ones(i: int) -> TT["i"]:
	return where(arange(i) >= 0, 1, 0)

sum

Do a dot product with the ones array. Effectively perform (1 x i) @ (i x 1) => 1.

def sum(a: TT["i"]) -> TT[1]:
	return ones(a.shape[0]) @ a[:, None]

outer

(i x 1) @ (1 x j) => (i x j).

def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
	return a[:, None] @ b[None, :]

diag

We can use indexing here to grab every value along the matrix diagonal.

def diag(a: TT["i", "i"]) -> TT["i"]:
	return a[arange(a.shape[0]), arange(a.shape[0])]

eye

Use broadcasting to check each arange value with itself, so that the diagonal has all ones.

def eye(j: int) -> TT["j", "j"]:
	return where(arange(j)[:, None] == arange(j), 1, 0)

triu

Same as eye except we take the entire upper triangle instead of just the diagonal.

def triu(j: int) -> TT["j", "j"]:
	return where(arange(j)[:, None] <= arange(j), 1, 0)

cumsum

For each value i in the array, we want the sum from 0 to i. We can multiply by the triu pattern to achieve exactly this.

def cumsum(a: TT["i"]) -> TT["i"]:
	return a @ triu(a.shape[0])

diff

Simply use indexing against the current and previous values, except for at the beginning where we don’t compute a difference.

def diff(a: TT["i"], i: int) -> TT["i"]:
	return where(arange(i) > 0, a[arange(i)] - a[arange(i) - 1], a[arange(i)])

vstack

Create a 2 x i array by broadcasting [[0], [1]] with the ones array. This works because for the first row you end up with all True and the second row is all False. Then conditionally set elements based on if we’re on row 0 or 1.

def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
	return where(arange(2)[:, None] != ones(a.shape[0]), a, b)

roll

Increment and index the array, “rolling” around using a modulo.

def roll(a: TT["i"], i: int) -> TT["i"]:
	return a[(arange(i) + 1) % i]

flip

Reverse index the array. The indices [0, 1, 2] become [2, 1, 0].

def flip(a: TT["i"], i: int) -> TT["i"]:
	return a[i - arange(i) - 1]

compress

In the compress function, the goal is to select elements from v based on a boolean mask g. This is done by creating a matrix where each column corresponds to an index in the input vector v. The cumsum(1 * g) operation computes a cumulative sum over the mask, effectively numbering the True elements. The expression arange(i) == cumsum(1 * g)[:, None] - 1 creates a matrix where each row corresponds to one of the True indices in the mask. This matrix is then multiplied with the input vector v, effectively selecting and arranging the elements according to the mask.

def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
	return v @ where(g[:, None], arange(i) == cumsum(1 * g)[:, None] - 1, 0)

pad_to

In pad_to, the goal is to extend a vector a to a specified length j by padding with zeros. This is achieved by creating a matrix where each column is a one-hot vector representing an index in the original array a. Multiplying a with this matrix effectively pads the array with zeros up to the desired length j.

def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
	return a @ (1 * (arange(i)[:, None] == arange(j)))

sequence_mask

The sequence_mask operation is used to mask a 2D tensor based on specified lengths for each row. The length[:, None] > arange(values.shape[1]) operation creates a boolean mask where each row is masked according to its corresponding value in the length vector. Values beyond the specified length in each row are set to zero. This mask is then applied to the values tensor using the where function, effectively zero-padding or truncating each row of the tensor based on the lengths specified in the length tensor.

def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
	return where(length[:, None] > arange(values.shape[1]), values, 0)

bincount

The trick here lies with eye(j)[a]. That part just says that for value in a, get the corresponding row from the identity matrix. We have a guarantee that j = max(a) + 1, so eye(j) will contain all of the relevant rows. Now, we can just do a simple matmul with a ones array to vertically sum the rows, which counts the occurrences of each value.

def bincount(a: TT["i"], j: int) -> TT["j"]:
	return ones(a.shape[0]) @ eye(j)[a]

scatter_add

We use a similar technique from bincount except now we multiply by values since we’re trying to sum to a specific position in a new array. The link tells us where to go and using the eye matrix lets us know which values to pull out for each position.

def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
	return values @ eye(j)[link]

flatten

We want to use indexing here to iterate across the entire 2D array. We can do this by grabbing the list of numbers from 0 - i * j and using // to count how many i’s have passed and % to get the current j position. This converts arange(i * j) to a list of indices.

def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
	return a[arange(i * j) // j, arange(i * j) % j]

linspace

Copy the linspace spec and use arange(n) instead of looping over the output array.

def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
	return (i + (j - i) * arange(n) / max(1, n - 1))

heaviside

Once again just copy the spec, subbing in where for the conditional and arange for the loop.

def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
	return where(a[arange(a.shape[0])] == 0, b[arange(a.shape[0])], a[arange(a.shape[0])] > 0)

repeat (1d)

This effectively does an outer product, copying the initial row i times.

def repeat(a: TT["i"], d: TT[1]) -> TT["d", "i"]:
	return ones(d)[:, None] @ a[None, :]

bucketize

We do an outer product to get boolean values that tell us which buckets a value could fit into. Then we can just sum across the columns to get the index to bucket a value into. For instance, if we had the boundaries [3, 5] and a value 10, it would be greater than all the boundaries so it would be at index 2.

def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
	return (1 * (v[:, None] >= boundaries)) @ ones(boundaries.shape[0])