228 lines
8.2 KiB
Python
228 lines
8.2 KiB
Python
# Copyright (c) 2019 Shigeki Karita
|
|
# 2020 Mobvoi Inc (Binbin Zhang)
|
|
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
'''
|
|
def subsequent_mask(
|
|
size: int,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size).
|
|
|
|
This mask is used only in decoder which works in an auto-regressive mode.
|
|
This means the current step could only do attention with its left steps.
|
|
|
|
In encoder, fully attention is used when streaming is not necessary and
|
|
the sequence is not long. In this case, no attention mask is needed.
|
|
|
|
When streaming is need, chunk-based attention is used in encoder. See
|
|
subsequent_chunk_mask for the chunk-based attention mask.
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
dtype (torch.device): result dtype
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
Examples:
|
|
>>> subsequent_mask(3)
|
|
[[1, 0, 0],
|
|
[1, 1, 0],
|
|
[1, 1, 1]]
|
|
"""
|
|
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
|
return torch.tril(ret)
|
|
'''
|
|
|
|
|
|
def subsequent_mask(
|
|
size: int,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size).
|
|
|
|
This mask is used only in decoder which works in an auto-regressive mode.
|
|
This means the current step could only do attention with its left steps.
|
|
|
|
In encoder, fully attention is used when streaming is not necessary and
|
|
the sequence is not long. In this case, no attention mask is needed.
|
|
|
|
When streaming is need, chunk-based attention is used in encoder. See
|
|
subsequent_chunk_mask for the chunk-based attention mask.
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
dtype (torch.device): result dtype
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
Examples:
|
|
>>> subsequent_mask(3)
|
|
[[1, 0, 0],
|
|
[1, 1, 0],
|
|
[1, 1, 1]]
|
|
"""
|
|
arange = torch.arange(size, device=device)
|
|
mask = arange.expand(size, size)
|
|
arange = arange.unsqueeze(-1)
|
|
mask = mask <= arange
|
|
return mask
|
|
|
|
|
|
def subsequent_chunk_mask(
|
|
size: int,
|
|
chunk_size: int,
|
|
num_left_chunks: int = -1,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
this is for streaming encoder
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
chunk_size (int): size of chunk
|
|
num_left_chunks (int): number of left chunks
|
|
<0: use full chunk
|
|
>=0: use num_left_chunks
|
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
Examples:
|
|
>>> subsequent_chunk_mask(4, 2)
|
|
[[1, 1, 0, 0],
|
|
[1, 1, 0, 0],
|
|
[1, 1, 1, 1],
|
|
[1, 1, 1, 1]]
|
|
"""
|
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
|
for i in range(size):
|
|
if num_left_chunks < 0:
|
|
start = 0
|
|
else:
|
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
|
ending = min((i // chunk_size + 1) * chunk_size, size)
|
|
ret[i, start:ending] = True
|
|
return ret
|
|
|
|
|
|
def add_optional_chunk_mask(xs: torch.Tensor,
|
|
masks: torch.Tensor,
|
|
use_dynamic_chunk: bool,
|
|
use_dynamic_left_chunk: bool,
|
|
decoding_chunk_size: int,
|
|
static_chunk_size: int,
|
|
num_decoding_left_chunks: int,
|
|
enable_full_context: bool = True):
|
|
""" Apply optional mask for encoder.
|
|
|
|
Args:
|
|
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
|
mask (torch.Tensor): mask for xs, (B, 1, L)
|
|
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
|
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
|
training.
|
|
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
|
0: default for training, use random dynamic chunk.
|
|
<0: for decoding, use full chunk.
|
|
>0: for decoding, use fixed chunk size as set.
|
|
static_chunk_size (int): chunk size for static chunk training/decoding
|
|
if it's greater than 0, if use_dynamic_chunk is true,
|
|
this parameter will be ignored
|
|
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
the chunk size is decoding_chunk_size.
|
|
>=0: use num_decoding_left_chunks
|
|
<0: use all left chunks
|
|
enable_full_context (bool):
|
|
True: chunk size is either [1, 25] or full context(max_len)
|
|
False: chunk size ~ U[1, 25]
|
|
|
|
Returns:
|
|
torch.Tensor: chunk mask of the input xs.
|
|
"""
|
|
# Whether to use chunk mask or not
|
|
if use_dynamic_chunk:
|
|
max_len = xs.size(1)
|
|
if decoding_chunk_size < 0:
|
|
chunk_size = max_len
|
|
num_left_chunks = -1
|
|
elif decoding_chunk_size > 0:
|
|
chunk_size = decoding_chunk_size
|
|
num_left_chunks = num_decoding_left_chunks
|
|
else:
|
|
# chunk size is either [1, 25] or full context(max_len).
|
|
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
|
# delay, the maximum frame is 100 / 4 = 25.
|
|
chunk_size = torch.randint(1, max_len, (1, )).item()
|
|
num_left_chunks = -1
|
|
if chunk_size > max_len // 2 and enable_full_context:
|
|
chunk_size = max_len
|
|
else:
|
|
chunk_size = chunk_size % 25 + 1
|
|
if use_dynamic_left_chunk:
|
|
max_left_chunks = (max_len - 1) // chunk_size
|
|
num_left_chunks = torch.randint(0, max_left_chunks,
|
|
(1, )).item()
|
|
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
|
num_left_chunks,
|
|
xs.device) # (L, L)
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
elif static_chunk_size > 0:
|
|
num_left_chunks = num_decoding_left_chunks
|
|
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
|
num_left_chunks,
|
|
xs.device) # (L, L)
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
else:
|
|
chunk_masks = masks
|
|
return chunk_masks
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|
"""Make mask tensor containing indices of padded part.
|
|
|
|
See description of make_non_pad_mask.
|
|
|
|
Args:
|
|
lengths (torch.Tensor): Batch of lengths (B,).
|
|
Returns:
|
|
torch.Tensor: Mask tensor containing indices of padded part.
|
|
|
|
Examples:
|
|
>>> lengths = [5, 3, 2]
|
|
>>> make_pad_mask(lengths)
|
|
masks = [[0, 0, 0, 0 ,0],
|
|
[0, 0, 0, 1, 1],
|
|
[0, 0, 1, 1, 1]]
|
|
"""
|
|
batch_size = lengths.size(0)
|
|
max_len = max_len if max_len > 0 else lengths.max().item()
|
|
seq_range = torch.arange(0,
|
|
max_len,
|
|
dtype=torch.int64,
|
|
device=lengths.device)
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
seq_length_expand = lengths.unsqueeze(-1)
|
|
mask = seq_range_expand >= seq_length_expand
|
|
return mask
|