From 2f9685797b0c8a420574c2a459c242f90efdf3ee Mon Sep 17 00:00:00 2001
From: aky15 <ankeyuthu@gmail.com>
Date: 星期三, 24 五月 2023 14:04:54 +0800
Subject: [PATCH] support resume model from pai (#544)
---
funasr/bin/asr_infer.py | 23 +++++++++++++++++------
1 files changed, 17 insertions(+), 6 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index fc311c8..760fd07 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -9,6 +9,7 @@
import time
import copy
import os
+import re
import codecs
import tempfile
import requests
@@ -1509,8 +1510,13 @@
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+ if self.frontend is not None:
+ speech = torch.unsqueeze(speech, axis=0)
+ speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
if self.asr_model.normalize is not None:
feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@@ -1535,14 +1541,19 @@
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
- feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
- feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.frontend is not None:
+ speech = torch.unsqueeze(speech, axis=0)
+ speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
feats = to_device(feats, device=self.device)
feats_lengths = to_device(feats_lengths, device=self.device)
- enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+ enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
nbest_hyps = self.beam_search(enc_out[0])
--
Gitblit v1.9.1