wuhongsheng
2024-07-05 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8
funasr/train_utils/trainer_ds.py
@@ -30,8 +30,9 @@
            yield
    else:
        if dtype == torch.float16 or dtype == torch.bfloat16:
            with autocast(enabled=True, dtype=dtype):
                yield
            yield
            # with autocast(enabled=True, dtype=dtype):
            #     yield
        else:
            yield
@@ -477,7 +478,7 @@
                            for k_ex in self.excludes:
                                k_tmp = k.replace("module.", "")
                                if k_tmp.startswith(k_ex):
                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
                                    logging.info(f"key: {k} matching: {k_ex}, excluded")
                                    excludes_flag = True
                                    break
                        if excludes_flag: