Merge pull request #83 from alibaba-damo-academy/dev_lzr
remove useless vars and fix bug in predictor tail_process_fn
| | |
| | | |
| | | return _forward |
| | | |
| | | def set_parameters(language: str = None, |
| | | sample_rate: Union[int, Dict[Any, int]] = None): |
| | | if language is not None: |
| | | global global_asr_language |
| | | global_asr_language = language |
| | | if sample_rate is not None: |
| | | global global_sample_rate |
| | | global_sample_rate = sample_rate |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="ASR Decoding", |
| | |
| | | mask_2 = torch.cat([ones_t, mask], dim=1)
|
| | | mask = mask_2 - mask_1
|
| | | tail_threshold = mask * tail_threshold
|
| | | alphas = torch.cat([alphas, tail_threshold], dim=1)
|
| | | alphas = torch.cat([alphas, zeros_t], dim=1)
|
| | | alphas = torch.add(alphas, tail_threshold)
|
| | | else:
|
| | | tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
| | | tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
| | |
| | | mask_2 = torch.cat([ones_t, mask], dim=1)
|
| | | mask = mask_2 - mask_1
|
| | | tail_threshold = mask * tail_threshold
|
| | | alphas = torch.cat([alphas, tail_threshold], dim=1)
|
| | | alphas = torch.cat([alphas, zeros_t], dim=1)
|
| | | alphas = torch.add(alphas, tail_threshold)
|
| | | else:
|
| | | tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
| | | tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
| | |
| | | |
| | | return wav_list |
| | | |
| | | |
| | | def set_parameters(language: str = None): |
| | | if language is not None: |
| | | global global_asr_language |
| | | global_asr_language = language |
| | | |
| | | |
| | | def compute_wer(hyp_list: List[Any], |
| | | ref_list: List[Any], |
| | | lang: str = None) -> Dict[str, Any]: |
| | | assert len(hyp_list) > 0, 'hyp list is empty' |
| | | assert len(ref_list) > 0, 'ref list is empty' |
| | | |
| | | if lang is not None: |
| | | global global_asr_language |
| | | global_asr_language = lang |
| | | |
| | | rst = { |
| | | 'Wrd': 0, |
| | |
| | | 'wrong_sentences': 0 |
| | | } |
| | | |
| | | if lang is None: |
| | | lang = global_asr_language |
| | | |
| | | for h_item in hyp_list: |
| | | for r_item in ref_list: |
| | | if h_item['key'] == r_item['key']: |
| | | out_item = compute_wer_by_line(h_item['value'], |
| | | r_item['value'], |
| | | global_asr_language) |
| | | lang) |
| | | rst['Wrd'] += out_item['nwords'] |
| | | rst['Corr'] += out_item['cor'] |
| | | rst['wrong_words'] += out_item['wrong'] |