Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ util/__pycache__/
index.html?linkid=2289031
wget-log
weights/icon_caption_florence_v2/
omnitool/gradio/uploads/
omnitool/gradio/uploads/
.DS_Store
14 changes: 13 additions & 1 deletion gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
from PIL import Image

os.environ["NO_PROXY"] = "localhost,127.0.0.1"

yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
# caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
Expand All @@ -27,7 +29,17 @@
OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
"""

DEVICE = torch.device('cuda')
# DEVICE = torch.device('cuda')
# Check if MPS is available (for Mac with Apple Silicon)
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
DEVICE = torch.device('mps')
# Fall back to CUDA if MPS is not available
elif torch.cuda.is_available():
DEVICE = torch.device('cuda')
# Fall back to CPU as last resort
else:
DEVICE = torch.device('cpu')
print("Warning: Neither MPS nor CUDA is available. Using CPU instead.")

# @spaces.GPU
# @torch.inference_mode()
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ azure-identity
numpy==1.26.4
opencv-python
opencv-python-headless
gradio
gradio==5.25.2
dill
accelerate
timm
einops==0.8.0
paddlepaddle
paddleocr
ruff==0.6.7
ruff
pre-commit==3.8.0
pytest==8.3.3
pytest-asyncio==0.23.6
Expand Down
15 changes: 14 additions & 1 deletion util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,25 @@

def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

if model_name == "blip2":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float32
)
elif device == 'mps':
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float32
).to(device)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float16
Expand All @@ -63,6 +74,8 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
elif device == 'mps':
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True).to(device)
else:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
return {'model': model.to(device), 'processor': processor}
Expand Down