diff --git a/README.md b/README.md index fb2cfa1..628ded4 100644 --- a/README.md +++ b/README.md @@ -150,15 +150,29 @@ The tracker process aggregates the model's predictions over time, building track (Configuration key: `enrichment`.) -Enrichment is an optional feature that uses an [Ollama](https://ollama.com) model to generate a more detailed description of the object that triggered a notification. If the Ollama model succeeds, the resulting description is included in the notification's message. +Enrichment is an optional feature that uses a vision AI model to generate a more detailed description of the object that triggered a notification. If the model succeeds, the resulting description is included in the notification's message. -To use enrichment, you'll need a working Ollama setup with a multimodal model installed. `driveway-monitor` does not provide this, since it's not necessary for the core feature set, and honestly it provides little additional value. +`driveway-monitor` supports two types of enrichment endpoints: + +#### Ollama Enrichment + +To use Ollama enrichment, you'll need a working [Ollama](https://ollama.com) setup with a multimodal model installed. `driveway-monitor` does not provide this, since it's not necessary for the core feature set, and honestly it provides little additional value. The best results I've gotten (which still are not stellar) are using [the LLaVA 13b model](https://ollama.com/library/llava). This usually returns a result in under 3 seconds (when running on a 2080 Ti). On a CPU or less powerful GPU, consider `llava:7b`, [`llava-llama3`](https://ollama.com/library/llava-llama3), or just skip enrichment altogether. -You can change the timeout for Ollama enrichment to generate a response by setting `enrichment.timeout_s` in your config. If you want to use enrichment, I highly recommend setting an aggressive timeout to ensure `driveway-monitor`'s responsiveness. +Set `enrichment.type` to `ollama` (default) and `enrichment.endpoint` to your Ollama API endpoint (e.g., `http://localhost:11434/api/generate`). + +#### OpenAI-Compatible Enrichment + +Alternatively, you can use any OpenAI-compatible vision API endpoint (including OpenAI's GPT-4 Vision, Azure OpenAI, or other compatible providers). + +Set `enrichment.type` to `openai`, `enrichment.endpoint` to your API endpoint (e.g., `https://api.openai.com/v1/chat/completions`), `enrichment.model` to your model name (e.g., `gpt-4o` or `gpt-4-vision-preview`), and optionally provide `enrichment.api_key` for authentication. + +#### General Configuration + +You can change the timeout for enrichment to generate a response by setting `enrichment.timeout_s` in your config. If you want to use enrichment, I highly recommend setting an aggressive timeout to ensure `driveway-monitor`'s responsiveness. -Using enrichment requires providing a _prompt file_ for each YOLO object classification (e.g. `car`, `truck`, `person`) you want to enrich. This allows giving different instructions to your Ollama model for people vs. cars, for example. The `enrichment-prompts` directory provides a useful set of prompt files to get you started. +Using enrichment requires providing a _prompt file_ for each YOLO object classification (e.g. `car`, `truck`, `person`) you want to enrich. This allows giving different instructions to your model for people vs. cars, for example. The `enrichment-prompts` directory provides a useful set of prompt files to get you started. When running `driveway-monitor` in Docker, keep in mind that your enrichment prompt files must be mounted in the container, and the paths in your config file must reflect the paths inside the container. @@ -193,13 +207,15 @@ The file is a single JSON object containing the following keys, or a subset ther - `tracker`: Configures the system that builds tracks from the model's detections over time. - `inactive_track_prune_s`: Specifies the number of seconds after which an inactive track is pruned. This prevents incorrectly adding a new prediction to an old track. - `track_connect_min_overlap`: Minimum overlap percentage of a prediction box with the average of the last 2 boxes in an existing track for the prediction to be added to that track. -- `enrichment`: Configures the subsystem that enriches notifications via the Ollama API. - - `enable`: Whether to enable enrichment via Ollama. Defaults to `false`. - - `endpoint`: Complete URL to the Ollama `/generate` endpoint, e.g. `http://localhost:11434/api/generate`. - - `keep_alive`: Ask Ollama to keep the model in memory for this long after the request. String, formatted like `60m`. [See the Ollama API docs](https://github.com/ollama/ollama/blob/main/docs/api.md#parameters). - - `model`: The name of the Ollama model to use, e.g. `llava` or `llava:13b`. - - `prompt_files`: Map of `YOLO classification name` → `path`. Each path is a file containing the prompt to give Ollama along with an image of that YOLO classification. - - `timeout_s`: Timeout for the Ollama request, in seconds. This includes connection/network time _and_ the time Ollama takes to generate a response. +- `enrichment`: Configures the subsystem that enriches notifications via a vision AI API. + - `enable`: Whether to enable enrichment. Defaults to `false`. + - `type`: Type of enrichment endpoint to use. Either `ollama` (default) or `openai`. + - `endpoint`: Complete URL to the API endpoint. For Ollama: e.g. `http://localhost:11434/api/generate`. For OpenAI-compatible: e.g. `https://api.openai.com/v1/chat/completions`. + - `model`: The name of the model to use. For Ollama: e.g. `llava` or `llava:13b`. For OpenAI-compatible: e.g. `gpt-4o` or `gpt-4-vision-preview`. + - `prompt_files`: Map of `YOLO classification name` → `path`. Each path is a file containing the prompt to give the model along with an image of that YOLO classification. + - `timeout_s`: Timeout for the API request, in seconds. This includes connection/network time _and_ the time the model takes to generate a response. + - `api_key`: (Optional) API key for authentication. Used for OpenAI-compatible endpoints. + - `keep_alive`: (Ollama only) Ask Ollama to keep the model in memory for this long after the request. String, formatted like `60m`. [See the Ollama API docs](https://github.com/ollama/ollama/blob/main/docs/api.md#parameters). - `notifier`: Configures how notifications are sent. - `debounce_threshold_s`: Specifies the number of seconds to wait after a notification before sending another one for the same type of object. - `default_priority`: Default priority for notifications. ([See Ntfy docs on Message Priority](https://docs.ntfy.sh/publish/#message-priority).) diff --git a/config.openai-example.json b/config.openai-example.json new file mode 100644 index 0000000..0c70527 --- /dev/null +++ b/config.openai-example.json @@ -0,0 +1,55 @@ +{ + "model": { + "device": "cuda", + "confidence": 0.3, + "liveness_tick_s": 60, + "fps": 15, + "healthcheck_ping_url": "https://uptimekuma.example.com:9001/api/push/abcdabcd?status=up&msg=OK&ping=" + }, + "tracker": { + "inactive_track_prune_s": 1 + }, + "notification_criteria": { + "classification_allowlist": [ + "car", + "truck", + "motorcycle", + "bus", + "person", + "bicycle" + ], + "min_track_length_s": 1.5, + "min_track_length_s_per_classification": { + "person": 3 + }, + "track_cel": "track.last_box.b.y > 0.4 && track.movement_vector.length > 0.4 && track.movement_vector.direction < 25 && track.movement_vector.direction > -80" + }, + "notifier": { + "server": "https://ntfy.example.com", + "token": "tk_0123456789ABCDEF", + "topic": "driveway-monitor", + "priorities": { + "car": "4", + "truck": "3", + "person": "4" + }, + "image_method": "attach" + }, + "enrichment": { + "enable": true, + "type": "openai", + "endpoint": "https://api.openai.com/v1/chat/completions", + "model": "gpt-4o", + "api_key": "sk-your-api-key-here", + "timeout_s": 10, + "prompt_files": { + "car": "enrichment-prompts/llava_prompt_car.txt", + "truck": "enrichment-prompts/llava_prompt_truck.txt", + "person": "enrichment-prompts/llava_prompt_person.txt" + } + }, + "web": { + "port": 5550, + "external_base_url": "https://mymachine.tailnet-example.ts.net:5559" + } +} diff --git a/config.py b/config.py index df661f4..a31d164 100644 --- a/config.py +++ b/config.py @@ -4,7 +4,7 @@ from typing import Optional from health import HealthPingerConfig -from ntfy import NtfyConfig, ImageAttachMethod, NtfyPriority +from ntfy import NtfyConfig, ImageAttachMethod, NtfyPriority, EnrichmentType from track import ModelConfig, TrackerConfig from web import WebConfig @@ -249,6 +249,16 @@ def config_from_file( if not isinstance(cfg.notifier.enrichment.enable, bool): raise ConfigValidationError("enrichment.enable must be a bool") if cfg.notifier.enrichment.enable: + enrichment_type_str = enrichment_dict.get("type", "ollama") + if enrichment_type_str: + try: + cfg.notifier.enrichment.type = EnrichmentType.from_str( + enrichment_type_str + ) + except KeyError: + raise ConfigValidationError( + "enrichment.type must be one of: ollama, openai" + ) cfg.notifier.enrichment.prompt_files = enrichment_dict.get( "prompt_files", cfg.notifier.enrichment.prompt_files ) @@ -291,11 +301,19 @@ def config_from_file( ) if not isinstance(cfg.notifier.enrichment.timeout_s, (int, float)): raise ConfigValidationError("enrichment.timeout_s must be a number") - cfg.notifier.enrichment.keep_alive = enrichment_dict.get( - "keep_alive", cfg.notifier.enrichment.keep_alive + if cfg.notifier.enrichment.type == EnrichmentType.OLLAMA: + cfg.notifier.enrichment.keep_alive = enrichment_dict.get( + "keep_alive", cfg.notifier.enrichment.keep_alive + ) + if not isinstance(cfg.notifier.enrichment.keep_alive, str): + raise ConfigValidationError("enrichment.keep_alive must be a str") + cfg.notifier.enrichment.api_key = enrichment_dict.get( + "api_key", cfg.notifier.enrichment.api_key ) - if not isinstance(cfg.notifier.enrichment.keep_alive, str): - raise ConfigValidationError("enrichment.keep_alive must be a str") + if cfg.notifier.enrichment.api_key is not None and not isinstance( + cfg.notifier.enrichment.api_key, str + ): + raise ConfigValidationError("enrichment.api_key must be a string") logger.info("config loaded & validated") return cfg diff --git a/ntfy.py b/ntfy.py index 96dba47..c4ebbd4 100644 --- a/ntfy.py +++ b/ntfy.py @@ -72,14 +72,28 @@ class NtfyRecord: jpeg_image: Optional[bytes] +class EnrichmentType(Enum): + OLLAMA = "ollama" + OPENAI = "openai" + + @staticmethod + def from_str(etype: str) -> "EnrichmentType": + return { + EnrichmentType.OLLAMA.value.lower(): EnrichmentType.OLLAMA, + EnrichmentType.OPENAI.value.lower(): EnrichmentType.OPENAI, + }[etype.lower()] + + @dataclasses.dataclass class EnrichmentConfig: enable: bool = False + type: EnrichmentType = EnrichmentType.OLLAMA endpoint: str = "" keep_alive: str = "1440m" model: str = "llava" prompt_files: Dict[str, str] = dataclasses.field(default_factory=lambda: {}) timeout_s: float = 5.0 + api_key: Optional[str] = None @dataclasses.dataclass @@ -276,6 +290,15 @@ def _enrich(self, logger, n: ObjectNotification) -> ObjectNotification: if not n.jpeg_image: return n + if self._config.enrichment.type == EnrichmentType.OLLAMA: + return self._enrich_ollama(logger, n) + elif self._config.enrichment.type == EnrichmentType.OPENAI: + return self._enrich_openai(logger, n) + else: + logger.error(f"unknown enrichment type: {self._config.enrichment.type}") + return n + + def _enrich_ollama(self, logger, n: ObjectNotification) -> ObjectNotification: prompt_file = self._config.enrichment.prompt_files.get(n.classification) if not prompt_file: return n @@ -346,9 +369,103 @@ def _enrich(self, logger, n: ObjectNotification) -> ObjectNotification: enriched_class=model_desc, ) + def _enrich_openai(self, logger, n: ObjectNotification) -> ObjectNotification: + prompt_file = self._config.enrichment.prompt_files.get(n.classification) + if not prompt_file: + return n + try: + with open(prompt_file, "r") as f: + enrichment_prompt = f.read() + except Exception as e: + logger.error(f"error reading enrichment prompt file '{prompt_file}': {e}") + return n + if not enrichment_prompt: + return n + + base64_image = base64.b64encode(n.jpeg_image).decode("ascii") + + headers = {"Content-Type": "application/json"} + if self._config.enrichment.api_key: + headers["Authorization"] = f"Bearer {self._config.enrichment.api_key}" + + try: + resp = requests.post( + self._config.enrichment.endpoint, + headers=headers, + json={ + "model": self._config.enrichment.model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": enrichment_prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + }, + }, + ], + } + ], + "max_tokens": 300, + }, + timeout=self._config.enrichment.timeout_s, + ) + resp.raise_for_status() + parsed = resp.json() + except requests.Timeout: + logger.error("enrichment request timed out") + return n + except requests.RequestException as e: + logger.error(f"enrichment failed: {e}") + return n + + try: + model_resp_str = parsed["choices"][0]["message"]["content"] + except (KeyError, IndexError) as e: + logger.error(f"enrichment response missing expected fields: {e}") + return n + + if not model_resp_str: + logger.error("enrichment response is empty") + return n + + try: + model_resp_parsed = json.loads(model_resp_str) + except json.JSONDecodeError as e: + logger.info(f"enrichment model did not produce valid JSON: {e}") + logger.info(f"response: {model_resp_str}") + return n + + if "type" not in model_resp_parsed and "error" not in model_resp_parsed: + logger.info("enrichment model did not produce expected JSON keys") + return n + + model_desc = model_resp_parsed.get("desc", "unknown") + if model_desc == "unknown" or model_desc == "": + model_err = model_resp_parsed.get("error") + if not model_err: + model_err = "(no error returned)" + logger.info( + f"enrichment model could not produce a useful description: {model_err}" + ) + return n + + return ObjectNotification( + t=n.t, + classification=n.classification, + event=n.event, + id=n.id, + jpeg_image=n.jpeg_image, + enriched_class=model_desc, + ) + def _load_enrichment_model(self, logger): if not self._config.enrichment.enable: return + if self._config.enrichment.type != EnrichmentType.OLLAMA: + return try: resp = requests.post( self._config.enrichment.endpoint,