From 729ac44000a5c4aa23dcb1a3b80adc119a350b23 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 15 三月 2023 19:10:16 +0800
Subject: [PATCH] Merge pull request #235 from alibaba-damo-academy/dev_wjm

---
 egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py                  |   10 ++
 funasr/modules/eend_ola/encoder_decoder_attractor.py                                                       |    6 
 egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py |    7 -
 funasr/tasks/diar.py                                                                                       |   68 ++++++++--------
 funasr/bin/diar_inference_launch.py                                                                        |    3 
 funasr/bin/eend_ola_inference.py                                                                           |   24 ++++-
 tests/test_asr_inference_pipeline.py                                                                       |    2 
 funasr/models/frontend/wav_frontend.py                                                                     |   77 ++++++++++++++++--
 funasr/models/e2e_diar_eend_ola.py                                                                         |   35 +++++---
 9 files changed, 163 insertions(+), 69 deletions(-)

diff --git a/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py
new file mode 100644
index 0000000..81cb2c6
--- /dev/null
+++ b/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py
@@ -0,0 +1,10 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_diar_pipline = pipeline(
+    task=Tasks.speaker_diarization,
+    model='damo/speech_diarization_eend-ola-en-us-callhome-8k',
+    model_revision="v1.0.0",
+)
+results = inference_diar_pipline(audio_in=["https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record2.wav"])
+print(results)
\ No newline at end of file
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
index 3cb31cf..5f4563d 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/unit_test.py
@@ -14,13 +14,12 @@
 )
 
 # 浠� audio_list 浣滀负杈撳叆锛屽叾涓涓�涓煶棰戜负寰呮娴嬭闊筹紝鍚庨潰鐨勯煶棰戜负涓嶅悓璇磋瘽浜虹殑澹扮汗娉ㄥ唽璇煶
-audio_list = [[
+audio_list = [
     "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav",
     "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav",
     "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav",
     "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav"
-]]
+]
 
 results = inference_diar_pipline(audio_in=audio_list)
-for rst in results:
-    print(rst["value"])
+print(results)
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 7738f4f..70bb947 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -142,6 +142,9 @@
         else:
             kwargs["param_dict"] = param_dict
         return inference_modelscope(mode=mode, **kwargs)
+    elif mode == "eend-ola":
+        from funasr.bin.eend_ola_inference import inference_modelscope
+        return inference_modelscope(mode=mode, **kwargs)
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
index d65895f..0483278 100755
--- a/funasr/bin/eend_ola_inference.py
+++ b/funasr/bin/eend_ola_inference.py
@@ -16,6 +16,7 @@
 
 import numpy as np
 import torch
+from scipy.signal import medfilt
 from typeguard import check_argument_types
 
 from funasr.models.frontend.wav_frontend import WavFrontendMel23
@@ -146,7 +147,7 @@
         output_dir: Optional[str] = None,
         batch_size: int = 1,
         dtype: str = "float32",
-        ngpu: int = 0,
+        ngpu: int = 1,
         num_workers: int = 0,
         log_level: Union[int, str] = "INFO",
         key_file: Optional[str] = None,
@@ -179,7 +180,6 @@
         diar_model_file=diar_model_file,
         device=device,
         dtype=dtype,
-        streaming=streaming,
     )
     logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
     speech2diar = Speech2Diarization.from_pretrained(
@@ -209,7 +209,7 @@
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+            data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
         loader = EENDOLADiarTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,
@@ -236,9 +236,23 @@
             # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
 
             results = speech2diar(**batch)
+
+            # post process
+            a = results[0][0].cpu().numpy()
+            a = medfilt(a, (11, 1))
+            rst = []
+            for spkid, frames in enumerate(a.T):
+                frames = np.pad(frames, (1, 1), 'constant')
+                changes, = np.where(np.diff(frames, axis=0) != 0)
+                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
+                for s, e in zip(changes[::2], changes[1::2]):
+                    st = s / 10.
+                    dur = (e - s) / 10.
+                    rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
+
             # Only supporting batch_size==1
-            key, value = keys[0], output_results_str(results, keys[0])
-            item = {"key": key, "value": value}
+            value = "\n".join(rst)
+            item = {"key": keys[0], "value": value}
             result_list.append(item)
             if output_path is not None:
                 output_writer.write(value)
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index f589269..097b23a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -52,15 +52,15 @@
 
         super().__init__()
         self.frontend = frontend
-        self.encoder = encoder
-        self.encoder_decoder_attractor = encoder_decoder_attractor
+        self.enc = encoder
+        self.eda = encoder_decoder_attractor
         self.attractor_loss_weight = attractor_loss_weight
         self.max_n_speaker = max_n_speaker
         if mapping_dict is None:
             mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
             self.mapping_dict = mapping_dict
         # PostNet
-        self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
+        self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
         self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
 
     def forward_encoder(self, xs, ilens):
@@ -68,7 +68,7 @@
         pad_shape = xs.shape
         xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
         xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
-        emb = self.encoder(xs, xs_mask)
+        emb = self.enc(xs, xs_mask)
         emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
         emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
         return emb
@@ -76,8 +76,8 @@
     def forward_post_net(self, logits, ilens):
         maxlen = torch.max(ilens).to(torch.int).item()
         logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
-        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
-        outputs, (_, _) = self.PostNet(logits)
+        logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
+        outputs, (_, _) = self.postnet(logits)
         outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
         outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
         outputs = [self.output_layer(output) for output in outputs]
@@ -112,7 +112,7 @@
         text = text[:, : text_lengths.max()]
 
         # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -190,18 +190,16 @@
                             shuffle: bool = True,
                             threshold: float = 0.5,
                             **kwargs):
-        if self.frontend is not None:
-            speech = self.frontend(speech)
         speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
         emb = self.forward_encoder(speech, speech_lengths)
         if shuffle:
             orders = [np.arange(e.shape[0]) for e in emb]
             for order in orders:
                 np.random.shuffle(order)
-            attractors, probs = self.encoder_decoder_attractor.estimate(
+            attractors, probs = self.eda.estimate(
                 [e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
         else:
-            attractors, probs = self.encoder_decoder_attractor.estimate(emb)
+            attractors, probs = self.eda.estimate(emb)
         attractors_active = []
         for p, att, e in zip(probs, attractors, emb):
             if n_speakers and n_speakers >= 0:
@@ -233,10 +231,23 @@
                 pred[i] = pred[i - 1]
             else:
                 pred[i] = 0
-        pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
+        pred = [self.inv_mapping_func(i) for i in pred]
         decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
         decisions = torch.from_numpy(
             np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
             torch.float32)
         decisions = decisions[:, :n_speaker]
         return decisions
+
+    def inv_mapping_func(self, label):
+
+        if not isinstance(label, int):
+            label = int(label)
+        if label in self.mapping_dict['label2dec'].keys():
+            num = self.mapping_dict['label2dec'][label]
+        else:
+            num = -1
+        return num
+
+    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
+        pass
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 445efca..475a939 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -1,14 +1,15 @@
 # Copyright (c) Alibaba, Inc. and its affiliates.
 # Part of the implementation is borrowed from espnet/espnet.
-from abc import ABC
 from typing import Tuple
 
 import numpy as np
 import torch
 import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
 from torch.nn.utils.rnn import pad_sequence
+from typeguard import check_argument_types
+
+import funasr.models.frontend.eend_ola_feature as eend_ola_feature
+from funasr.models.frontend.abs_frontend import AbsFrontend
 
 
 def load_cmvn(cmvn_file):
@@ -275,7 +276,8 @@
     # inputs tensor has catted the cache tensor
     # def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
     #               is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
-    def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
+    def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
+        torch.Tensor, torch.Tensor, int]:
         """
         Apply lfr with data
         """
@@ -376,7 +378,8 @@
             if self.lfr_m != 1 or self.lfr_n != 1:
                 # update self.lfr_splice_cache in self.apply_lfr
                 # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
-                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
+                mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
+                                                                                     is_final)
             if self.cmvn_file is not None:
                 mat = self.apply_cmvn(mat, self.cmvn)
             feat_length = mat.size(0)
@@ -398,9 +401,10 @@
         assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
         waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths)  # input shape: B T D
         if feats.shape[0]:
