From 8827e26b8d487f123f8d7d5cbd8d00b81dcefcff Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 23 二月 2024 00:58:18 +0800
Subject: [PATCH] fp16

---
 funasr/models/paraformer/cif_predictor.py |   45 ++++++++++++++++++++++++---------------------
 1 files changed, 24 insertions(+), 21 deletions(-)

diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 383d9ca..60ddc24 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -1,23 +1,25 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import torch
-from torch import nn
-from torch import Tensor
 import logging
 import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.scama.utils import sequence_mask
-from typing import Optional, Tuple
 
 from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+
 
 @tables.register("predictor_classes", "CifPredictor")
-class CifPredictor(nn.Module):
+class CifPredictor(torch.nn.Module):
     def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
         super().__init__()
 
-        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
-        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
-        self.cif_output = nn.Linear(idim, 1)
+        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+        self.cif_output = torch.nn.Linear(idim, 1)
         self.dropout = torch.nn.Dropout(p=dropout)
         self.threshold = threshold
         self.smooth_factor = smooth_factor
@@ -137,7 +139,7 @@
         return predictor_alignments.detach(), predictor_alignments_length.detach()
 
 @tables.register("predictor_classes", "CifPredictorV2")
-class CifPredictorV2(nn.Module):
+class CifPredictorV2(torch.nn.Module):
     def __init__(self,
                  idim,
                  l_order,
@@ -153,9 +155,9 @@
                  ):
         super(CifPredictorV2, self).__init__()
 
-        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
-        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
-        self.cif_output = nn.Linear(idim, 1)
+        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
+        self.cif_output = torch.nn.Linear(idim, 1)
         self.dropout = torch.nn.Dropout(p=dropout)
         self.threshold = threshold
         self.smooth_factor = smooth_factor
@@ -184,7 +186,7 @@
         alphas = alphas.squeeze(-1)
         mask = mask.squeeze(-1)
         if target_label_length is not None:
-            target_length = target_label_length
+            target_length = target_label_length.squeeze(-1)
         elif target_label is not None:
             target_length = (target_label != ignore_id).float().sum(-1)
         else:
@@ -205,7 +207,8 @@
 
         return acoustic_embeds, token_num, alphas, cif_peak
 
-    def forward_chunk(self, hidden, cache=None):
+    def forward_chunk(self, hidden, cache=None, **kwargs):
+        is_final = kwargs.get("is_final", False)
         batch_size, len_time, hidden_size = hidden.shape
         h = hidden
         context = h.transpose(1, 2)
@@ -226,14 +229,14 @@
 
         if cache is not None and "chunk_size" in cache:
             alphas[:, :cache["chunk_size"][0]] = 0.0
-            if "is_final" in cache and not cache["is_final"]:
+            if not is_final:
                 alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
         if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
             cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
             cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
             hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
             alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
-        if cache is not None and "is_final" in cache and cache["is_final"]:
+        if cache is not None and is_final:
             tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
             tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
             tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
@@ -277,7 +280,7 @@
 
         max_token_len = max(token_length)
         if max_token_len == 0:
-             return hidden, torch.stack(token_length, 0)
+             return hidden, torch.stack(token_length, 0), None, None
         list_ls = []
         for b in range(batch_size):
             pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
@@ -291,7 +294,7 @@
         cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
         cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
         cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
-        return torch.stack(list_ls, 0), torch.stack(token_length, 0)
+        return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
 
 
     def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
@@ -425,7 +428,7 @@
         return var_dict_torch_update
 
 
-class mae_loss(nn.Module):
+class mae_loss(torch.nn.Module):
 
     def __init__(self, normalize_length=False):
         super(mae_loss, self).__init__()

--
Gitblit v1.9.1