游雁
2023-07-20 b9d6be45fb7da977be51a89455a61149c463aae9
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
2个文件已修改
24 ■■■■ 已修改文件
funasr/bin/asr_inference_launch.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py
@@ -1340,7 +1340,7 @@
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1:
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
@@ -1371,10 +1371,7 @@
        left_context=left_context,
        right_context=right_context,
    )
    speech2text = Speech2TextTransducer.from_pretrained(
        model_tag=model_tag,
        **speech2text_kwargs,
    )
    speech2text = Speech2TextTransducer(**speech2text_kwargs)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
funasr/utils/timestamp_tools.py
@@ -1,14 +1,10 @@
from itertools import zip_longest
import torch
import copy
import codecs
import logging
import edit_distance
import argparse
import pdb
import numpy as np
from typing import Any, List, Tuple, Union
import edit_distance
from itertools import zip_longest
def ts_prediction_lfr6_standard(us_alphas, 
@@ -36,7 +32,14 @@
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    if num_peak != len(char_list) + 1:
        logging.warning("length mismatch, result might be incorrect.")
        logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
    if num_peak > len(char_list) + 1:
        fire_place = fire_place[:len(char_list) - 1]
    elif num_peak < len(char_list) + 1:
        char_list = char_list[:num_peak + 1]
    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
        # char_list.insert(0, '<sil>')