Yabin Li
2024-03-04 aba47683fd4b2984dbff7fc79b0f532fc2d9f6b7
funasr/auto/auto_model.py
@@ -165,17 +165,18 @@
            kwargs["token_list"] = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
            kwargs["token_list"] = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
            vocab_size = len(kwargs["token_list"])
            vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
        else:
            vocab_size = -1
        
        # build frontend
        frontend = kwargs.get("frontend", None)
        kwargs["input_size"] = None
        if frontend is not None:
            frontend_class = tables.frontend_classes.get(frontend)
            frontend = frontend_class(**kwargs["frontend_conf"])
            kwargs["frontend"] = frontend
            kwargs["input_size"] = frontend.output_size()
            kwargs["input_size"] = frontend.output_size() if hasattr(frontend, "output_size") else None
        
        # build model
        model_class = tables.model_classes.get(kwargs["model"])
@@ -193,7 +194,7 @@
                    path=init_param,
                    ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
                    oss_bucket=kwargs.get("oss_bucket", None),
                    scope_map=kwargs.get("scope_map", "module.,None"),
                    scope_map=kwargs.get("scope_map", []),
                    excludes=kwargs.get("excludes", None),
                )
            else:
@@ -392,7 +393,8 @@
            # step.3 compute punc model
            if self.punc_model is not None:
                if not len(result["text"]):
                    result['raw_text'] = ''
                    if return_raw_text:
                        result['raw_text'] = ''
                else:
                    self.punc_kwargs.update(cfg)
                    punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
@@ -434,10 +436,13 @@
                distribute_spk(sentence_list, sv_output)
                result['sentence_info'] = sentence_list
            elif kwargs.get("sentence_timestamp", False):
                sentence_list = timestamp_sentence(punc_res[0]['punc_array'],
                                                   result['timestamp'],
                                                   raw_text,
                                                   return_raw_text=return_raw_text)
                if not len(result['text']):
                    sentence_list = []
                else:
                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'],
                                                       result['timestamp'],
                                                       raw_text,
                                                       return_raw_text=return_raw_text)
                result['sentence_info'] = sentence_list
            if "spk_embedding" in result: del result['spk_embedding']