aky15
2023-05-23 71f1059af9bd71ce87483952bbf3964fc3d3a5f9
rnnt support wav input
2个文件已修改
31 ■■■■■ 已修改文件
funasr/bin/asr_infer.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -1510,8 +1510,13 @@
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@@ -1536,14 +1541,19 @@
        
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        
        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
        
        nbest_hyps = self.beam_search(enc_out[0])
        
funasr/tasks/asr.py
@@ -363,12 +363,6 @@
            default=get_default_kwargs(CTC),
            help="The keyword arguments for CTC class.",
        )
        group.add_argument(
            "--joint_network_conf",
            action=NestedDictAction,
            default=None,
            help="The keyword arguments for joint network class.",
        )
        group = parser.add_argument_group(description="Preprocess related")
        group.add_argument(
@@ -1379,6 +1373,7 @@
    num_optimizers: int = 1
    class_choices_list = [
        model_choices,
        frontend_choices,
        specaug_choices,
        normalize_choices,
@@ -1476,7 +1471,7 @@
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("asr")
            model_class = model_choices.get_class("rnnt_unified")
        model = model_class(
            vocab_size=vocab_size,