Shortcuts

Nested Tensors

Nested tensor is very similar to regular tensor, except for the shape:

  • for a regular tensor, each dimension has a size

  • for a nested tensor, not all dimensions have regular sizes; some of them are jagged

Nested tensors are a natural solution for representing sequential data within various domains:

  • in NLP, sentences can have variable lengths, so a batch of sentences forms a nested tensor

  • in CV, images can have variable shapes, so a batch of images forms a nested tensor

In this tutorial, we will demonstrate basic usage of nested tensors and motivate their usefulness for operating on sequential data of varying lengths with a real-world example.

The nested tensor operations used here have not been released yet. You will have to install the latest nightly to run this tutorial.

import torch
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Nested Tensor Initialization

From the Python frontend, a nested tensor can be created from a list of tensors.

nt = torch.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], device=device)
print(nt)

By padding every underlying tensor to the same shape, a nested tensor can be converted to a regular tensor.

pt = nt.to_padded_tensor(0.0)
print(pt)

For practical reasons, conceptually we implement nested tensor as a batch of tensors with different shapes, i.e. dimension 0 is assumed to be the batch dimension. Indexing dimension 0 gives back the underlying tensor.

print("0th underlying tensor:", nt[0], sep='\n')
print("last column of 1st underlying tensor:", nt[1, :, -1], sep='\n')

Slicing in dimension 0 has not been supported yet.

Nested Tensor Operations

As each operation must be explicitly implemented for nested tensors, operation coverage for nested tensors is currently narrower than that of regular tensors. For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered. However, coverage is being expanded rapidly. If you need certain operations, please file an issue to help us prioritize coverage.

reshape

The reshape op is for changing the shape of a tensor. Its full semantics for regular tensors can be found here. For regular tensors, when specifying the new shape, a single dimension may be -1, in which case it is inferred from the remaining dimensions and the number of elements.

The semantics for nested tensors are similar, except that -1 no longer infers. Instead, it inherits the old size (here 2 for nt[0] and 3 for nt[1]). -1 is the only legal size to specify for a jagged dimension.

nt1 = nt.reshape(2, -1, 2, 3)
print(nt1)

transpose

The transpose op is for swapping two dimensions of a tensor. Its full semantics can be found here. Note that nested tensor dimension 0 is special; it is assumed to be the batch dimension, so transposes involving nested tensor dimension 0 are forbidden.

nt2 = nt1.transpose(1, 2)
print(nt2)

others

Other operations have the same semantics as for regular tensors. Applying the operation on a nested tensor is equivalent to applying the operation to the underlying tensor components, with the result being a nested tensor as well.

nt_mm = torch.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
nt3 = torch.matmul(nt2, nt_mm)
print("matmul:", nt3, sep='\n')

nt4 = F.dropout(nt3, 0.1)
print("dropout:", nt4, sep='\n')

nt5 = F.softmax(nt4, -1)
print("softmax:", nt5, sep='\n')

Why Nested Tensor

In the age before nested tensor, one has to manually pad each data tensor to the same shape to form a batch as a regular tensor. For example, we have 2 sentences and a vocabulary, then pad with 0.

sentences = [["goodbye", "padding"],
             ["embrace", "nested", "tensor"]]
vocabulary = {"goodbye" : 1.0, "padding" : 2.0,
              "embrace" : 3.0, "nested" : 4.0, "tensor" : 5.0}
padded_sentences = torch.tensor([[1.0, 2.0, 0.0],
                                 [3.0, 4.0, 5.0]])
nested_sentences = torch.nested_tensor([torch.tensor([1.0, 2.0]),
                                        torch.tensor([3.0, 4.0, 5.0])])
print(padded_sentences)
print(nested_sentences)

Clearly, padding introduces inefficiency. Further, padding with zeros does not correctly treat entries as padding for every operation, e.g. in softmax one has to pad with -inf rather than 0 to ignore specific entries.

padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
                                             [3.0, 4.0, 5.0]])
print(F.softmax(padded_sentences_for_softmax, -1))
print(F.softmax(nested_sentences, -1))

Let us take a look at a practical example: the multi-head attention component utilized in Transformers. The nested tensor version is straightforward.

import math

