From 508e518b12f52eb84843e9cad4b3e51165bb52fb Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 27 二月 2023 16:08:57 +0800
Subject: [PATCH] update cif onnx

---
 .gitignore                                |    4 +++-
 funasr/export/models/predictor/cif.py     |   10 ++++++----
 funasr/runtime/python/onnxruntime/demo.py |    4 ++--
 3 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/.gitignore b/.gitignore
index 8258377..603f712 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,4 +7,6 @@
 init_model/
 *.tar.gz
 test_local/
-RapidASR
\ No newline at end of file
+RapidASR
+export/*
+*.pyc
\ No newline at end of file
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index cb26862..6f4601d 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -48,11 +48,11 @@
 		alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
 		mask = mask.transpose(-1, -2).float()
 		alphas = alphas * mask
-		
 		alphas = alphas.squeeze(-1)
-		
 		token_num = alphas.sum(-1)
 		
+		mask = mask.squeeze(-1)
+		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
 		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
 		
 		return acoustic_embeds, token_num, alphas, cif_peak
@@ -63,12 +63,14 @@
 		
 		zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
 		ones_t = torch.ones_like(zeros_t)
+
 		mask_1 = torch.cat([mask, zeros_t], dim=1)
 		mask_2 = torch.cat([ones_t, mask], dim=1)
 		mask = mask_2 - mask_1
 		tail_threshold = mask * tail_threshold
-		alphas = torch.cat([alphas, tail_threshold], dim=1)
-		
+		alphas = torch.cat([alphas, zeros_t], dim=1)
+		alphas = torch.add(alphas, tail_threshold)
+
 		zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
 		hidden = torch.cat([hidden, zeros], dim=1)
 		token_num = alphas.sum(dim=-1)
diff --git a/funasr/runtime/python/onnxruntime/demo.py b/funasr/runtime/python/onnxruntime/demo.py
index e9c281c..9c7f2f4 100644
--- a/funasr/runtime/python/onnxruntime/demo.py
+++ b/funasr/runtime/python/onnxruntime/demo.py
@@ -1,10 +1,10 @@
 
 from rapid_paraformer import Paraformer
 
-model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
 model = Paraformer(model_dir, batch_size=1)
 
-wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav']
 
 result = model(wav_path)
 print(result)
\ No newline at end of file

--
Gitblit v1.9.1