| | |
| | | self.tokenizer = tokenizer |
| | | |
| | | # 6. [Optional] Build hotword list from str, local file or url |
| | | # for None |
| | | if hotword_list_or_file is None: |
| | | self.hotword_list = None |
| | | # for text str input |
| | | elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'): |
| | | logging.info("Attempting to parse hotwords as str...") |
| | | self.hotword_list = [] |
| | | hotword_str_list = [] |
| | | for hw in hotword_list_or_file.strip().split(): |
| | | hotword_str_list.append(hw) |
| | | self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | self.hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Hotword list: {}.".format(hotword_str_list)) |
| | | # for local txt inputs |
| | | elif os.path.exists(hotword_list_or_file): |
| | | logging.info("Attempting to parse hotwords from local txt...") |
| | | self.hotword_list = [] |
| | | hotword_str_list = [] |
| | | with codecs.open(hotword_list_or_file, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | hw = line.strip() |
| | | hotword_str_list.append(hw) |
| | | self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | self.hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Initialized hotword list from file: {}, hotword list: {}." |
| | | .format(hotword_list_or_file, hotword_str_list)) |
| | | # for url, download and generate txt |
| | | else: |
| | | logging.info("Attempting to parse hotwords from url...") |
| | | work_dir = tempfile.TemporaryDirectory().name |
| | | if not os.path.exists(work_dir): |
| | | os.makedirs(work_dir) |
| | | text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) |
| | | local_file = requests.get(hotword_list_or_file) |
| | | open(text_file_path, "wb").write(local_file.content) |
| | | hotword_list_or_file = text_file_path |
| | | self.hotword_list = [] |
| | | hotword_str_list = [] |
| | | with codecs.open(hotword_list_or_file, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | hw = line.strip() |
| | | hotword_str_list.append(hw) |
| | | self.hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | self.hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Initialized hotword list from file: {}, hotword list: {}." |
| | | .format(hotword_list_or_file, hotword_str_list)) |
| | | |
| | | self.hotword_list = None |
| | | self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) |
| | | |
| | | is_use_lm = lm_weight != 0.0 and lm_file is not None |
| | | if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: |
| | |
| | | |
| | | # assert check_return_type(results) |
| | | return results |
| | | |
| | | def generate_hotwords_list(self, hotword_list_or_file): |
| | | # for None |
| | | if hotword_list_or_file is None: |
| | | hotword_list = None |
| | | # for local txt inputs |
| | | elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): |
| | | logging.info("Attempting to parse hotwords from local txt...") |
| | | hotword_list = [] |
| | | hotword_str_list = [] |
| | | with codecs.open(hotword_list_or_file, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | hw = line.strip() |
| | | hotword_str_list.append(hw) |
| | | hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Initialized hotword list from file: {}, hotword list: {}." |
| | | .format(hotword_list_or_file, hotword_str_list)) |
| | | # for url, download and generate txt |
| | | elif hotword_list_or_file.startswith('http'): |
| | | logging.info("Attempting to parse hotwords from url...") |
| | | work_dir = tempfile.TemporaryDirectory().name |
| | | if not os.path.exists(work_dir): |
| | | os.makedirs(work_dir) |
| | | text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) |
| | | local_file = requests.get(hotword_list_or_file) |
| | | open(text_file_path, "wb").write(local_file.content) |
| | | hotword_list_or_file = text_file_path |
| | | hotword_list = [] |
| | | hotword_str_list = [] |
| | | with codecs.open(hotword_list_or_file, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | hw = line.strip() |
| | | hotword_str_list.append(hw) |
| | | hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Initialized hotword list from file: {}, hotword list: {}." |
| | | .format(hotword_list_or_file, hotword_str_list)) |
| | | # for text str input |
| | | elif not hotword_list_or_file.endswith('.txt'): |
| | | logging.info("Attempting to parse hotwords as str...") |
| | | hotword_list = [] |
| | | hotword_str_list = [] |
| | | for hw in hotword_list_or_file.strip().split(): |
| | | hotword_str_list.append(hw) |
| | | hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Hotword list: {}.".format(hotword_str_list)) |
| | | else: |
| | | hotword_list = None |
| | | return hotword_list |
| | | |
| | | class Speech2TextExport: |
| | | """Speech2TextExport class |
| | |
| | | output_dir_v2: Optional[str] = None, |
| | | fs: dict = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | hotword_list_or_file = None |
| | | if param_dict is not None: |
| | | hotword_list_or_file = param_dict.get('hotword') |
| | | |
| | | if 'hotword' in kwargs: |
| | | hotword_list_or_file = kwargs['hotword'] |
| | | |
| | | if speech2text.hotword_list is None: |
| | | speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) |
| | | |
| | | # 3. Build data-iterator |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, torch.Tensor): |