Funasr1.0 (#1343)
* funasr1.0.5
* funasr1.0.5 audio samples input
* batch_type token
* batch_type token
| | |
| | | |
| | | # example2 |
| | | import torchaudio |
| | | import os |
| | | wav_file = os.path.join(model.model_path, "example/asr_example.wav") |
| | | input_tensor, sample_rate = torchaudio.load(wav_file) |
| | | input_tensor = input_tensor.mean(0) |
| | |
| | | |
| | | # example3 |
| | | import soundfile |
| | | import os |
| | | |
| | | wav_file = os.path.join(model.model_path, "example/asr_example.wav") |
| | | speech, sample_rate = soundfile.read(wav_file) |
| | | res = model.generate(input=[speech], batch_size_s=300, is_final=True) |
| | |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) |
| | | batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | |
| | | self.max_token_length = kwargs.get("max_token_length", 5000) |
| | | self.shuffle_idx = np.arange(self.total_samples) |
| | | self.shuffle = shuffle and is_training |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | |
| | | |
| | | def __len__(self): |
| | | return (self.total_samples-1) // self.batch_size + 1 |
| | |
| | | |
| | | idx_map = self.shuffle_idx[idx] |
| | | # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] |
| | | sample_len_cur = self.dataset.get_source_len(idx_map) + \ |
| | | self.dataset.get_target_len(idx_map) |
| | | target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 |
| | | source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source |
| | | sample_len_cur = source_len + target_len |
| | | |
| | | |
| | | datalen_with_index.append([idx, sample_len_cur]) |
| | | |
| | |
| | | |
| | | max_token_cur = max(max_token, sample_len_cur_raw) |
| | | max_token_padding = 1 + num_sample |
| | | if self.batch_type == 'length': |
| | | if self.batch_type != 'example': |
| | | max_token_padding *= max_token_cur |
| | | if max_token_padding <= self.batch_size: |
| | | batch.append(idx) |
| | |
| | | |
| | | |
| | | from funasr.models.whisper.utils.decoding import detect_language as detect_language_function, decode as decode_function |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @dataclass |
| | | class ModelDimensions: |
| | |
| | | return x |
| | | |
| | | |
| | | |
| | | @tables.register("encoder_classes", "WhisperEncoder") |
| | | class AudioEncoder(nn.Module): |
| | | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
| | | super().__init__() |
| | |
| | | x = self.ln_post(x) |
| | | return x |
| | | |
| | | |
| | | @tables.register("decoder_classes", "WhisperDecoder") |
| | | class TextDecoder(nn.Module): |
| | | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
| | | super().__init__() |
| | |
| | | |
| | | return logits |
| | | |
| | | |
| | | @tables.register("model_classes", "Whisper") |
| | | class Whisper(nn.Module): |
| | | def __init__(self, dims: dict): |
| | | super().__init__() |