.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "prototype/nestedtensor.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_prototype_nestedtensor.py: Nested Tensors =============================================================== Nested tensor is very similar to regular tensor, except for the shape: * for a regular tensor, each dimension has a size * for a nested tensor, not all dimensions have regular sizes; some of them are jagged Nested tensors are a natural solution for representing sequential data within various domains: * in NLP, sentences can have variable lengths, so a batch of sentences forms a nested tensor * in CV, images can have variable shapes, so a batch of images forms a nested tensor In this tutorial, we will demonstrate basic usage of nested tensors and motivate their usefulness for operating on sequential data of varying lengths with a real-world example. The nested tensor operations used here have not been released yet. You will have to install the latest nightly to run this tutorial. .. GENERATED FROM PYTHON SOURCE LINES 23-29 .. code-block:: default import torch import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') .. GENERATED FROM PYTHON SOURCE LINES 30-33 Nested Tensor Initialization ---------------- .. GENERATED FROM PYTHON SOURCE LINES 35-36 From the Python frontend, a nested tensor can be created from a list of tensors. .. GENERATED FROM PYTHON SOURCE LINES 36-39 .. code-block:: default nt = torch.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], device=device) print(nt) .. GENERATED FROM PYTHON SOURCE LINES 40-42 By padding every underlying tensor to the same shape, a nested tensor can be converted to a regular tensor. .. GENERATED FROM PYTHON SOURCE LINES 42-45 .. code-block:: default pt = nt.to_padded_tensor(0.0) print(pt) .. GENERATED FROM PYTHON SOURCE LINES 46-50 For practical reasons, conceptually we implement nested tensor as a batch of tensors with different shapes, i.e. dimension 0 is assumed to be the batch dimension. Indexing dimension 0 gives back the underlying tensor. .. GENERATED FROM PYTHON SOURCE LINES 50-53 .. code-block:: default print("0th underlying tensor:", nt[0], sep='\n') print("last column of 1st underlying tensor:", nt[1, :, -1], sep='\n') .. GENERATED FROM PYTHON SOURCE LINES 54-55 Slicing in dimension 0 has not been supported yet. .. GENERATED FROM PYTHON SOURCE LINES 57-60 Nested Tensor Operations ---------------- .. GENERATED FROM PYTHON SOURCE LINES 62-81 As each operation must be explicitly implemented for nested tensors, operation coverage for nested tensors is currently narrower than that of regular tensors. For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered. However, coverage is being expanded rapidly. If you need certain operations, please file an `issue `__ to help us prioritize coverage. **reshape** The reshape op is for changing the shape of a tensor. Its full semantics for regular tensors can be found `here `__. For regular tensors, when specifying the new shape, a single dimension may be -1, in which case it is inferred from the remaining dimensions and the number of elements. The semantics for nested tensors are similar, except that -1 no longer infers. Instead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``). -1 is the only legal size to specify for a jagged dimension. .. GENERATED FROM PYTHON SOURCE LINES 81-84 .. code-block:: default nt1 = nt.reshape(2, -1, 2, 3) print(nt1) .. GENERATED FROM PYTHON SOURCE LINES 85-93 **transpose** The transpose op is for swapping two dimensions of a tensor. Its full semantics can be found `here `__. Note that nested tensor dimension 0 is special; it is assumed to be the batch dimension, so transposes involving nested tensor dimension 0 are forbidden. .. GENERATED FROM PYTHON SOURCE LINES 93-96 .. code-block:: default nt2 = nt1.transpose(1, 2) print(nt2) .. GENERATED FROM PYTHON SOURCE LINES 97-103 **others** Other operations have the same semantics as for regular tensors. Applying the operation on a nested tensor is equivalent to applying the operation to the underlying tensor components, with the result being a nested tensor as well. .. GENERATED FROM PYTHON SOURCE LINES 103-113 .. code-block:: default nt_mm = torch.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device) nt3 = torch.matmul(nt2, nt_mm) print("matmul:", nt3, sep='\n') nt4 = F.dropout(nt3, 0.1) print("dropout:", nt4, sep='\n') nt5 = F.softmax(nt4, -1) print("softmax:", nt5, sep='\n') .. GENERATED FROM PYTHON SOURCE LINES 114-117 Why Nested Tensor ---------------- .. GENERATED FROM PYTHON SOURCE LINES 119-122 In the age before nested tensor, one has to manually pad each data tensor to the same shape to form a batch as a regular tensor. For example, we have 2 sentences and a vocabulary, then pad with 0. .. GENERATED FROM PYTHON SOURCE LINES 122-133 .. code-block:: default sentences = [["goodbye", "padding"], ["embrace", "nested", "tensor"]] vocabulary = {"goodbye" : 1.0, "padding" : 2.0, "embrace" : 3.0, "nested" : 4.0, "tensor" : 5.0} padded_sentences = torch.tensor([[1.0, 2.0, 0.0], [3.0, 4.0, 5.0]]) nested_sentences = torch.nested_tensor([torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]) print(padded_sentences) print(nested_sentences) .. GENERATED FROM PYTHON SOURCE LINES 134-137 Clearly, padding introduces inefficiency. Further, padding with zeros does not correctly treat entries as padding for every operation, e.g. in softmax one has to pad with -inf rather than 0 to ignore specific entries. .. GENERATED FROM PYTHON SOURCE LINES 137-142 .. code-block:: default padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")], [3.0, 4.0, 5.0]]) print(F.softmax(padded_sentences_for_softmax, -1)) print(F.softmax(nested_sentences, -1)) .. GENERATED FROM PYTHON SOURCE LINES 143-146 Let us take a look at a practical example: the multi-head attention component utilized in `Transformers `__. The nested tensor version is straightforward. .. GENERATED FROM PYTHON SOURCE LINES 146-229 .. code-block:: default import math """ Args: query: query of shape (N, L_t, E_q) key: key of shape (N, L_s, E_k) value: value of shape (N, L_s, E_v) nheads: number of heads in multi-head attention W_q: Weight for query input projection of shape (E_total, E_q) W_k: Weight for key input projection of shape (E_total, E_k) W_v: Weight for value input projection of shape (E_total, E_v) W_out: Weight for output projection of shape (E_out, E_total) b_q (optional): Bias for query input projection of shape E_total. Default: None b_k (optional): Bias for key input projection of shape E_total. Default: None b_v (optional): Bias for value input projection of shape E_total. Default: None b_out (optional): Bias for output projection of shape E_out. Default: None dropout_p: dropout probability. Default: 0.0 where: N is the batch size L_t is the target sequence length (jagged) L_s is the source sequence length (jagged) E_q is the embedding size for query E_k is the embedding size for key E_v is the embedding size for value E_total is the embedding size for all heads combined E_out is the output embedding size Returns: attn_output: Output of shape (N, L_t, E_out) """ def mha_nested(query, key, value, nheads, W_q, W_k, W_v, W_out, b_q=None, b_k=None, b_v=None, b_out=None, dropout_p=0.0): N = query.size(0) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" E_head = E_total // nheads # apply input projection # (N, L_t, E_q) -> (N, L_t, E_total) query = F.linear(query, W_q, b_q) # (N, L_s, E_k) -> (N, L_s, E_total) key = F.linear(key, W_k, b_k) # (N, L_s, E_v) -> (N, L_s, E_total) value = F.linear(value, W_v, b_v) # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) query = query.reshape(-1, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) key = key.reshape(-1, -1, nheads, E_head).transpose(1, 2) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) value = value.reshape(-1, -1, nheads, E_head).transpose(1, 2) # query matmul key^T # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s) keyT = key.transpose(-1, -2) attn_weights = torch.matmul(query, keyT) # scale down attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) # softmax attn_weights = F.softmax(attn_weights, dim=-1) # dropout if dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p) # attention_weights matmul value # (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head) attn_output = torch.matmul(attn_weights, value) # merge heads # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total) # apply output projection # (N, L_t, E_total) -> (N, L_t, E_out) attn_output = F.linear(attn_output, W_out, b_out) return attn_output .. GENERATED FROM PYTHON SOURCE LINES 230-232 The 0-padded tensor version additionally requires masks for more complicated treatments at padded entries. .. GENERATED FROM PYTHON SOURCE LINES 232-347 .. code-block:: default """ Args: query: query of shape (N, L_t, E_q) key: key of shape (N, L_s, E_k) value: value of shape (N, L_s, E_v) nheads: number of heads in multi-head attention attn_mask_q: boolean mask indicating locations that should not take part in attention for query, shape (N, L_t) attn_mask_kv: boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s) W_q: Weight for query input projection of shape (E_total, E_q) W_k: Weight for key input projection of shape (E_total, E_k) W_v: Weight for value input projection of shape (E_total, E_v) W_out: Weight for output projection of shape (E_out, E_total) b_q (optional): Bias for query input projection of shape E_total. Default: None b_k (optional): Bias for key input projection of shape E_total. Default: None b_v (optional): Bias for value input projection of shape E_total. Default: None b_out (optional): Bias for output projection of shape E_out. Default: None dropout_p: dropout probability. Default: 0.0 where: N is the batch size L_t is the target sequence length (padded) L_s is the source sequence length (padded) E_q is the embedding size for query E_k is the embedding size for key E_v is the embedding size for value E_total is the embedding size for all heads combined E_out is the output embedding size Returns: attn_output: Output of shape (N, L_t, E_out) """ def mha_padded(query, key, value, nheads, attn_mask_q, attn_mask_kv, W_q, W_k, W_v, W_out, b_q=None, b_k=None, b_v=None, b_out=None, dropout_p=0.0): N = query.size(0) L_t = query.size(1) L_s = key.size(1) E_total = W_q.size(0) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" E_head = E_total // nheads # apply input projection # (N, L_t, E_q) -> (N, L_t, E_total) query = F.linear(query, W_q, b_q) # (N, L_s, E_k) -> (N, L_s, E_total) key = F.linear(key, W_k, b_k) # (N, L_s, E_v) -> (N, L_s, E_total) value = F.linear(value, W_v, b_v) # padding-specific step: remove bias from padded entries # in the specific multihead-attention formula it is not necessary to remove these bias # because the -inf padding later on in softmax step can take care of it # but to be general here we demonstrate the bias removal for i in range(N): for j in range(L_t): if attn_mask_q[i, j]: query[i, j, :] = 0.0 for j in range(L_s): if attn_mask_kv[i, j]: key[i, j, :] = 0.0 value[i, j, :] = 0.0 # reshape query, key, value to separate by head # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head) query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head) value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head) # query bmm key^T # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s) keyT = key.transpose(-1, -2) # padding-specific step: add -inf mask for padding in softmax attn_mask = query.new_zeros((N, nheads, L_t, L_s)) for i in range(N): for j in range(L_t): for k in range(L_s): if attn_mask_q[i, j] or attn_mask_kv[i, k]: attn_mask[i, :, j, k] = float("-inf") attn_mask = attn_mask.reshape((N * nheads, L_t, L_s)) attn_weights = torch.baddbmm(attn_mask, query, keyT) # if no padding, it could have been as simple as # attn_weights = torch.bmm(query, keyT) # scale down attn_weights = attn_weights * (1.0 / math.sqrt(E_head)) # softmax attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0) # dropout if dropout_p > 0.0: attn_weights = F.dropout(attn_weights, p=dropout_p) # attention_weights bmm value # (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head) attn_output = attn_weights.bmm(value) # merge heads # (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total) # apply output projection # (N, L_t, E_total) -> (N, L_t, E_out) attn_output = F.linear(attn_output, W_out, b_out) # padding-specific step: remove output projection bias from padded entries for i in range(N): for j in range(L_t): if attn_mask_q[i, j]: attn_output[i, j, :] = 0.0 return attn_output .. GENERATED FROM PYTHON SOURCE LINES 348-349 set hyperparameters following `the Transformer paper `__ .. GENERATED FROM PYTHON SOURCE LINES 349-353 .. code-block:: default N = 512 E_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512 nheads = 8 .. GENERATED FROM PYTHON SOURCE LINES 354-355 except for dropout probability: set to 0 for correctness check .. GENERATED FROM PYTHON SOURCE LINES 355-357 .. code-block:: default dropout_p = 0.0 .. GENERATED FROM PYTHON SOURCE LINES 358-359 Let us generate some realistic fake data from Zipf's law. .. GENERATED FROM PYTHON SOURCE LINES 359-379 .. code-block:: default import numpy as np def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray: # generate fake corpus by unigram Zipf distribution # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858 sentence_lengths = np.empty(batch_size, dtype=int) for ibatch in range(batch_size): sentence_lengths[ibatch] = 1 word = np.random.zipf(alpha) while word != 3 and word != 386 and word != 858: sentence_lengths[ibatch] += 1 word = np.random.zipf(alpha) return sentence_lengths alpha = 1.2 sentence_lengths = zipf_sentence_lengths(alpha, N) L_t = np.max(sentence_lengths) L_s = L_t .. GENERATED FROM PYTHON SOURCE LINES 380-381 create inputs .. GENERATED FROM PYTHON SOURCE LINES 381-418 .. code-block:: default # create parameters W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device) W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device) W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device) W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device) # create nested input queries = [] keys = [] values = [] for i in range(N): l = sentence_lengths[i] s = l queries.append(torch.randn((l, E_q), device=device)) keys .append(torch.randn((s, E_k), device=device)) values .append(torch.randn((s, E_v), device=device)) query = torch.nested_tensor(queries) key = torch.nested_tensor(keys ) value = torch.nested_tensor(values ) # pad input padded_query = query.to_padded_tensor(0.0, (N, L_t, E_q)) padded_key = key .to_padded_tensor(0.0, (N, L_s, E_k)) padded_value = value.to_padded_tensor(0.0, (N, L_s, E_v)) # create attention masks attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool) attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool) for i in range(N): for j in range(L_t): if padded_query[i, j, :].abs().max().item() == 0.0: attn_mask_q[i, j] = True for j in range(L_s): if padded_key[i, j, :].abs().max().item() == 0.0: attn_mask_kv[i, j] = True .. GENERATED FROM PYTHON SOURCE LINES 419-420 check correctness and performance .. GENERATED FROM PYTHON SOURCE LINES 420-443 .. code-block:: default import timeit t0 = timeit.default_timer() out_nested = mha_nested( query, key, value, nheads, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t1 = timeit.default_timer() out_padded = mha_padded( padded_query, padded_key, padded_value, nheads, attn_mask_q, attn_mask_kv, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t2 = timeit.default_timer() print("nested and padded calculations differ by", (out_nested.to_padded_tensor(0.0, (N, L_t, E_out)) - out_padded).abs().max().item()) print("nested tensor multi-head attention takes", t1 - t0, "seconds") print("padded tensor multi-head attention takes", t2 - t1, "seconds") .. GENERATED FROM PYTHON SOURCE LINES 444-450 The nested tensor version avoids wasted computation on padding, so in sequential CPU execution it is faster than padded tensor version as expected. Optimization for multi-threaded environment is underway. For now, performant kernels are provided for specific use cases, e.g. self-attention evaluation by multi-head attention formula. .. GENERATED FROM PYTHON SOURCE LINES 450-456 .. code-block:: default # embeddings are assumed to be the same E = E_total mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device) mha_lib.eval() .. GENERATED FROM PYTHON SOURCE LINES 457-458 extract parameters for correctness check .. GENERATED FROM PYTHON SOURCE LINES 458-467 .. code-block:: default mha_lib.in_proj_weight.requires_grad_(False) mha_lib.in_proj_bias.requires_grad_(False) mha_lib.out_proj.weight.requires_grad_(False) mha_lib.out_proj.bias.requires_grad_(False) W_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E] W_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E] W_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :] W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias .. GENERATED FROM PYTHON SOURCE LINES 468-469 check correctness and performance .. GENERATED FROM PYTHON SOURCE LINES 469-493 .. code-block:: default t0 = timeit.default_timer() out_lib, out_lib_weights = mha_lib(query, query, query) t1 = timeit.default_timer() out_nested = mha_nested( query, query, query, nheads, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t2 = timeit.default_timer() padded_out = mha_padded( padded_query, padded_query, padded_query, nheads, attn_mask_q, attn_mask_q, W_q, W_k, W_v, W_out, b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out, dropout_p=dropout_p) t3 = timeit.default_timer() print("nested general and library calculations differ by", (out_nested.to_padded_tensor(0.0) - out_lib.to_padded_tensor(0.0)).abs().max().item()) print("nested library multi-head attention takes", t1 - t0, "seconds") print("nested general multi-head attention takes", t2 - t1, "seconds") print("padded tensor multi-head attention takes", t3 - t2, "seconds") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_prototype_nestedtensor.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: nestedtensor.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: nestedtensor.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_