From 113f8ea30a4a989a31ef993598fca0ff9158668b Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期二, 20 六月 2023 13:04:19 +0800
Subject: [PATCH] Dev lyh (#657)

---
 egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml |    1 
 funasr/models/frontend/default.py                      |  101 +++++++++++++++++++++++++++++++-------------------
 2 files changed, 64 insertions(+), 38 deletions(-)

diff --git a/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
index 47bc6bd..18614dd 100644
--- a/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
+++ b/egs/alimeeting/sa_asr/conf/train_sa_asr_conformer.yaml
@@ -10,6 +10,7 @@
     lfr_m: 1
     lfr_n: 1
     use_channel: 0
+    mc: False
 
 # encoder related
 asr_encoder: conformer
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 6718f3f..abbcd1b 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -77,8 +77,8 @@
             htk=htk,
         )
         self.n_mels = n_mels
-        self.frontend_type = "default"
         self.use_channel = use_channel
+        self.frontend_type = "default"
 
     def output_size(self) -> int:
         return self.n_mels
@@ -146,9 +146,11 @@
     def __init__(
             self,
             fs: Union[int, str] = 16000,
-            n_fft: int = 400,
-            frame_length: int = 25,
-            frame_shift: int = 10,
+            n_fft: int = 512,
+            win_length: int = None,
+            hop_length: int = None,
+            frame_length: int = None,
+            frame_shift: int = None,
             window: Optional[str] = "hann",
             center: bool = True,
             normalized: bool = False,
@@ -162,7 +164,8 @@
             use_channel: int = None,
             lfr_m: int = 1,
             lfr_n: int = 1,
-            cmvn_file: str = None
+            cmvn_file: str = None,
+            mc: bool = True
     ):
         assert check_argument_types()
         super().__init__()
@@ -171,8 +174,18 @@
 
         # Deepcopy (In general, dict shouldn't be used as default arg)
         frontend_conf = copy.deepcopy(frontend_conf)
-        self.win_length = frame_length * 16
-        self.hop_length = frame_shift * 16
+        if win_length is None and hop_length is None:
+            self.win_length = frame_length * 16
+            self.hop_length = frame_shift * 16
+        elif frame_length is None and frame_shift is None:
+            self.win_length = self.win_length
+            self.hop_length = self.hop_length
+        else:
+            logging.error(
+                "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
+                "can be set."
+            )
+            exit(1)
 
         if apply_stft:
             self.stft = Stft(
@@ -202,17 +215,19 @@
             htk=htk,
         )
         self.n_mels = n_mels
-        self.frontend_type = "default"
         self.use_channel = use_channel
-        if self.use_channel is not None:
-            logging.info("use the channel %d" % (self.use_channel))
-        else:
-            logging.info("random select channel")
-        self.cmvn_file = cmvn_file
-        if self.cmvn_file is not None:
-            mean, std = self._load_cmvn(self.cmvn_file)
-            self.register_buffer("mean", torch.from_numpy(mean))
-            self.register_buffer("std", torch.from_numpy(std))
+        self.mc = mc
+        if not self.mc:
+            if self.use_channel is not None:
+                logging.info("use the channel %d" % (self.use_channel))
+            else:
+                logging.info("random select channel")
+            self.cmvn_file = cmvn_file
+            if self.cmvn_file is not None:
+                mean, std = self._load_cmvn(self.cmvn_file)
+                self.register_buffer("mean", torch.from_numpy(mean))
+                self.register_buffer("std", torch.from_numpy(std))
+        self.frontend_type = "multichannelfrontend"
 
     def output_size(self) -> int:
         return self.n_mels
@@ -233,8 +248,8 @@
             # input_stft: (Batch, Length, [Channel], Freq)
             input_stft, _, mask = self.frontend(input_stft, feats_lens)
 
-        # 3. [Multi channel case]: Select a channel
-        if input_stft.dim() == 4:
+        # 3. [Multi channel case]: Select a channel(sa_asr)
+        if input_stft.dim() == 4 and not self.mc:
             # h: (B, T, C, F) -> h: (B, T, F)
             if self.training:
                 if self.use_channel is not None:
@@ -256,27 +271,37 @@
         # input_power: (Batch, [Channel,] Length, Freq)
         #       -> input_feats: (Batch, Length, Dim)
         input_feats, _ = self.logmel(input_power, feats_lens)
-        
-        # 6. Apply CMVN
-        if self.cmvn_file is not None:
-            if feats_lens is None:
-                feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
-            self.mean = self.mean.to(input_feats.device, input_feats.dtype)
-            self.std = self.std.to(input_feats.device, input_feats.dtype)
-            mask = make_pad_mask(feats_lens, input_feats, 1)
-
-            if input_feats.requires_grad:
-                input_feats = input_feats + self.mean
+        if self.mc:
+            # MFCCA
+            if input_feats.dim() ==4:
+                bt = input_feats.size(0)
+                channel_size = input_feats.size(2)
+                input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
+                feats_lens = feats_lens.repeat(1,channel_size).squeeze()
             else:
-                input_feats += self.mean
-            if input_feats.requires_grad:
-                input_feats = input_feats.masked_fill(mask, 0.0)
-            else:
-                input_feats.masked_fill_(mask, 0.0)
+                channel_size = 1
+            return input_feats, feats_lens, channel_size
+        else:
+            # 6. Apply CMVN
+            if self.cmvn_file is not None:
+                if feats_lens is None:
+                    feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
+                self.mean = self.mean.to(input_feats.device, input_feats.dtype)
+                self.std = self.std.to(input_feats.device, input_feats.dtype)
+                mask = make_pad_mask(feats_lens, input_feats, 1)
 
-            input_feats *= self.std
+                if input_feats.requires_grad:
+                    input_feats = input_feats + self.mean
+                else:
+                    input_feats += self.mean
+                if input_feats.requires_grad:
+                    input_feats = input_feats.masked_fill(mask, 0.0)
+                else:
+                    input_feats.masked_fill_(mask, 0.0)
 
-        return input_feats, feats_lens
+                input_feats *= self.std
+
+            return input_feats, feats_lens
 
     def _compute_stft(
             self, input: torch.Tensor, input_lengths: torch.Tensor
@@ -313,4 +338,4 @@
                     continue
         means = np.array(means_list).astype(np.float)
         vars = np.array(vars_list).astype(np.float)
-        return means, vars
\ No newline at end of file
+        return means, vars

--
Gitblit v1.9.1