2424 T2V_MODEL_NAMES ,
2525 T2I_MODEL_NAMES ,
2626 I2I_MODEL_NAMES ,
27+ I2V_MODEL_NAMES ,
2728 TARGET_IMAGE_SIZE ,
2829 select_random_model ,
2930 get_task ,
@@ -152,7 +153,12 @@ def batch_generate(self, batch_size: int = 5) -> None:
152153 image_sample = self .image_cache .sample ()
153154 images .append (image_sample ['image' ])
154155 bt .logging .info (f"Sampled image { i + 1 } /{ batch_size } for captioning: { image_sample ['path' ]} " )
155- prompts .append (self .generate_prompt (image = image_sample ['image' ], clear_gpu = i == batch_size - 1 ))
156+ task = get_task (self .model_name ) if self .model_name else None
157+ prompts .append (self .generate_prompt (
158+ image = image_sample ['image' ],
159+ clear_gpu = i == batch_size - 1 ,
160+ task = task
161+ ))
156162 bt .logging .info (f"Caption { i + 1 } /{ batch_size } generated: { prompts [- 1 ]} " )
157163
158164 # If specific model is set, use only that model
@@ -163,9 +169,12 @@ def batch_generate(self, batch_size: int = 5) -> None:
163169 i2i_model_names = random .sample (I2I_MODEL_NAMES , len (I2I_MODEL_NAMES ))
164170 t2i_model_names = random .sample (T2I_MODEL_NAMES , len (T2I_MODEL_NAMES ))
165171 t2v_model_names = random .sample (T2V_MODEL_NAMES , len (T2V_MODEL_NAMES ))
172+ i2v_model_names = random .sample (I2V_MODEL_NAMES , len (I2V_MODEL_NAMES ))
173+
166174 model_names = [
167- m for triple in zip_longest (t2v_model_names , t2i_model_names , i2i_model_names )
168- for m in triple if m is not None
175+ m for quad in zip_longest (t2v_model_names , t2i_model_names ,
176+ i2i_model_names , i2v_model_names )
177+ for m in quad if m is not None
169178 ]
170179
171180 # Generate for each model/prompt combination
@@ -222,7 +231,7 @@ def generate(
222231 ValueError: If real_image is None when using annotation prompt type.
223232 NotImplementedError: If prompt type is not supported.
224233 """
225- prompt = self .generate_prompt (image , clear_gpu = True )
234+ prompt = self .generate_prompt (image , clear_gpu = True , task = task )
226235 bt .logging .info ("Generating synthetic data..." )
227236 gen_data = self ._run_generation (prompt , task , model_name , image )
228237 self .clear_gpu ()
@@ -231,7 +240,8 @@ def generate(
231240 def generate_prompt (
232241 self ,
233242 image : Optional [Image .Image ] = None ,
234- clear_gpu : bool = True
243+ clear_gpu : bool = True ,
244+ task : Optional [str ] = None
235245 ) -> str :
236246 """Generate a prompt based on the specified strategy."""
237247 bt .logging .info ("Generating prompt" )
@@ -241,7 +251,7 @@ def generate_prompt(
241251 "image can't be None if self.prompt_type is 'annotation'"
242252 )
243253 self .prompt_generator .load_models ()
244- prompt = self .prompt_generator .generate (image )
254+ prompt = self .prompt_generator .generate (image , task = task )
245255 if clear_gpu :
246256 self .prompt_generator .clear_gpu ()
247257 else :
@@ -261,9 +271,9 @@ def _run_generation(
261271
262272 Args:
263273 prompt: The text prompt used to inspire the generation.
264- task: The generation task type ('t2i', 't2v', 'i2i', or None).
274+ task: The generation task type ('t2i', 't2v', 'i2i', 'i2v', or None).
265275 model_name: Optional model name to use for generation.
266- image: Optional input image for image-to-image generation.
276+ image: Optional input image for image-to-image or image-to-video generation.
267277 generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions.
268278
269279 Returns:
@@ -272,6 +282,10 @@ def _run_generation(
272282 Raises:
273283 RuntimeError: If generation fails.
274284 """
285+ # Clear CUDA cache before loading model
286+ torch .cuda .empty_cache ()
287+ gc .collect ()
288+
275289 self .load_model (model_name )
276290 model_config = MODELS [self .model_name ]
277291 task = get_task (model_name ) if task is None else task
@@ -289,14 +303,38 @@ def _run_generation(
289303
290304 gen_args ['mask_image' ], mask_center = create_random_mask (image .size )
291305 gen_args ['image' ] = image
306+ # prep image-to-video generation args
307+ elif task == 'i2v' :
308+ if image is None :
309+ raise ValueError ("image cannot be None for image-to-video generation" )
310+ # Get target size from gen_args if specified, otherwise use default
311+ target_size = (
312+ gen_args .get ('height' , 768 ),
313+ gen_args .get ('width' , 768 )
314+ )
315+ if image .size [0 ] > target_size [0 ] or image .size [1 ] > target_size [1 ]:
316+ image = image .resize (target_size , Image .Resampling .LANCZOS )
317+ gen_args ['image' ] = image
292318
293319 # Prepare generation arguments
294320 for k , v in gen_args .items ():
295321 if isinstance (v , dict ):
296322 if "min" in v and "max" in v :
297- gen_args [k ] = np .random .randint (v ['min' ], v ['max' ])
323+ # For i2v, use minimum values to save memory
324+ if task == 'i2v' :
325+ gen_args [k ] = v ['min' ]
326+ else :
327+ gen_args [k ] = np .random .randint (v ['min' ], v ['max' ])
298328 if "options" in v :
299329 gen_args [k ] = random .choice (v ['options' ])
330+ # Ensure num_frames is always an integer
331+ if k == 'num_frames' and isinstance (v , dict ):
332+ if "min" in v :
333+ gen_args [k ] = v ['min' ]
334+ elif "max" in v :
335+ gen_args [k ] = v ['max' ]
336+ else :
337+ gen_args [k ] = 24 # Default value
300338
301339 try :
302340 if generate_at_target_size :
@@ -307,6 +345,10 @@ def _run_generation(
307345 gen_args ['width' ] = gen_args ['resolution' ][1 ]
308346 del gen_args ['resolution' ]
309347
348+ # Ensure num_frames is an integer before generation
349+ if 'num_frames' in gen_args :
350+ gen_args ['num_frames' ] = int (gen_args ['num_frames' ])
351+
310352 truncated_prompt = truncate_prompt_if_too_long (prompt , self .model )
311353 bt .logging .info (f"Generating media from prompt: { truncated_prompt } " )
312354 bt .logging .info (f"Generation args: { gen_args } " )
@@ -321,8 +363,14 @@ def _run_generation(
321363 pretrained_args = model_config .get ('from_pretrained_args' , {})
322364 torch_dtype = pretrained_args .get ('torch_dtype' , torch .bfloat16 )
323365 with torch .autocast (self .device , torch_dtype , cache_enabled = False ):
366+ # Clear CUDA cache before generation
367+ torch .cuda .empty_cache ()
368+ gc .collect ()
324369 gen_output = generate (truncated_prompt , ** gen_args )
325370 else :
371+ # Clear CUDA cache before generation
372+ torch .cuda .empty_cache ()
373+ gc .collect ()
326374 gen_output = generate (truncated_prompt , ** gen_args )
327375
328376 gen_time = time .time () - start_time
@@ -334,6 +382,8 @@ def _run_generation(
334382 f"default dimensions. Error: { e } "
335383 )
336384 try :
385+ # Clear CUDA cache before retry
386+ torch .cuda .empty_cache ()
337387 gen_output = self .model (prompt = truncated_prompt )
338388 gen_time = time .time () - start_time
339389 except Exception as fallback_error :
@@ -461,5 +511,4 @@ def clear_gpu(self) -> None:
461511 del self .model
462512 self .model = None
463513 gc .collect ()
464- torch .cuda .empty_cache ()
465-
514+ torch .cuda .empty_cache ()
0 commit comments