shixian.shi
2023-02-27 508e518b12f52eb84843e9cad4b3e51165bb52fb
update cif onnx
3个文件已修改
14 ■■■■■ 已修改文件
.gitignore 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/predictor/cif.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.gitignore
@@ -8,3 +8,5 @@
*.tar.gz
test_local/
RapidASR
export/*
*.pyc
funasr/export/models/predictor/cif.py
@@ -48,11 +48,11 @@
        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
        mask = mask.transpose(-1, -2).float()
        alphas = alphas * mask
        alphas = alphas.squeeze(-1)
        token_num = alphas.sum(-1)
        
        mask = mask.squeeze(-1)
        hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
        
        return acoustic_embeds, token_num, alphas, cif_peak
@@ -63,11 +63,13 @@
        
        zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
        ones_t = torch.ones_like(zeros_t)
        mask_1 = torch.cat([mask, zeros_t], dim=1)
        mask_2 = torch.cat([ones_t, mask], dim=1)
        mask = mask_2 - mask_1
        tail_threshold = mask * tail_threshold
        alphas = torch.cat([alphas, tail_threshold], dim=1)
        alphas = torch.cat([alphas, zeros_t], dim=1)
        alphas = torch.add(alphas, tail_threshold)
        
        zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
        hidden = torch.cat([hidden, zeros], dim=1)
funasr/runtime/python/onnxruntime/demo.py
@@ -1,10 +1,10 @@
from rapid_paraformer import Paraformer
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir, batch_size=1)
wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav']
result = model(wav_path)
print(result)