@@ -511,63 +511,4 @@ def clear_gpu(self) -> None:
511511 del self .model
512512 self .model = None
513513 gc .collect ()
514- torch .cuda .empty_cache ()
515-
516- def generate_from_prompt (
517- self ,
518- prompt : str ,
519- task : Optional [str ] = None ,
520- image : Optional [Image .Image ] = None ,
521- generate_at_target_size : bool = False
522- ) -> Dict [str , Any ]:
523- """Generate synthetic data based on a provided prompt.
524-
525- Args:
526- prompt: The text prompt to use for generation
527- task: Optional task type ('t2i', 't2v', 'i2i', 'i2v')
528- image: Optional input image for i2i or i2v generation
529- generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions
530-
531- Returns:
532- Dictionary containing generated data information
533- """
534- bt .logging .info (f"Generating synthetic data from provided prompt: { prompt } " )
535-
536- # Default to t2i if task is not specified
537- if task is None :
538- task = 't2i'
539-
540- # If model_name is not specified, select one based on the task
541- if self .model_name is None and self .use_random_model :
542- bt .logging .warning (f"No model configured. Using random model." )
543- if task == 't2i' :
544- model_candidates = T2I_MODEL_NAMES
545- elif task == 't2v' :
546- model_candidates = T2V_MODEL_NAMES
547- elif task == 'i2i' :
548- model_candidates = I2I_MODEL_NAMES
549- elif task == 'i2v' :
550- model_candidates = I2V_MODEL_NAMES
551- else :
552- raise ValueError (f"Unsupported task: { task } " )
553-
554- self .model_name = random .choice (model_candidates )
555-
556- # Validate input image for tasks that require it
557- if task in ['i2i' , 'i2v' ] and image is None :
558- raise ValueError (f"Input image is required for { task } generation" )
559-
560- # Run the generation with the provided prompt
561- gen_data = self ._run_generation (
562- prompt = prompt ,
563- task = task ,
564- model_name = self .model_name ,
565- image = image ,
566- generate_at_target_size = generate_at_target_size
567- )
568-
569- # Clean up GPU memory
570- self .clear_gpu ()
571-
572- return gen_data
573-
514+ torch .cuda .empty_cache ()
0 commit comments