| | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoice") |
| | | class SenseVoice(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | |
| | | model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) |
| | | model.encoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .encoder import sense_voice_encode_forward |
| | | |
| | | model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) |
| | | |
| | | # decoder |
| | | model.decoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .decoder import sense_voice_decode_forward |
| | | |
| | | model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder) |
| | | |
| | | self.model = model |
| | |
| | | specaug = specaug_class(**kwargs.get("specaug_conf", {})) |
| | | self.specaug = specaug |
| | | |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | | encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False) |
| | | |
| | | encoder_out, encoder_out_lens = checkpoint( |
| | | self.encode, speech, speech_lengths, use_reentrant=False |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) : |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | |
| | | # Forward encoder |
| | | encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(decoder_out, -1) |
| | | acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id) |
| | | acc_att = compute_accuracy( |
| | | preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id |
| | | ) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | |
| | | def inference(self, |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)) |
| | | frontend = frontend_class( |
| | | n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) |
| | | ) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | if ( |
| | | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
| | | ): # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000), |
| | | audio_sample_list = load_audio_text_image_video( |
| | | data_in, |
| | | fs=frontend.fs if hasattr(frontend, "fs") else 16000, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | tokenizer=tokenizer, |
| | | ) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | speech, speech_lengths = extract_fbank( |
| | | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
| | | ) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None) |
| | | |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | |
| | |
| | | model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) |
| | | model.encoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .encoder import sense_voice_encode_forward |
| | | |
| | | model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) |
| | | |
| | | # decoder |
| | | del model.decoder |
| | | decoder = kwargs.get("decoder", "SenseVoiceDecoder") |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class(n_vocab=dims.n_vocab, |
| | | decoder = decoder_class( |
| | | n_vocab=dims.n_vocab, |
| | | n_ctx=dims.n_text_ctx, |
| | | n_state=dims.n_text_state, |
| | | n_head=dims.n_text_head, |
| | | n_layer=dims.n_text_layer, |
| | | **kwargs.get("decoder_conf")) |
| | | **kwargs.get("decoder_conf"), |
| | | ) |
| | | model.decoder = decoder |
| | | |
| | | self.model = model |
| | |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | | encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False) |
| | | |
| | | encoder_out, encoder_out_lens = checkpoint( |
| | | self.encode, speech, speech_lengths, use_reentrant=False |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(decoder_out, -1) |
| | | acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id) |
| | | acc_att = compute_accuracy( |
| | | preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id |
| | | ) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def inference(self, |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)) |
| | | frontend = frontend_class( |
| | | n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) |
| | | ) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | if ( |
| | | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
| | | ): # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, |
| | | audio_sample_list = load_audio_text_image_video( |
| | | data_in, |
| | | fs=frontend.fs if hasattr(frontend, "fs") else 16000, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | tokenizer=tokenizer, |
| | | ) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | speech, speech_lengths = extract_fbank( |
| | | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
| | | ) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |