Skip to content

Commit 2104e71

Browse files
committed
removing generate_from_prompt
1 parent 51b33cf commit 2104e71

File tree

1 file changed

+1
-60
lines changed

1 file changed

+1
-60
lines changed

bitmind/synthetic_data_generation/synthetic_data_generator.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -508,63 +508,4 @@ def clear_gpu(self) -> None:
508508
del self.model
509509
self.model = None
510510
gc.collect()
511-
torch.cuda.empty_cache()
512-
513-
def generate_from_prompt(
514-
self,
515-
prompt: str,
516-
task: Optional[str] = None,
517-
image: Optional[Image.Image] = None,
518-
generate_at_target_size: bool = False
519-
) -> Dict[str, Any]:
520-
"""Generate synthetic data based on a provided prompt.
521-
522-
Args:
523-
prompt: The text prompt to use for generation
524-
task: Optional task type ('t2i', 't2v', 'i2i', 'i2v')
525-
image: Optional input image for i2i or i2v generation
526-
generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions
527-
528-
Returns:
529-
Dictionary containing generated data information
530-
"""
531-
bt.logging.info(f"Generating synthetic data from provided prompt: {prompt}")
532-
533-
# Default to t2i if task is not specified
534-
if task is None:
535-
task = 't2i'
536-
537-
# If model_name is not specified, select one based on the task
538-
if self.model_name is None and self.use_random_model:
539-
bt.logging.warning(f"No model configured. Using random model.")
540-
if task == 't2i':
541-
model_candidates = T2I_MODEL_NAMES
542-
elif task == 't2v':
543-
model_candidates = T2V_MODEL_NAMES
544-
elif task == 'i2i':
545-
model_candidates = I2I_MODEL_NAMES
546-
elif task == 'i2v':
547-
model_candidates = I2V_MODEL_NAMES
548-
else:
549-
raise ValueError(f"Unsupported task: {task}")
550-
551-
self.model_name = random.choice(model_candidates)
552-
553-
# Validate input image for tasks that require it
554-
if task in ['i2i', 'i2v'] and image is None:
555-
raise ValueError(f"Input image is required for {task} generation")
556-
557-
# Run the generation with the provided prompt
558-
gen_data = self._run_generation(
559-
prompt=prompt,
560-
task=task,
561-
model_name=self.model_name,
562-
image=image,
563-
generate_at_target_size=generate_at_target_size
564-
)
565-
566-
# Clean up GPU memory
567-
self.clear_gpu()
568-
569-
return gen_data
570-
511+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)