-            #if self.reserve_waveforms is None and self.lfr_m > 1:
+            # if self.reserve_waveforms is None and self.lfr_m > 1:
             #    self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
-            self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1)
+            self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
+                (self.reserve_waveforms, waveforms), dim=1)
             if not self.lfr_splice_cache:  # 鍒濆鍖杝plice_cache
                 for i in range(batch_size):
                     self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
@@ -409,7 +413,8 @@
                 lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache)  # B T D
                 feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
                 feats_lengths += lfr_splice_cache_tensor[0].shape[0]
-                frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
+                frame_from_waveforms = int(
+                    (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
                 minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
                 feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
                 if self.lfr_m == 1:
@@ -423,14 +428,15 @@
                     self.waveforms = self.waveforms[:, :sample_length]
             else:
                 # update self.reserve_waveforms and self.lfr_splice_cache
-                self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
+                self.reserve_waveforms = self.waveforms[:,
+                                         :-(self.frame_sample_length - self.frame_shift_sample_length)]
                 for i in range(batch_size):
                     self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
                 return torch.empty(0), feats_lengths
         else:
             if is_final:
                 self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
-                feats = torch.stack(self.lfr_splice_cache) 
+                feats = torch.stack(self.lfr_splice_cache)
                 feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
                 feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
         if is_final:
@@ -444,3 +450,54 @@
         self.reserve_waveforms = None
         self.input_cache = None
         self.lfr_splice_cache = []
+
+
+class WavFrontendMel23(AbsFrontend):
+    """Conventional frontend structure for ASR.
+    """
+
+    def __init__(
+            self,
+            fs: int = 16000,
+            frame_length: int = 25,
+            frame_shift: int = 10,
+            lfr_m: int = 1,
+            lfr_n: int = 1,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        self.fs = fs
+        self.frame_length = frame_length
+        self.frame_shift = frame_shift
+        self.lfr_m = lfr_m
+        self.lfr_n = lfr_n
+        self.n_mels = 23
+
+    def output_size(self) -> int:
+        return self.n_mels * (2 * self.lfr_m + 1)
+
+    def forward(
+            self,
+            input: torch.Tensor,
+            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        batch_size = input.size(0)
+        feats = []
+        feats_lens = []
+        for i in range(batch_size):
+            waveform_length = input_lengths[i]
+            waveform = input[i][:waveform_length]
+            waveform = waveform.numpy()
+            mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
+            mat = eend_ola_feature.transform(mat)
+            mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
+            mat = mat[::self.lfr_n]
+            mat = torch.from_numpy(mat)
+            feat_length = mat.size(0)
+            feats.append(mat)
+            feats_lens.append(feat_length)
+
+        feats_lens = torch.as_tensor(feats_lens)
+        feats_pad = pad_sequence(feats,
+                                 batch_first=True,
+                                 padding_value=0.0)
+        return feats_pad, feats_lens
diff --git a/funasr/modules/eend_ola/encoder_decoder_attractor.py b/funasr/modules/eend_ola/encoder_decoder_attractor.py
index db01b00..45ac982 100644
--- a/funasr/modules/eend_ola/encoder_decoder_attractor.py
+++ b/funasr/modules/eend_ola/encoder_decoder_attractor.py
@@ -16,12 +16,12 @@
         self.n_units = n_units
 
     def forward_core(self, xs, zeros):
-        ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.float32).to(xs[0].device)
+        ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.int64)
         xs = [self.enc0_dropout(x) for x in xs]
         xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
         xs = nn.utils.rnn.pack_padded_sequence(xs, ilens, batch_first=True, enforce_sorted=False)
         _, (hx, cx) = self.encoder(xs)
-        zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.float32).to(zeros[0].device)
+        zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.int64)
         max_zlen = torch.max(zlens).to(torch.int).item()
         zeros = [self.enc0_dropout(z) for z in zeros]
         zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1)
