VirtuosoQ
2024-04-26 e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc
funasr/models/fsmn_vad_streaming/model.py
@@ -284,6 +284,7 @@
      encoder_class = tables.encoder_classes.get(encoder)
      encoder = encoder_class(**encoder_conf)
      self.encoder = encoder
      self.encoder_conf = encoder_conf
   
   def ResetDetection(self, cache: dict = {}):
      cache["stats"].continous_silence_frame_count = 0
@@ -542,6 +543,11 @@
      cache["frontend"] = {}
      cache["prev_samples"] = torch.empty(0)
      cache["encoder"] = {}
      if kwargs.get("max_end_silence_time") is not None:
         # update the max_end_silence_time
         self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")
      windows_detector = WindowDetector(self.vad_opts.window_size_ms,
                                        self.vad_opts.sil_to_speech_time_thres,
                                        self.vad_opts.speech_to_sil_time_thres,
@@ -625,14 +631,15 @@
         self.init_cache(cache)
      
      ibest_writer = None
      if ibest_writer is None and kwargs.get("output_dir") is not None:
         writer = DatadirWriter(kwargs.get("output_dir"))
         ibest_writer = writer[f"{1}best_recog"]
      if kwargs.get("output_dir") is not None:
         if not hasattr(self, "writer"):
            self.writer = DatadirWriter(kwargs.get("output_dir"))
         ibest_writer = self.writer[f"{1}best_recog"]
      
      results = []
      result_i = {"key": key[0], "value": segments}
      if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
         result_i = json.dumps(result_i)
      # if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
      #    result_i = json.dumps(result_i)
      
      results.append(result_i)
      
@@ -641,6 +648,12 @@
      
      return results, meta_data
   
   def export(self, **kwargs):
      from .export_meta import export_rebuild_model
      models = export_rebuild_model(model=self, **kwargs)
      return models
   def DetectCommonFrames(self, cache: dict = {}) -> int:
      if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
         return 0