From f59a72d24e917fb2e9560fa646ae80285dba6c95 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 15 三月 2023 10:21:32 +0800
Subject: [PATCH] release timestasmp related tools
---
funasr/utils/timestamp_tools.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 48 insertions(+), 2 deletions(-)
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 2bccd50..09c3bec 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -1,3 +1,4 @@
+from pydoc import TextRepr
from scipy.fftpack import shift
import torch
import copy
@@ -5,6 +6,7 @@
import logging
import edit_distance
import argparse
+import pdb
import numpy as np
from typing import Any, List, Tuple, Union
@@ -13,7 +15,8 @@
us_peaks,
char_list,
vad_offset=0.0,
- force_time_shift=-1.5
+ force_time_shift=-1.5,
+ sil_in_str=True
):
if not len(char_list):
return []
@@ -66,6 +69,8 @@
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
for char, timestamp in zip(new_char_list, timestamp_list):
+ #if char != '<sil>':
+ if not sil_in_str and char == '<sil>': continue
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
res = []
for char, timestamp in zip(new_char_list, timestamp_list):
@@ -233,13 +238,54 @@
return self._accumlated_shift / self._accumlated_tokens
-SUPPORTED_MODES = ['cal_aas']
+def convert_external_alphas(alphas_file, text_file, output_file):
+ from funasr.models.predictor.cif import cif_wo_hidden
+ with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
+ for line1, line2 in zip(f1.readlines(), f2.readlines()):
+ line1 = line1.rstrip()
+ line2 = line2.rstrip()
+ assert line1.split()[0] == line2.split()[0]
+ uttid = line1.split()[0]
+ alphas = [float(i) for i in line1.split()[1:]]
+ new_alphas = np.array(remove_chunk_padding(alphas))
+ new_alphas[-1] += 1e-4
+ text = line2.split()[1:]
+ if len(text) + 1 != int(new_alphas.sum()):
+ # force resize
+ new_alphas *= (len(text) + 1) / int(new_alphas.sum())
+ peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
+ if " " in text:
+ text = text.split()
+ else:
+ text = [i for i in text]
+ res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
+ force_time_shift=-7.0,
+ sil_in_str=False)
+ f3.write("{} {}\n".format(uttid, res_str))
+
+
+def remove_chunk_padding(alphas):
+ # remove the padding part in alphas if using chunk paraformer for GPU
+ START_ZERO = 45
+ MID_ZERO = 75
+ REAL_FRAMES = 360 # for chunk based encoder 10-120-10 and fsmn padding 5
+ alphas = alphas[START_ZERO:] # remove the padding at beginning
+ new_alphas = []
+ while True:
+ new_alphas = new_alphas + alphas[:REAL_FRAMES]
+ alphas = alphas[REAL_FRAMES+MID_ZERO:]
+ if len(alphas) < REAL_FRAMES: break
+ return new_alphas
+
+SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
if args.mode == 'cal_aas':
asc = AverageShiftCalculator()
asc(args.input, args.input2)
+ elif args.mode == 'read_ext_alphas':
+ convert_external_alphas(args.input, args.input2, args.output)
else:
logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
--
Gitblit v1.9.1