From b9d6be45fb7da977be51a89455a61149c463aae9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 20 七月 2023 19:07:16 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/utils/timestamp_tools.py | 17 ++++++++++-------
funasr/bin/asr_inference_launch.py | 7 ++-----
2 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index fd0eecd..7ddacf0 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1340,7 +1340,7 @@
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
- if ngpu >= 1:
+ if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
@@ -1371,10 +1371,7 @@
left_context=left_context,
right_context=right_context,
)
- speech2text = Speech2TextTransducer.from_pretrained(
- model_tag=model_tag,
- **speech2text_kwargs,
- )
+ speech2text = Speech2TextTransducer(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 4e7a8a9..5787f1d 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -1,14 +1,10 @@
-from itertools import zip_longest
-
import torch
-import copy
import codecs
import logging
-import edit_distance
import argparse
-import pdb
import numpy as np
-from typing import Any, List, Tuple, Union
+import edit_distance
+from itertools import zip_longest
def ts_prediction_lfr6_standard(us_alphas,
@@ -36,7 +32,14 @@
# so treat the frames between two peaks as the duration of the former token
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
num_peak = len(fire_place)
- assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
+ if num_peak != len(char_list) + 1:
+ logging.warning("length mismatch, result might be incorrect.")
+ logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
+ if num_peak > len(char_list) + 1:
+ fire_place = fire_place[:len(char_list) - 1]
+ elif num_peak < len(char_list) + 1:
+ char_list = char_list[:num_peak + 1]
+ # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
# char_list.insert(0, '<sil>')
--
Gitblit v1.9.1