| | |
| | | self.prompt_ids_len = 0 |
| | | self.retry = kwargs.get("retry", 5) |
| | | |
| | | self.permute = False |
| | | from funasr.frontends.whisper_frontend import WhisperFrontend |
| | | |
| | | if isinstance(self.frontend, WhisperFrontend): |
| | | self.permute = True |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_source_len(item) |
| | |
| | | |
| | | if speech_lengths > self.batch_size: |
| | | continue |
| | | if self.permute: |
| | | speech = speech.permute(0, 2, 1) |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | |
| | | task = item.get("prompt", "<|ASR|>") |
| | | text_language = item.get("text_language", "<|zh|>") |
| | | |
| | | if isinstance(self.sos, str): |
| | | prompt = f"{self.sos}{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | else: |
| | | prompt = f"{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | prompt_ids = [self.sos] + prompt_ids |
| | | |
| | | prompt_ids_len = len(prompt_ids) - 1 # [sos, task] |
| | | self.prompt_ids_len = prompt_ids_len |
| | | |
| | |
| | | if target_ids_len > 200: |
| | | continue |
| | | |
| | | if isinstance(self.eos, str): |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | else: |
| | | eos = [self.eos] |
| | | |
| | | ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos] |
| | | ids_lengths = len(ids) |
| | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | |
| | | self.ignore_id = ignore_id |
| | | |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | |
| | | self.encoder = encoder |
| | | |
| | | self.decoder = decoder |
| | |
| | | |
| | | self.error_calculator = None |
| | | |
| | | self.share_embedding = share_embedding |
| | | if self.share_embedding: |
| | | self.decoder.embed = None |
| | | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | ys_pad[ys_pad == -1] = 0 |
| | | decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | if isinstance(decoder_out, (list, tuple)): |
| | | decoder_out = decoder_out[0] |
| | |
| | | # "TypeError: can't pickle SwigPyObject objects", |
| | | # when giving it as argument of "multiprocessing.Process()". |
| | | self.sp = None |
| | | self._build_sentence_piece_processor() |
| | | |
| | | def __repr__(self): |
| | | return f'{self.__class__.__name__}(model="{self.bpemodel}")' |
| | |
| | | self._build_sentence_piece_processor() |
| | | return self.sp.DecodePieces(list(tokens)) |
| | | |
| | | def encode(self, line: str) -> List[int]: |
| | | def encode(self, line: str, **kwargs) -> List[int]: |
| | | self._build_sentence_piece_processor() |
| | | return self.sp.EncodeAsIds(line) |
| | | |
| | | def decode(self, line: List[int]): |
| | | def decode(self, line: List[int], **kwargs): |
| | | self._build_sentence_piece_processor() |
| | | return self.sp.DecodeIds(line) |
| | | |
| | | def get_vocab_size(self): |
| | | return self.sp.GetPieceSize() |