haoneng.lhn
2023-09-21 fa61ffd3ddc45aad7c0649a21e1cc7b71516369b
funasr/bin/asr_inference_launch.py
@@ -853,7 +853,7 @@
                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
        cache["encoder"] = cache_en
        cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None}
        cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
        cache["decoder"] = cache_de
        return cache
@@ -870,7 +870,7 @@
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
            cache["encoder"] = cache_en
            cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None}
            cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
            cache["decoder"] = cache_de
        return cache
@@ -953,14 +953,15 @@
        # FIXME(kamo): The output format should be discussed about
        raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
        asr_result_list = []
        cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
        cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back,
                               decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
        item = {}
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            sample_offset = 0
            speech_length = raw_inputs.shape[1]
            stride_size = chunk_size[1] * 960
            cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1,
                                   encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
            cache = _prepare_cache(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back,
                                   decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
            final_result = ""
            for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
                if sample_offset + stride_size >= speech_length - 1:
@@ -981,8 +982,8 @@
        asr_result_list.append(item)
        if is_final:
            cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1,
                                 encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
            cache = _cache_reset(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back,
                                 decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
        return asr_result_list
    return _forward