"""
Args:
    query: query of shape (N, L_t, E_q)
    key: key of shape (N, L_s, E_k)
    value: value of shape (N, L_s, E_v)
    nheads: number of heads in multi-head attention
    W_q: Weight for query input projection of shape (E_total, E_q)
    W_k: Weight for key input projection of shape (E_total, E_k)
    W_v: Weight for value input projection of shape (E_total, E_v)
    W_out: Weight for output projection of shape (E_out, E_total)
    b_q (optional): Bias for query input projection of shape E_total. Default: None
    b_k (optional): Bias for key input projection of shape E_total. Default: None
    b_v (optional): Bias for value input projection of shape E_total. Default: None
    b_out (optional): Bias for output projection of shape E_out. Default: None
    dropout_p: dropout probability. Default: 0.0
    where:
        N is the batch size
        L_t is the target sequence length (jagged)
        L_s is the source sequence length (jagged)
        E_q is the embedding size for query
        E_k is the embedding size for key
        E_v is the embedding size for value
        E_total is the embedding size for all heads combined
        E_out is the output embedding size
Returns:
    attn_output: Output of shape (N, L_t, E_out)
"""
def mha_nested(query, key, value, nheads,
W_q, W_k, W_v, W_out,
b_q=None, b_k=None, b_v=None, b_out=None,
dropout_p=0.0):
    N = query.size(0)
    E_total = W_q.size(0)
    assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
    E_head = E_total // nheads

    # apply input projection
    # (N, L_t, E_q) -> (N, L_t, E_total)
    query = F.linear(query, W_q, b_q)
    # (N, L_s, E_k) -> (N, L_s, E_total)
    key = F.linear(key, W_k, b_k)
    # (N, L_s, E_v) -> (N, L_s, E_total)
    value = F.linear(value, W_v, b_v)

    # reshape query, key, value to separate by head
    # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
    query = query.reshape(-1, -1, nheads, E_head).transpose(1, 2)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
    key = key.reshape(-1, -1, nheads, E_head).transpose(1, 2)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
    value = value.reshape(-1, -1, nheads, E_head).transpose(1, 2)

    # query matmul key^T
    # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s)
    keyT = key.transpose(-1, -2)
    attn_weights = torch.matmul(query, keyT)

    # scale down
    attn_weights = attn_weights * (1.0 / math.sqrt(E_head))

    # softmax
    attn_weights = F.softmax(attn_weights, dim=-1)

    # dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # attention_weights matmul value
    # (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head)
    attn_output = torch.matmul(attn_weights, value)

    # merge heads
    # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
    attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total)

    # apply output projection
    # (N, L_t, E_total) -> (N, L_t, E_out)
    attn_output = F.linear(attn_output, W_out, b_out)

    return attn_output

The 0-padded tensor version additionally requires masks for more complicated treatments at padded entries.

"""
Args:
    query: query of shape (N, L_t, E_q)
    key: key of shape (N, L_s, E_k)
    value: value of shape (N, L_s, E_v)
    nheads: number of heads in multi-head attention
    attn_mask_q: boolean mask indicating locations that should not take part in attention for query, shape (N, L_t)
    attn_mask_kv: boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s)
    W_q: Weight for query input projection of shape (E_total, E_q)
    W_k: Weight for key input projection of shape (E_total, E_k)
    W_v: Weight for value input projection of shape (E_total, E_v)
    W_out: Weight for output projection of shape (E_out, E_total)
    b_q (optional): Bias for query input projection of shape E_total. Default: None
    b_k (optional): Bias for key input projection of shape E_total. Default: None
    b_v (optional): Bias for value input projection of shape E_total. Default: None
    b_out (optional): Bias for output projection of shape E_out. Default: None
    dropout_p: dropout probability. Default: 0.0
    where:
        N is the batch size
        L_t is the target sequence length (padded)
        L_s is the source sequence length (padded)
        E_q is the embedding size for query
        E_k is the embedding size for key
        E_v is the embedding size for value
        E_total is the embedding size for all heads combined
        E_out is the output embedding size
Returns:
    attn_output: Output of shape (N, L_t, E_out)
"""
def mha_padded(query, key, value, nheads,
attn_mask_q, attn_mask_kv,
W_q, W_k, W_v, W_out,
b_q=None, b_k=None, b_v=None, b_out=None,
dropout_p=0.0):
    N = query.size(0)
    L_t = query.size(1)
    L_s = key.size(1)
    E_total = W_q.size(0)
    assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
    E_head = E_total // nheads

    # apply input projection
    # (N, L_t, E_q) -> (N, L_t, E_total)
    query = F.linear(query, W_q, b_q)
    # (N, L_s, E_k) -> (N, L_s, E_total)
    key = F.linear(key, W_k, b_k)
    # (N, L_s, E_v) -> (N, L_s, E_total)
    value = F.linear(value, W_v, b_v)

    # padding-specific step: remove bias from padded entries
    # in the specific multihead-attention formula it is not necessary to remove these bias
    # because the -inf padding later on in softmax step can take care of it
    # but to be general here we demonstrate the bias removal
    for i in range(N):
        for j in range(L_t):
            if attn_mask_q[i, j]:
                query[i, j, :] = 0.0
        for j in range(L_s):
            if attn_mask_kv[i, j]:
                key[i, j, :] = 0.0
                value[i, j, :] = 0.0

    # reshape query, key, value to separate by head
    # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head)
    query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)
    key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)
    # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)
    value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)

    # query bmm key^T
    # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s)
    keyT = key.transpose(-1, -2)
    # padding-specific step: add -inf mask for padding in softmax
    attn_mask = query.new_zeros((N, nheads, L_t, L_s))
    for i in range(N):
        for j in range(L_t):
            for k in range(L_s):
                if attn_mask_q[i, j] or attn_mask_kv[i, k]:
                    attn_mask[i, :, j, k] = float("-inf")
    attn_mask = attn_mask.reshape((N * nheads, L_t, L_s))
    attn_weights = torch.baddbmm(attn_mask, query, keyT)
    # if no padding, it could have been as simple as
    #     attn_weights = torch.bmm(query, keyT)

    # scale down
    attn_weights = attn_weights * (1.0 / math.sqrt(E_head))

    # softmax
    attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0)

    # dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # attention_weights bmm value
    # (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head)
    attn_output = attn_weights.bmm(value)

    # merge heads
    # (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
    attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total)

    # apply output projection
    # (N, L_t, E_total) -> (N, L_t, E_out)
    attn_output = F.linear(attn_output, W_out, b_out)

    # padding-specific step: remove output projection bias from padded entries
    for i in range(N):
        for j in range(L_t):
            if attn_mask_q[i, j]:
                attn_output[i, j, :] = 0.0

    return attn_output

set hyperparameters following the Transformer paper

N = 512
E_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512
nheads = 8

except for dropout probability: set to 0 for correctness check

dropout_p = 0.0

Let us generate some realistic fake data from Zipf’s law.

import numpy as np

def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return sentence_lengths

alpha = 1.2

sentence_lengths = zipf_sentence_lengths(alpha, N)
L_t = np.max(sentence_lengths)
L_s = L_t

create inputs

# create parameters
W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device)
W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device)
W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device)
W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device)

# create nested input
queries = []
keys    = []
values  = []
for i in range(N):
    l = sentence_lengths[i]
    s = l
    queries.append(torch.randn((l, E_q), device=device))
    keys   .append(torch.randn((s, E_k), device=device))
    values .append(torch.randn((s, E_v), device=device))
query = torch.nested_tensor(queries)
key   = torch.nested_tensor(keys   )
value = torch.nested_tensor(values )

# pad input
padded_query = query.to_padded_tensor(0.0, (N, L_t, E_q))
padded_key   = key  .to_padded_tensor(0.0, (N, L_s, E_k))
padded_value = value.to_padded_tensor(0.0, (N, L_s, E_v))

# create attention masks
attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool)
attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool)
for i in range(N):
    for j in range(L_t):
        if padded_query[i, j, :].abs().max().item() == 0.0:
            attn_mask_q[i, j] = True
    for j in range(L_s):
        if padded_key[i, j, :].abs().max().item() == 0.0:
            attn_mask_kv[i, j] = True

check correctness and performance

import timeit

t0 = timeit.default_timer()
out_nested = mha_nested(
    query, key, value, nheads,
    W_q, W_k, W_v, W_out,
    b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
    dropout_p=dropout_p)

t1 = timeit.default_timer()
out_padded = mha_padded(
    padded_query, padded_key, padded_value, nheads,
    attn_mask_q, attn_mask_kv,
    W_q, W_k, W_v, W_out,
    b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
    dropout_p=dropout_p)
t2 = timeit.default_timer()

print("nested and padded calculations differ by", (out_nested.to_padded_tensor(0.0, (N, L_t, E_out)) - out_padded).abs().max().item())
print("nested tensor multi-head attention takes", t1 - t0, "seconds")
print("padded tensor multi-head attention takes", t2 - t1, "seconds")

The nested tensor version avoids wasted computation on padding, so in sequential CPU execution it is faster than padded tensor version as expected. Optimization for multi-threaded environment is underway.

For now, performant kernels are provided for specific use cases, e.g. self-attention evaluation by multi-head attention formula.

# embeddings are assumed to be the same
E = E_total
mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device)
mha_lib.eval()

extract parameters for correctness check

mha_lib.in_proj_weight.requires_grad_(False)
mha_lib.in_proj_bias.requires_grad_(False)
mha_lib.out_proj.weight.requires_grad_(False)
mha_lib.out_proj.bias.requires_grad_(False)
W_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E]
W_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E]
W_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :]
W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias

check correctness and performance

t0 = timeit.default_timer()
out_lib, out_lib_weights = mha_lib(query, query, query)

t1 = timeit.default_timer()
out_nested = mha_nested(
    query, query, query, nheads,
    W_q, W_k, W_v, W_out,
    b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
    dropout_p=dropout_p)

t2 = timeit.default_timer()
padded_out = mha_padded(
    padded_query, padded_query, padded_query, nheads,
    attn_mask_q, attn_mask_q,
    W_q, W_k, W_v, W_out,
    b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
    dropout_p=dropout_p)
t3 = timeit.default_timer()

print("nested general and library calculations differ by", (out_nested.to_padded_tensor(0.0) - out_lib.to_padded_tensor(0.0)).abs().max().item())
print("nested library multi-head attention takes", t1 - t0, "seconds")
print("nested general multi-head attention takes", t2 - t1, "seconds")
print("padded tensor multi-head attention takes", t3 - t2, "seconds")

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources