| | |
| | | |
| | | # bias encoder |
| | | if self.bias_encoder_type == 'lstm': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_encoder = torch.nn.LSTM(self.inner_dim, |
| | | self.inner_dim, |
| | | 2, |
| | |
| | | self.lstm_proj = None |
| | | self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) |
| | | elif self.bias_encoder_type == 'mean': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) |
| | | else: |
| | | logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type)) |
| | |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | |
| | | # hotword |
| | | self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) |
| | | |