From 8866ef268fa107e09fb458890e551a45d4848643 Mon Sep 17 00:00:00 2001 From: krolhm Date: Fri, 9 May 2025 19:06:41 +0200 Subject: [PATCH] work on MacOS :) --- scripts/image_process.py | 14 +++---- scripts/inference_triposg.py | 29 ++++++++++++-- scripts/inference_triposg_scribble.py | 12 ++++-- scripts/inference_vae.py | 13 ++++-- triposg/inference_utils.py | 40 ++++++++++++++----- .../autoencoders/autoencoder_kl_triposg.py | 2 +- triposg/pipelines/pipeline_triposg.py | 17 ++++++-- .../pipelines/pipeline_triposg_scribble.py | 6 +-- 8 files changed, 98 insertions(+), 35 deletions(-) diff --git a/scripts/image_process.py b/scripts/image_process.py index fde03b4..e40fa0d 100755 --- a/scripts/image_process.py +++ b/scripts/image_process.py @@ -17,7 +17,7 @@ def find_bounding_box(gray_image): x, y, w, h = cv2.boundingRect(max_contour) return x, y, w, h -def load_image(img_path, bg_color=None, rmbg_net=None, padding_ratio=0.1): +def load_image(img_path: str, bg_color: np.ndarray = None, rmbg_net=None, padding_ratio: float = 0.1, device: str = "cuda"): img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) if img is None: return f"invalid image path {img_path}" @@ -72,13 +72,13 @@ def rmbg(image: torch.Tensor) -> torch.Tensor: else: return f"invalid image: channels {num_channels}" - rgb_image_gpu = torch.from_numpy(rgb_image).cuda().float().permute(2, 0, 1) / 255. + rgb_image_gpu = torch.from_numpy(rgb_image).to(device).float().permute(2, 0, 1) / 255. if alpha is None: resize_transform = transforms.Resize((384, 384), antialias=True) rgb_image_resized = resize_transform(rgb_image_gpu) normalize_image = rgb_image_resized * 2 - 1 - mean_color = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).cuda() + mean_color = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1) resize_transform = transforms.Resize((1024, 1024), antialias=True) rgb_image_resized = resize_transform(rgb_image_gpu) max_value = rgb_image_resized.flatten().max() @@ -105,7 +105,7 @@ def rmbg(image: torch.Tensor) -> torch.Tensor: cleaned_alpha = remove_small_objects(labeled_alpha, min_size=200) cleaned_alpha = (cleaned_alpha > 0).astype(np.uint8) alpha = cleaned_alpha * 255 - alpha_gpu = torch.from_numpy(cleaned_alpha).cuda().float().unsqueeze(0) + alpha_gpu = torch.from_numpy(cleaned_alpha).to(device).float().unsqueeze(0) x, y, w, h = find_bounding_box(alpha) # If alpha is provided, the bounds of all foreground are used @@ -125,7 +125,7 @@ def rmbg(image: torch.Tensor) -> torch.Tensor: raise ValueError(f"input image too small") bg_gray = bg_color[0] - bg_color = torch.from_numpy(bg_color).float().cuda().repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1) + bg_color = torch.from_numpy(bg_color).float().to(device).repeat(alpha_gpu.shape[1], alpha_gpu.shape[2], 1).permute(2, 0, 1) rgb_image_gpu = rgb_image_gpu * alpha_gpu + bg_color * (1 - alpha_gpu) padding_size = [0] * 6 if w > h: @@ -140,9 +140,9 @@ def rmbg(image: torch.Tensor) -> torch.Tensor: return padded_tensor -def prepare_image(image_path, bg_color, rmbg_net=None): +def prepare_image(image_path: str, bg_color: np.ndarray, rmbg_net=None, device: str = "cuda"): if os.path.isfile(image_path): - img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net) + img_tensor = load_image(image_path, bg_color=bg_color, rmbg_net=rmbg_net, device=device) img_np = img_tensor.permute(1,2,0).cpu().numpy() img_pil = Image.fromarray((img_np*255).astype(np.uint8)) diff --git a/scripts/inference_triposg.py b/scripts/inference_triposg.py index 8076876..c5d0e83 100755 --- a/scripts/inference_triposg.py +++ b/scripts/inference_triposg.py @@ -28,13 +28,24 @@ def run_triposg( num_inference_steps: int = 50, guidance_scale: float = 7.0, faces: int = -1, + device: str = "cuda", + use_flash_decoder: bool = True, ) -> trimesh.Scene: - img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net) + img_pil = prepare_image(image_input, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net, device=device) + + effective_use_flash_decoder = use_flash_decoder + if device == 'mps' and use_flash_decoder: + try: + import diso + print("Note: Using flash_decoder on MPS. If 'diso' library is not fully compatible, issues might occur or performance might vary.") + except ImportError: + print("Warning: 'diso' library not found. 'flash_decoder' cannot be used. Falling back to hierarchical_extract_geometry.") + effective_use_flash_decoder = False outputs = pipe( image=img_pil, - generator=torch.Generator(device=pipe.device).manual_seed(seed), + generator=torch.Generator(device=device).manual_seed(seed), num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).samples[0] @@ -66,8 +77,15 @@ def simplify_mesh(mesh: trimesh.Trimesh, n_faces): return mesh if __name__ == "__main__": - device = "cuda" - dtype = torch.float16 + if torch.backends.mps.is_available(): + device = "mps" + dtype = torch.float32 + elif torch.cuda.is_available(): + device = "cuda" + dtype = torch.float16 + else: + device = "cpu" + dtype = torch.float32 parser = argparse.ArgumentParser() parser.add_argument("--image-input", type=str, required=True) @@ -76,6 +94,7 @@ def simplify_mesh(mesh: trimesh.Trimesh, n_faces): parser.add_argument("--num-inference-steps", type=int, default=50) parser.add_argument("--guidance-scale", type=float, default=7.0) parser.add_argument("--faces", type=int, default=-1) + parser.add_argument("--use-flash-decoder", action=argparse.BooleanOptionalAction, default=True) args = parser.parse_args() # download pretrained weights @@ -100,5 +119,7 @@ def simplify_mesh(mesh: trimesh.Trimesh, n_faces): num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, faces=args.faces, + device=device, + use_flash_decoder=args.use_flash_decoder, ).export(args.output_path) print(f"Mesh saved to {args.output_path}") diff --git a/scripts/inference_triposg_scribble.py b/scripts/inference_triposg_scribble.py index d325e13..2239caf 100644 --- a/scripts/inference_triposg_scribble.py +++ b/scripts/inference_triposg_scribble.py @@ -32,7 +32,7 @@ def run_triposg_scribble( outputs = pipe( image=img_pil, prompt=prompt, - generator=torch.Generator(device=pipe.device).manual_seed(seed), + generator=torch.Generator(device=pipe.device if hasattr(pipe, 'device') else "cpu").manual_seed(seed), num_inference_steps=num_inference_steps, guidance_scale=0, # this is a CFG-distilled model attention_kwargs={"cross_attention_scale": prompt_confidence, "cross_attention_2_scale": scribble_confidence}, @@ -44,8 +44,14 @@ def run_triposg_scribble( if __name__ == "__main__": - device = "cuda" - dtype = torch.float16 + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + dtype = torch.float16 if device != "cpu" else torch.float32 parser = argparse.ArgumentParser() parser.add_argument("--image-input", type=str, required=True) diff --git a/scripts/inference_vae.py b/scripts/inference_vae.py index 7634db8..4458e83 100644 --- a/scripts/inference_vae.py +++ b/scripts/inference_vae.py @@ -18,14 +18,21 @@ def load_surface(data_path, num_pc=204800): ind = rng.choice(surface.shape[0], num_pc, replace=False) surface = torch.FloatTensor(surface[ind]) normal = torch.FloatTensor(normal[ind]) - surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() + surface = torch.cat([surface, normal], dim=-1).unsqueeze(0) return surface if __name__ == "__main__": - device = "cuda" - dtype = torch.float16 + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + dtype = torch.float16 if device != "cpu" else torch.float32 + parser = argparse.ArgumentParser() parser.add_argument("--surface-input", type=str, required=True) args = parser.parse_args() diff --git a/triposg/inference_utils.py b/triposg/inference_utils.py index 6ce292b..6adf24c 100755 --- a/triposg/inference_utils.py +++ b/triposg/inference_utils.py @@ -4,11 +4,15 @@ import scipy.ndimage from skimage import measure from einops import repeat -from diso import DiffDMC import torch.nn.functional as F from triposg.utils.typing import * +try: + from diso import DiffDMC +except ImportError: + DiffDMC = None + def generate_dense_grid_points_gpu(bbox_min: torch.Tensor, bbox_max: torch.Tensor, octree_depth: int, @@ -98,7 +102,7 @@ def find_candidates_band(occupancy_grid: torch.Tensor, band_threshold: float, n_ return core_mesh_coords def expand_edge_region_fast(edge_coords, grid_size): - expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device='cuda', dtype=torch.float16, requires_grad=False) + expanded_tensor = torch.zeros(grid_size, grid_size, grid_size, device=edge_coords.device, dtype=torch.float16, requires_grad=False) expanded_tensor[edge_coords[:, 0], edge_coords[:, 1], edge_coords[:, 2]] = 1 if grid_size < 512: kernel_size = 5 @@ -186,7 +190,10 @@ def hierarchical_extract_geometry(geometric_func: Callable, # breakpoint() high_res_occupancy[indices[:, 0], indices[:, 1], indices[:, 2]] = values grid_logits = high_res_occupancy - torch.cuda.empty_cache() + if device.type == 'cuda': + torch.cuda.empty_cache() + elif device.type == 'mps': + torch.mps.empty_cache() mesh_v_f = [] try: print("final grids shape = ", grid_logits.shape) @@ -195,7 +202,10 @@ def hierarchical_extract_geometry(geometric_func: Callable, mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces)) except Exception as e: print(e) - torch.cuda.empty_cache() + if device.type == 'cuda': + torch.cuda.empty_cache() + elif device.type == 'mps': + torch.mps.empty_cache() mesh_v_f = (None, None) return [mesh_v_f] @@ -463,17 +473,25 @@ def flash_extract_geometry( grid_logits = grid_logits[0] try: print("final grids shape = ", grid_logits.shape) - dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device) - sdf = -grid_logits / octree_resolution - sdf = sdf.to(torch.float32).contiguous() - vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False) - vertices = vertices.detach().cpu().numpy() - faces = faces.detach().cpu().numpy()[:, ::-1] + if grid_logits.device.type == 'mps': + print("Warning: DiffDMC (diso library) in flash_extract_geometry might not be compatible with MPS. Using skimage.measure.marching_cubes on CPU as a fallback for this specific call if DiffDMC fails or is unavailable.") + grid_logits_cpu = grid_logits.float().cpu().numpy() + vertices, faces, _, _ = measure.marching_cubes(grid_logits_cpu, mc_level, method="lewiner") + else: + dmc = DiffDMC(dtype=torch.float32).to(grid_logits.device) + sdf = -grid_logits / octree_resolution + sdf = sdf.to(torch.float32).contiguous() + vertices, faces = dmc(sdf, deform=None, return_quads=False, normalize=False) + vertices = vertices.detach().cpu().numpy() + faces = faces.detach().cpu().numpy()[:, ::-1] vertices = vertices / (2 ** octree_depth) * bbox_size + bbox_min mesh_v_f = (vertices.astype(np.float32), np.ascontiguousarray(faces)) except Exception as e: print(e) - torch.cuda.empty_cache() + if latents.device.type == 'cuda': + torch.cuda.empty_cache() + elif latents.device.type == 'mps': + torch.mps.empty_cache() mesh_v_f = (None, None) return [mesh_v_f] \ No newline at end of file diff --git a/triposg/models/autoencoders/autoencoder_kl_triposg.py b/triposg/models/autoencoders/autoencoder_kl_triposg.py index f7ec6cd..672b077 100755 --- a/triposg/models/autoencoders/autoencoder_kl_triposg.py +++ b/triposg/models/autoencoders/autoencoder_kl_triposg.py @@ -158,7 +158,7 @@ def query_geometry( ): logits = model_fn(queries, sample) if grad: - with torch.autocast(device_type="cuda", dtype=torch.float32): + with torch.autocast(device_type=queries.device.type, dtype=torch.float32, enabled=queries.device.type != 'cpu'): if self.grad_type == "numerical": interval = self.grad_interval grad_value = [] diff --git a/triposg/pipelines/pipeline_triposg.py b/triposg/pipelines/pipeline_triposg.py index 4d113dd..ed0fab0 100755 --- a/triposg/pipelines/pipeline_triposg.py +++ b/triposg/pipelines/pipeline_triposg.py @@ -295,11 +295,22 @@ def __call__( # 7. decoder mesh - if not use_flash_decoder: + effective_use_flash_decoder = use_flash_decoder + if self.device.type == 'mps' and use_flash_decoder: + try: + import diso # type: ignore + # If diso imports, we assume it might work, but it's experimental on MPS. + # A more robust check would involve testing a small diso operation. + logger.warn("Using flash_decoder on MPS. The 'diso' library's compatibility with MPS is not fully guaranteed. If issues arise, consider setting use_flash_decoder=False.") + except ImportError: + logger.warn("'diso' library not found. 'flash_decoder' cannot be used. Falling back to hierarchical_extract_geometry.") + effective_use_flash_decoder = False + + if not effective_use_flash_decoder: geometric_func = lambda x: self.vae.decode(latents, sampled_points=x).sample output = hierarchical_extract_geometry( geometric_func, - device, + self.device, bounds=bounds, dense_octree_depth=dense_octree_depth, hierarchical_octree_depth=hierarchical_octree_depth, @@ -312,7 +323,7 @@ def __call__( bounds=bounds, octree_depth=flash_octree_depth, ) - meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output] + meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output if mesh_v_f[0] is not None and mesh_v_f[1] is not None] # Offload all models self.maybe_free_model_hooks() diff --git a/triposg/pipelines/pipeline_triposg_scribble.py b/triposg/pipelines/pipeline_triposg_scribble.py index ed2b59b..e32a4fa 100644 --- a/triposg/pipelines/pipeline_triposg_scribble.py +++ b/triposg/pipelines/pipeline_triposg_scribble.py @@ -203,7 +203,7 @@ def __call__( dense_octree_depth: int = 8, hierarchical_octree_depth: int = 9, flash_octree_depth: int = 9, - use_flash_decoder: bool = True, + use_flash_decoder: bool = False, # Defaulting to False due to boundary problems and for MPS compatibility (diso library) return_dict: bool = True, ): self._guidance_scale = guidance_scale @@ -250,7 +250,7 @@ def __call__( num_tokens, num_channels_latents, image_embeds.dtype, - device, + self.device, generator, latents, ) @@ -327,7 +327,7 @@ def __call__( bounds=bounds, octree_depth=flash_octree_depth, ) - meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output] + meshes = [trimesh.Trimesh(mesh_v_f[0].astype(np.float32), mesh_v_f[1]) for mesh_v_f in output if mesh_v_f[0] is not None and mesh_v_f[1] is not None] # Offload all models self.maybe_free_model_hooks()