| | |
| | | if cfg.get("layer_norm_first"): |
| | | self.norm = make_layer_norm(cfg.get("embed_dim")) |
| | | |
| | | |
| | | vocab_size = kwargs.get("vocab_size", -1) |
| | | self.proj = None |
| | | if vocab_size > 0: |
| | | self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size) |
| | | |
| | | |
| | | def forward( |
| | |
| | | # assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file) |
| | | # assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file) |
| | | granularity = kwargs.get("granularity", "utterance") |
| | | extract_embedding = kwargs.get("extract_embedding", True) |
| | | if self.proj is None: |
| | | extract_embedding = True |
| | | meta_data = {} |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | |
| | | data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000) |
| | | |
| | | results = [] |
| | | output_dir = kwargs.get("output_dir") |
| | | if output_dir: |
| | |
| | | source = source.view(1, -1) |
| | | |
| | | feats = self.extract_features(source, padding_mask=None) |
| | | x = feats['x'] |
| | | feats = feats['x'].squeeze(0).cpu().numpy() |
| | | if granularity == 'frame': |
| | | feats = feats |
| | | elif granularity == 'utterance': |
| | | feats = np.mean(feats, axis=0) |
| | | |
| | | result_i = {"key": key[i], "feats": feats} |
| | | results.append(result_i) |
| | | if output_dir: |
| | | if output_dir and extract_embedding: |
| | | np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats) |
| | | |
| | | labels = tokenizer.token_list if tokenizer is not None else [] |
| | | scores = [] |
| | | if self.proj: |
| | | x = x.mean(dim=1) |
| | | x = self.proj(x) |
| | | x = torch.softmax(x, dim=-1) |
| | | scores = x[0].tolist() |
| | | |
| | | result_i = {"key": key[i], "labels": labels, "scores": scores} |
| | | if extract_embedding: |
| | | result_i["feats"] = feats |
| | | results.append(result_i) |
| | | |
| | | |
| | | return results, meta_data |