九耳
2023-02-28 ee06cb9c6870d9e1579015aabfe1a84a61a5c087
funasr/punctuation/espnet_model.py
@@ -14,15 +14,18 @@
class ESPnetPunctuationModel(AbsESPnetModel):
    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
        assert check_argument_types()
        super().__init__()
        self.punc_model = punc_model
        self.punc_weight = torch.Tensor(punc_weight)
        self.sos = 1
        self.eos = 2
        # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
        self.ignore_id = ignore_id
        if self.punc_model.with_vad():
            print("This is a vad puncuation model.")
    def nll(
        self,
@@ -31,6 +34,8 @@
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        max_length: Optional[int] = None,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute negative log likelihood(nll)
@@ -49,19 +54,16 @@
        else:
            text = text[:, :max_length]
            punc = punc[:, :max_length]
        # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
        # text: (Batch, Length) -> x, y: (Batch, Length + 1)
        #x = F.pad(text, [1, 0], "constant", self.eos)
        #t = F.pad(text, [0, 1], "constant", self.ignore_id)
        #for i, l in enumerate(text_lengths):
        #    t[i, l] = self.sos
        #x_lengths = text_lengths + 1
        if self.punc_model.with_vad():
            # Should be VadRealtimeTransformer
            assert vad_indexes is not None
            y, _ = self.punc_model(text, text_lengths, vad_indexes)
        else:
            # Should be TargetDelayTransformer,
            y, _ = self.punc_model(text, text_lengths)
        # 2. Forward Language model
        # x: (Batch, Length) -> y: (Batch, Length, NVocab)
        y, _ = self.punc_model(text, text_lengths)
        # 3. Calc negative log likelihood
        # Calc negative log likelihood
        # nll: (BxL,)
        if self.training == False:
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
@@ -72,7 +74,8 @@
            nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
            return nll, text_lengths
        else:
            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), reduction="none", ignore_index=self.ignore_id)
            self.punc_weight = self.punc_weight.to(punc.device)
            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
        # nll: (BxL,) -> (BxL,)
        if max_length is None:
            nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -130,9 +133,16 @@
        assert x_lengths.size(0) == total_num
        return nll, x_lengths
    def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor,
                punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
    def forward(
        self,
        text: torch.Tensor,
        punc: torch.Tensor,
        text_lengths: torch.Tensor,
        punc_lengths: torch.Tensor,
        vad_indexes: Optional[torch.Tensor] = None,
        vad_indexes_lengths: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
        ntokens = y_lengths.sum()
        loss = nll.sum() / ntokens
        stats = dict(loss=loss.detach())
@@ -145,5 +155,12 @@
                      text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
        return {}
    def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        return self.punc_model(text, text_lengths)
    def inference(self,
                  text: torch.Tensor,
                  text_lengths: torch.Tensor,
                  vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
        if self.punc_model.with_vad():
            assert vad_indexes is not None
            return self.punc_model(text, text_lengths, vad_indexes)
        else:
            return self.punc_model(text, text_lengths)