| | |
| | | audio_encoder_output_size = audio_encoder.output_size() |
| | | freeze = audio_encoder_conf.get("freeze", True) |
| | | freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1)) |
| | | if freeze_layer_num > 0: |
| | | freeze_layer_num = range(freeze_layer_num) |
| | | # if freeze_layer_num > 0: |
| | | # freeze_layer_num = range(freeze_layer_num) |
| | | |
| | | if freeze: |
| | | for name, param in audio_encoder.named_parameters(): |
| | | if isinstance(freeze_layer_num, (list, tuple)): |
| | | if freeze_layer_num > 0: |
| | | idx = re.search(r"\.\d+\.", name) |
| | | if idx is not None: |
| | | beg, end = idx.regs[0] |
| | | layer_id = int(name[beg + 1 : end - 1]) |
| | | if layer_id in freeze_layer_num: |
| | | if layer_id < freeze_layer_num: |
| | | param.requires_grad = False |
| | | else: |
| | | param.requires_grad = False |