@@ -47,4 +47,4 @@
         zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs]
         attractors = self.forward_core(xs, zeros)
         probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors]
-        return attractors, probs
\ No newline at end of file
+        return attractors, probs
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index ae7ee9b..6962915 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -750,47 +750,47 @@
             cls, args: argparse.Namespace, train: bool
     ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
         assert check_argument_types()
-        if args.use_preprocessor:
-            retval = CommonPreprocessor(
-                train=train,
-                token_type=args.token_type,
-                token_list=args.token_list,
-                bpemodel=None,
-                non_linguistic_symbols=None,
-                text_cleaner=None,
-                g2p_type=None,
-                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
-                seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-                # NOTE(kamo): Check attribute existence for backward compatibility
-                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
-                rir_apply_prob=args.rir_apply_prob
-                if hasattr(args, "rir_apply_prob")
-                else 1.0,
-                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
-                noise_apply_prob=args.noise_apply_prob
-                if hasattr(args, "noise_apply_prob")
-                else 1.0,
-                noise_db_range=args.noise_db_range
-                if hasattr(args, "noise_db_range")
-                else "13_15",
-                speech_volume_normalize=args.speech_volume_normalize
-                if hasattr(args, "rir_scp")
-                else None,
-            )
-        else:
-            retval = None
-        assert check_return_type(retval)
-        return retval
+        # if args.use_preprocessor:
+        #     retval = CommonPreprocessor(
+        #         train=train,
+        #         token_type=args.token_type,
+        #         token_list=args.token_list,
+        #         bpemodel=None,
+        #         non_linguistic_symbols=None,
+        #         text_cleaner=None,
+        #         g2p_type=None,
+        #         split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+        #         seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+        #         # NOTE(kamo): Check attribute existence for backward compatibility
+        #         rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+        #         rir_apply_prob=args.rir_apply_prob
+        #         if hasattr(args, "rir_apply_prob")
+        #         else 1.0,
+        #         noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+        #         noise_apply_prob=args.noise_apply_prob
+        #         if hasattr(args, "noise_apply_prob")
+        #         else 1.0,
+        #         noise_db_range=args.noise_db_range
+        #         if hasattr(args, "noise_db_range")
+        #         else "13_15",
+        #         speech_volume_normalize=args.speech_volume_normalize
+        #         if hasattr(args, "rir_scp")
+        #         else None,
+        #     )
+        # else:
+        #     retval = None
+        # assert check_return_type(retval)
+        return None
 
     @classmethod
     def required_data_names(
             cls, train: bool = True, inference: bool = False
     ) -> Tuple[str, ...]:
         if not inference:
-            retval = ("speech", "profile", "binary_labels")
+            retval = ("speech", )
         else:
             # Recognition mode
-            retval = ("speech")
+            retval = ("speech", )
         return retval
 
     @classmethod
@@ -823,7 +823,7 @@
 
         # 2. Encoder
         encoder_class = encoder_choices.get_class(args.encoder)
-        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+        encoder = encoder_class(**args.encoder_conf)
 
         # 3. EncoderDecoderAttractor
         encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index 70dbe89..32b8af5 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -451,7 +451,7 @@
 
     def test_uniasr_2pass_zhcn_16k_common_vocab8358_offline(self):
         inference_pipeline = pipeline(
-            task=Tasks.,
+            task=Tasks.auto_speech_recognition,
             model='damo/speech_UniASauto_speech_recognitionR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav',

--
Gitblit v1.9.1