aky15
2023-05-23 71f1059af9bd71ce87483952bbf3964fc3d3a5f9
funasr/bin/asr_infer.py
@@ -1510,8 +1510,13 @@
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        
        if self.asr_model.normalize is not None:
            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
@@ -1536,14 +1541,19 @@
        
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        if self.frontend is not None:
            speech = torch.unsqueeze(speech, axis=0)
            speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
            feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
        
        feats = to_device(feats, device=self.device)
        feats_lengths = to_device(feats_lengths, device=self.device)
        
        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
        enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
        
        nbest_hyps = self.beam_search(enc_out[0])