Merge pull request #673 from alibaba-damo-academy/dev_clas
contextual paraformer related update: infer and finetune
| | |
| | | |
| | | param_dict = dict() |
| | | param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt" |
| | | param_dict['clas_scale'] = 1.00 # 1.50 # set it larger if you want high recall (sacrifice general accuracy) |
| | | # 13% relative recall raise over internal hotword test set (45%->51%) |
| | | # CER might raise when utterance contains no hotword |
| | | |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.auto_speech_recognition, |
| | | model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", |
| | |
| | | nbest: int = 1, |
| | | frontend_conf: dict = None, |
| | | hotword_list_or_file: str = None, |
| | | clas_scale: float = 1.0, |
| | | decoding_ind: int = 0, |
| | | **kwargs, |
| | | ): |
| | |
| | | # 6. [Optional] Build hotword list from str, local file or url |
| | | self.hotword_list = None |
| | | self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) |
| | | self.clas_scale = clas_scale |
| | | |
| | | is_use_lm = lm_weight != 0.0 and lm_file is not None |
| | | if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: |
| | |
| | | pre_token_length = pre_token_length.round().long() |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model, |
| | | NeatContextualParaformer): |
| | | if not isinstance(self.asr_model, ContextualParaformer) and \ |
| | | not isinstance(self.asr_model, NeatContextualParaformer): |
| | | if self.hotword_list: |
| | | logging.warning("Hotword is given but asr model is not a ContextualParaformer.") |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, |
| | | pre_token_length) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | else: |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, |
| | | pre_token_length, hw_list=self.hotword_list) |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, |
| | | enc_len, |
| | | pre_acoustic_embeds, |
| | | pre_token_length, |
| | | hw_list=self.hotword_list, |
| | | clas_scale=self.clas_scale) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | if isinstance(self.asr_model, BiCifParaformer): |
| | |
| | | export_mode = param_dict.get("export_mode", False) |
| | | else: |
| | | hotword_list_or_file = None |
| | | clas_scale = param_dict.get('clas_scale', 1.0) |
| | | |
| | | if kwargs.get("device", None) == "cpu": |
| | | ngpu = 0 |
| | |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | hotword_list_or_file=hotword_list_or_file, |
| | | clas_scale=clas_scale, |
| | | ) |
| | | |
| | | speech2text = Speech2TextParaformer(**speech2text_kwargs) |
| | |
| | | finetune_configs = yaml.safe_load(f) |
| | | # set data_types |
| | | if dataset_type == "large": |
| | | finetune_configs["dataset_conf"]["data_types"] = "sound,text" |
| | | # finetune_configs["dataset_conf"]["data_types"] = "sound,text" |
| | | if 'data_types' not in finetune_configs['dataset_conf']: |
| | | finetune_configs["dataset_conf"]["data_types"] = "sound,text" |
| | | finetune_configs = update_dct(configs, finetune_configs) |
| | | for key, value in finetune_configs.items(): |
| | | if hasattr(args, key): |
| | |
| | | data_types = conf.get("data_types", "kaldi_ark,text") |
| | | |
| | | pre_hwfile = conf.get("pre_hwlist", None) |
| | | pre_prob = conf.get("pre_prob", 0) # unused yet |
| | | |
| | | hw_config = {"sample_rate": conf.get("sample_rate", 0.6), |
| | | "double_rate": conf.get("double_rate", 0.1), |
| | | "hotword_min_length": conf.get("hotword_min_length", 2), |
| | | "hotword_max_length": conf.get("hotword_max_length", 8), |
| | | "pre_prob": conf.get("pre_prob", 0.0)} |
| | | |
| | | # pre_prob = conf.get("pre_prob", 0) # unused yet |
| | | if pre_hwfile is not None: |
| | | pre_hwlist = [] |
| | | with open(pre_hwfile, 'r') as fin: |
| | |
| | | else: |
| | | pre_hwlist = None |
| | | |
| | | hw_config = {"sample_rate": conf.get("sample_rate", 0.6), |
| | | "double_rate": conf.get("double_rate", 0.1), |
| | | "hotword_min_length": conf.get("hotword_min_length", 2), |
| | | "hotword_max_length": conf.get("hotword_max_length", 8), |
| | | "pre_prob": conf.get("pre_prob", 0.0), |
| | | "pre_hwlist": pre_hwlist} |
| | | |
| | | |
| | | |
| | | dataset = AudioDataset(scp_lists, |
| | | data_names, |
| | | data_types, |
| | |
| | | sample_rate, |
| | | double_rate, |
| | | pre_prob, |
| | | pre_index=None): |
| | | pre_index=None, |
| | | pre_hwlist=None): |
| | | if length < hotword_min_length: |
| | | return [-1] |
| | | if random.random() < sample_rate: |
| | |
| | | |
| | | length = len(text) |
| | | if 'hw_tag' in data: |
| | | hotword_indxs = sample_hotword(length, **hw_config) |
| | | if hw_config['pre_hwlist'] is not None and hw_config['pre_prob'] > 0: |
| | | # enable preset hotword detect in sampling |
| | | pre_index = None |
| | | for hw in hw_config['pre_hwlist']: |
| | | hw = " ".join(seg_tokenize(hw, seg_dict)) |
| | | _find = " ".join(text).find(hw) |
| | | if _find != -1: |
| | | # _find = text[:_find].count(" ") # bpe sometimes |
| | | pre_index = [_find, _find + max(hw.count(" "), 1)] |
| | | break |
| | | hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index) |
| | | data['hotword_indxs'] = hotword_indxs |
| | | del data['hw_tag'] |
| | | for i in range(length): |
| | |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | contextual_info: torch.Tensor, |
| | | clas_scale: float = 1.0, |
| | | return_hidden: bool = False, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | |
| | | cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask) |
| | | |
| | | if self.bias_output is not None: |
| | | x = torch.cat([x_src_attn, cx], dim=2) |
| | | x = torch.cat([x_src_attn, cx*clas_scale], dim=2) |
| | | x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D |
| | | x = x_self_attn + self.dropout(x) |
| | | |
| | |
| | | input_mask_expand_dim, 0) |
| | | return sematic_embeds * tgt_mask, decoder_out * tgt_mask |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None): |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0): |
| | | if hw_list is None: |
| | | hw_list = [torch.Tensor([1]).long().to(encoder_out.device)] # empty hotword list |
| | | hw_list_pad = pad_list(hw_list, 0) |
| | |
| | | hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) |
| | | |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale |
| | | ) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |