aky15
2023-04-12 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f
funasr/bin/asr_inference_mfcca.py
@@ -41,8 +41,6 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
import pdb
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
@@ -55,7 +53,7 @@
    Examples:
        >>> import soundfile
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
        >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2text(audio)
        [(text, token, token_int, hypothesis object), ...]
@@ -194,8 +192,8 @@
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        if(speech.dim()==3):
            speech = torch.squeeze(speech, 2)
        #speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        speech = speech.to(getattr(torch, self.dtype))
        # lenghts: (1,)
@@ -470,6 +468,7 @@
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    param_dict: dict = None,
    **kwargs,
):
    assert check_argument_types()
@@ -520,6 +519,9 @@
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 **kwargs,
                 ):
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
@@ -530,6 +532,8 @@
            data_path_and_name_and_type,
            dtype=dtype,
            batch_size=batch_size,
            fs=fs,
            mc=True,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
@@ -587,16 +591,6 @@
        return asr_result_list
    
    return _forward
def set_parameters(language: str = None,
                   sample_rate: Union[int, Dict[Any, int]] = None):
    if language is not None:
        global global_asr_language
        global_asr_language = language
    if sample_rate is not None:
        global global_sample_rate
        global_sample_rate = sample_rate
def get_parser():
    parser = config_argparse.ArgumentParser(