huangmingming
2023-01-30 adcee8828ef5d78b575043954deb662a35e318f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Length bonus module."""
from typing import Any
from typing import List
from typing import Tuple
 
import torch
 
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
 
 
class LengthBonus(BatchScorerInterface):
    """Length bonus in beam search."""
 
    def __init__(self, n_vocab: int):
        """Initialize class.
 
        Args:
            n_vocab (int): The number of tokens in vocabulary for beam search
 
        """
        self.n = n_vocab
 
    def score(self, y, state, x):
        """Score new token.
 
        Args:
            y (torch.Tensor): 1D torch.int64 prefix tokens.
            state: Scorer state for prefix tokens
            x (torch.Tensor): 2D encoder feature that generates ys.
 
        Returns:
            tuple[torch.Tensor, Any]: Tuple of
                torch.float32 scores for next token (n_vocab)
                and None
 
        """
        return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
 
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
 
        Args:
            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (torch.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).
 
        Returns:
            tuple[torch.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, n_vocab)`
                and next state list for ys.
 
        """
        return (
            torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
                ys.shape[0], self.n
            ),
            None,
        )