diff --git a/src/models/deprecation.py b/src/models/deprecation.py index 4a57278..1290acc 100644 --- a/src/models/deprecation.py +++ b/src/models/deprecation.py @@ -1,50 +1,104 @@ -"""Data models for deprecation entries.""" +"""Deprecation model for AI model deprecation tracking.""" -import hashlib from datetime import UTC, datetime -from typing import Any -from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator -class Deprecation(BaseModel): - """Model for AI model deprecation information.""" +class DeprecationEntry(BaseModel): + """Model representing an AI model deprecation entry.""" - provider: str = Field(description="Provider name (e.g., 'OpenAI', 'Anthropic')") - model: str = Field(description="Affected model name") - deprecation_date: datetime = Field(description="When the deprecation was announced") - retirement_date: datetime = Field(description="When the model stops working") - replacement: str | None = Field(default=None, description="Suggested alternative model") - notes: str | None = Field(default=None, description="Additional context") - source_url: HttpUrl = Field(description="URL where the deprecation info came from") + provider: str = Field(..., description="Provider name (e.g., OpenAI, Anthropic)") + model: str = Field(..., description="Model name or identifier") + deprecation_date: datetime = Field(..., description="Date when deprecation was announced") + retirement_date: datetime = Field(..., description="Date when model stops working") + replacement: str | None = Field(None, description="Suggested alternative model") + notes: str | None = Field(None, description="Additional context or information") + source_url: str = Field(..., description="Link to official announcement") last_updated: datetime = Field( default_factory=lambda: datetime.now(UTC), - description="When we last checked this information", - ) - # Alias for compatibility with main branch - created_at: datetime = Field( - default_factory=lambda: datetime.now(UTC), - description="When entry was created (alias for last_updated)", + description="When this entry was last updated", ) - @field_validator("deprecation_date", "retirement_date", "last_updated") + @field_validator("provider", "model") + @classmethod + def validate_non_empty_string(cls, v: str) -> str: + """Ensure provider and model are non-empty strings.""" + if not v or not v.strip(): + raise ValueError("Field must be a non-empty string") + return v.strip() + + @field_validator("source_url") @classmethod - def ensure_utc_timezone(cls, v: datetime) -> datetime: - """Ensure datetime fields have UTC timezone.""" - if v.tzinfo is None: - return v.replace(tzinfo=UTC) - return v.astimezone(UTC) + def validate_url(cls, v: str) -> str: + """Basic URL validation.""" + if not v.startswith(("http://", "https://")): + raise ValueError("URL must start with http:// or https://") + return v @model_validator(mode="after") - def validate_dates(self) -> "Deprecation": - """Validate that retirement_date is after deprecation_date.""" + def validate_dates(self) -> "DeprecationEntry": + """Ensure retirement date is after deprecation date.""" if self.retirement_date <= self.deprecation_date: - raise ValueError("retirement_date must be after deprecation_date") + raise ValueError("Retirement date must be after deprecation date") return self + def to_rss_item(self) -> dict[str, str | datetime]: + """Convert to RSS item dictionary.""" + description_parts = [ + f"Provider: {self.provider}", + f"Model: {self.model}", + f"Deprecation Date: {self.deprecation_date.isoformat()}", + f"Retirement Date: {self.retirement_date.isoformat()}", + ] + + if self.replacement: + description_parts.append(f"Replacement: {self.replacement}") + + if self.notes: + description_parts.append(f"Notes: {self.notes}") + + description = "\n".join(description_parts) + + title = f"{self.provider} - {self.model} Deprecation" + + return { + "title": title, + "description": description, + "link": self.source_url, + "guid": f"{self.provider}-{self.model}-{self.deprecation_date.isoformat()}", + "pubDate": self.deprecation_date, + } + + def to_json_dict(self) -> dict[str, str | None]: + """Convert to JSON-serializable dictionary.""" + return { + "provider": self.provider, + "model": self.model, + "deprecation_date": self.deprecation_date.isoformat(), + "retirement_date": self.retirement_date.isoformat(), + "replacement": self.replacement, + "notes": self.notes, + "source_url": self.source_url, + "last_updated": self.last_updated.isoformat(), + } + + model_config = { + "json_encoders": { + datetime: lambda v: v.isoformat(), + } + } + + # Compatibility methods for main branch + def is_active(self) -> bool: + """Check if deprecation is still active (not yet retired).""" + now = datetime.now(UTC) + return self.retirement_date > now + def get_hash(self) -> str: - """Generate hash of core deprecation data (excluding last_updated).""" - # Include all fields that identify the unique deprecation, excluding last_updated + """Generate hash for core deprecation data.""" + import hashlib + core_data = { "provider": self.provider, "model": self.model, @@ -52,84 +106,48 @@ def get_hash(self) -> str: "retirement_date": self.retirement_date.isoformat(), "replacement": self.replacement, "notes": self.notes, - "source_url": str(self.source_url), + "source_url": self.source_url, } - - # Create deterministic string representation data_str = str(sorted(core_data.items())) return hashlib.sha256(data_str.encode()).hexdigest() def get_identity_hash(self) -> str: - """Generate hash for identifying same deprecation (for updates).""" - # Only include immutable fields that identify the unique deprecation + """Generate hash for identifying same deprecation.""" + import hashlib + identity_data = { "provider": self.provider, "model": self.model, "deprecation_date": self.deprecation_date.isoformat(), "retirement_date": self.retirement_date.isoformat(), - "source_url": str(self.source_url), + "source_url": self.source_url, } - - # Create deterministic string representation data_str = str(sorted(identity_data.items())) return hashlib.sha256(data_str.encode()).hexdigest() - def same_deprecation(self, other: "Deprecation") -> bool: - """Check if this represents the same deprecation (for updates).""" + def same_deprecation(self, other: "DeprecationEntry") -> bool: + """Check if this represents the same deprecation.""" return self.get_identity_hash() == other.get_identity_hash() def __eq__(self, other: object) -> bool: """Compare deprecations based on core data (excluding last_updated).""" - if not isinstance(other, Deprecation): + if not isinstance(other, DeprecationEntry): return False return self.get_hash() == other.get_hash() - def __hash__(self) -> int: - """Hash based on core data for use in sets/dicts.""" - return hash(self.get_hash()) - def __str__(self) -> str: """String representation of deprecation.""" return ( - f"Deprecation({self.provider} {self.model}: " - f"{self.deprecation_date.date()} -> {self.retirement_date.date()})" + f"DeprecationEntry(provider='{self.provider}', model='{self.model}', " + f"deprecation_date='{self.deprecation_date.date()}', " + f"retirement_date='{self.retirement_date.date()}')" ) - def is_active(self) -> bool: - """Check if the deprecation is still active (not yet retired).""" - now = datetime.now(UTC) - return self.retirement_date > now - - def to_rss_item(self) -> dict[str, Any]: - """Convert deprecation to RSS item format (compatibility with main).""" - title = f"{self.provider}: {self.model} Deprecation" - description_parts = [ - f"Model: {self.model}", - f"Provider: {self.provider}", - f"Deprecation Date: {self.deprecation_date.strftime('%Y-%m-%d')}", - f"Retirement Date: {self.retirement_date.strftime('%Y-%m-%d')}", - ] - if self.replacement: - description_parts.append(f"Replacement: {self.replacement}") - if self.notes: - description_parts.append(f"Notes: {self.notes}") - - return { - "title": title, - "description": " | ".join(description_parts), - "guid": str(self.source_url), - "pubDate": self.created_at, - "link": str(self.source_url), - } - - def __repr__(self) -> str: - """Detailed string representation.""" - return ( - f"Deprecation(provider='{self.provider}', model='{self.model}', " - f"deprecation_date={self.deprecation_date.isoformat()}, " - f"retirement_date={self.retirement_date.isoformat()})" - ) + @property + def created_at(self) -> datetime: + """Alias for compatibility with main branch.""" + return self.deprecation_date -# Main has DeprecationEntry, we use Deprecation, create alias for compatibility -DeprecationEntry = Deprecation +# Alias for compatibility with other branches that may use Deprecation +Deprecation = DeprecationEntry diff --git a/src/rss/config.py b/src/rss/config.py new file mode 100644 index 0000000..6f34706 --- /dev/null +++ b/src/rss/config.py @@ -0,0 +1,134 @@ +"""RSS feed configuration.""" + +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field + + +class FeedConfig(BaseModel): + """Configuration for RSS feed generation.""" + + title: str = Field( + default="AI Model Deprecations", + description="Feed title", + ) + description: str = Field( + default="Daily-updated RSS feed tracking AI model deprecations across providers", + description="Feed description", + ) + link: str = Field( + default="https://deprecations.example.com", + description="Feed website link", + ) + language: str = Field( + default="en", + description="Feed language code", + ) + copyright: str | None = Field( + default=None, + description="Copyright information", + ) + managing_editor: str | None = Field( + default=None, + description="Managing editor email", + ) + webmaster: str | None = Field( + default=None, + description="Webmaster email", + ) + ttl: int = Field( + default=1440, + description="Time to live in minutes (default 24 hours)", + gt=0, + ) + + model_config = {"validate_assignment": True} + + +class VersionConfig(BaseModel): + """Configuration for RSS feed versioning.""" + + version: str = Field( + default="v1", + description="Feed version identifier", + pattern=r"^v\d+$", + ) + supported_versions: list[str] = Field( + default_factory=lambda: ["v1"], + description="List of supported versions", + ) + + def is_version_supported(self, version: str) -> bool: + """Check if a version is supported.""" + return version in self.supported_versions + + model_config = {"validate_assignment": True} + + +class OutputConfig(BaseModel): + """Configuration for RSS feed output paths.""" + + base_path: Path = Field( + default=Path("output/rss"), + description="Base output directory for RSS feeds", + ) + filename: str = Field( + default="feed.xml", + description="RSS feed filename", + ) + + def get_versioned_path(self, version: str) -> Path: + """Get the full path for a versioned feed.""" + return self.base_path / version / self.filename + + def ensure_directories(self, version: str) -> None: + """Ensure output directories exist for a given version.""" + versioned_dir = self.base_path / version + versioned_dir.mkdir(parents=True, exist_ok=True) + + model_config = {"validate_assignment": True} + + +class RSSConfig(BaseModel): + """Complete RSS configuration.""" + + feed: FeedConfig = Field( + default_factory=FeedConfig, + description="Feed metadata configuration", + ) + version: VersionConfig = Field( + default_factory=VersionConfig, + description="Version configuration", + ) + output: OutputConfig = Field( + default_factory=OutputConfig, + description="Output path configuration", + ) + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> "RSSConfig": + """Create RSSConfig from dictionary.""" + return cls( + feed=FeedConfig(**config_dict.get("feed", {})), + version=VersionConfig(**config_dict.get("version", {})), + output=OutputConfig(**config_dict.get("output", {})), + ) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "feed": self.feed.model_dump(), + "version": self.version.model_dump(), + "output": { + "base_path": str(self.output.base_path), + "filename": self.output.filename, + }, + } + + model_config = {"validate_assignment": True} + + +def get_default_config() -> RSSConfig: + """Get default RSS configuration.""" + return RSSConfig() diff --git a/src/rss/generator.py b/src/rss/generator.py new file mode 100644 index 0000000..d41477f --- /dev/null +++ b/src/rss/generator.py @@ -0,0 +1,228 @@ +"""RSS feed generator for AI model deprecations.""" + +from datetime import UTC, datetime +from pathlib import Path + +from feedgen.feed import FeedGenerator # type: ignore[import-untyped] + +from src.models.deprecation import DeprecationEntry +from src.rss.config import RSSConfig, get_default_config + + +class RSSGenerator: + """Generator for creating RSS feeds from deprecation entries.""" + + def __init__(self, config: RSSConfig | None = None) -> None: + """Initialize RSS generator with configuration. + + Args: + config: RSS configuration. Uses default if not provided. + """ + self.config = config or get_default_config() + self._feeds: dict[str, FeedGenerator] = {} + + def _create_feed(self, version: str) -> FeedGenerator: + """Create a new feed generator for a specific version. + + Args: + version: Version identifier (e.g., "v1") + + Returns: + Configured FeedGenerator instance + """ + fg = FeedGenerator() + + fg.title(self.config.feed.title) + fg.description(self.config.feed.description) + fg.link(href=self.config.feed.link, rel="alternate") + fg.language(self.config.feed.language) + + if self.config.feed.copyright: + fg.copyright(self.config.feed.copyright) + + if self.config.feed.managing_editor: + fg.managingEditor(self.config.feed.managing_editor) + + if self.config.feed.webmaster: + fg.webMaster(self.config.feed.webmaster) + + fg.ttl(self.config.feed.ttl) + + fg.generator(generator="deprecations-rss", version=version) + fg.lastBuildDate(datetime.now(UTC)) + + return fg + + def get_feed(self, version: str = "v1") -> FeedGenerator: + """Get or create a feed generator for a specific version. + + Args: + version: Version identifier + + Returns: + FeedGenerator for the specified version + + Raises: + ValueError: If version is not supported + """ + if not self.config.version.is_version_supported(version): + raise ValueError( + f"Version {version} not supported. " + f"Supported versions: {', '.join(self.config.version.supported_versions)}" + ) + + if version not in self._feeds: + self._feeds[version] = self._create_feed(version) + + return self._feeds[version] + + def add_entry( + self, + entry: DeprecationEntry, + version: str = "v1", + ) -> None: + """Add a deprecation entry to the feed. + + Args: + entry: Deprecation entry to add + version: Version of feed to add entry to + """ + feed = self.get_feed(version) + rss_item = entry.to_rss_item() + + fe = feed.add_entry() + fe.title(rss_item["title"]) + fe.description(rss_item["description"]) + + if rss_item["link"]: + fe.link(href=rss_item["link"]) + + fe.guid(rss_item["guid"], permalink=False) + + pubdate = rss_item["pubDate"] + if isinstance(pubdate, datetime) and pubdate.tzinfo is None: + pubdate = pubdate.replace(tzinfo=UTC) + fe.pubDate(pubdate) + + def add_entries( + self, + entries: list[DeprecationEntry], + version: str = "v1", + sort_by_date: bool = True, + ) -> None: + """Add multiple deprecation entries to the feed. + + Args: + entries: List of deprecation entries + version: Version of feed to add entries to + sort_by_date: Whether to sort entries by deprecation date (newest first) + """ + if sort_by_date: + # Sort oldest first because feedgen outputs in LIFO order + # This results in newest-first in the final output + entries = sorted( + entries, + key=lambda e: e.deprecation_date, + reverse=False, + ) + + for entry in entries: + self.add_entry(entry, version) + + def generate_rss(self, version: str = "v1") -> str: + """Generate RSS 2.0 XML string. + + Args: + version: Version of feed to generate + + Returns: + RSS 2.0 XML string + """ + feed = self.get_feed(version) + rss_bytes: bytes = feed.rss_str(pretty=True) + return rss_bytes.decode("utf-8") + + def save_feed( + self, + version: str = "v1", + output_path: Path | None = None, + ) -> Path: + """Save RSS feed to file. + + Args: + version: Version of feed to save + output_path: Custom output path. Uses config path if not provided. + + Returns: + Path where feed was saved + """ + feed = self.get_feed(version) + + if output_path is None: + self.config.output.ensure_directories(version) + output_path = self.config.output.get_versioned_path(version) + else: + output_path.parent.mkdir(parents=True, exist_ok=True) + + feed.rss_file(str(output_path)) + return output_path + + def clear_feed(self, version: str = "v1") -> None: + """Clear all entries from a feed. + + Args: + version: Version of feed to clear + """ + if version in self._feeds: + self._feeds[version] = self._create_feed(version) + + def clear_all_feeds(self) -> None: + """Clear all feeds.""" + self._feeds.clear() + + def get_entry_count(self, version: str = "v1") -> int: + """Get the number of entries in a feed. + + Args: + version: Version of feed to check + + Returns: + Number of entries in the feed + """ + if version not in self._feeds: + return 0 + + feed = self._feeds[version] + return len(feed.entry()) + + def validate_feed(self, version: str = "v1") -> bool: + """Validate that feed meets RSS 2.0 requirements. + + Args: + version: Version of feed to validate + + Returns: + True if feed is valid + + Raises: + ValueError: If feed is invalid with details + """ + feed = self.get_feed(version) + + if not feed.title(): + raise ValueError("Feed must have a title") + + if not feed.description(): + raise ValueError("Feed must have a description") + + if not feed.link(): + raise ValueError("Feed must have a link") + + for entry in feed.entry(): + if not entry.title() and not entry.description(): + raise ValueError("Each entry must have either title or description") + + if not entry.guid(): + raise ValueError("Each entry must have a GUID") + + return True diff --git a/src/scrapers/README.md b/src/scrapers/README.md new file mode 100644 index 0000000..fe471f2 --- /dev/null +++ b/src/scrapers/README.md @@ -0,0 +1,62 @@ +# Scrapers Module + +This module provides utilities and base classes for scraping AI model deprecation data from various providers. + +## Structure + +- `base_scraper.py` - Abstract base class with common scraping functionality +- `utils.py` - Utility functions for date parsing, text cleaning, and URL handling +- `example_scraper.py` - Example implementation showing how to create provider-specific scrapers + +## Creating a New Scraper + +To create a scraper for a new provider: + +1. Create a new file for your provider (e.g., `openai_scraper.py`) +2. Extend the `BaseScraper` class +3. Implement the `extract_deprecations()` method + +### Example + +```python +from src.scrapers.base_scraper import BaseScraper + +class OpenAIScraper(BaseScraper): + async def extract_deprecations(self) -> list[dict[str, Any]]: + # Fetch the deprecation page + html = await self.fetch(self.url) + soup = await self.parse_html(html) + + deprecations = [] + # Parse the HTML and extract deprecation data + # Use self.extract_text() and self.extract_date() helpers + + return deprecations +``` + +## Utilities + +### Date Parsing +- Handles ISO, RFC, and human-readable date formats +- Automatically adds UTC timezone if missing + +### Text Cleaning +- Removes HTML tags and entities +- Normalizes whitespace +- Optionally preserves line breaks + +### URL Handling +- Validates HTTP(S) URLs +- Normalizes URLs for consistent comparison + +## Error Handling + +The base scraper includes: +- Automatic retry with exponential backoff +- Configurable timeouts and retry attempts +- Proper exception handling and reporting + +## Testing + +Each scraper should have corresponding tests following the pytest-describe pattern. +See `tests/scrapers/` for examples. \ No newline at end of file diff --git a/src/scrapers/anthropic.py b/src/scrapers/anthropic.py index dc2c70f..4851ea3 100644 --- a/src/scrapers/anthropic.py +++ b/src/scrapers/anthropic.py @@ -8,7 +8,7 @@ import httpx from bs4 import BeautifulSoup -from src.models.deprecation import Deprecation +from src.models.deprecation import DeprecationEntry from src.scrapers.base import BaseScraper try: @@ -121,7 +121,7 @@ async def scrape_playwright(self) -> dict[str, Any]: logger.error(f"Playwright scraping failed: {e}") raise - def _parse_api_deprecation(self, item: dict[str, Any]) -> Deprecation | None: + def _parse_api_deprecation(self, item: dict[str, Any]) -> DeprecationEntry | None: """Parse a single deprecation from API response.""" try: model = item.get("model") @@ -139,21 +139,21 @@ def _parse_api_deprecation(self, item: dict[str, Any]) -> Deprecation | None: logger.warning(f"Invalid date ordering for {model}, skipping") return None - return Deprecation( + return DeprecationEntry( provider="Anthropic", model=model, deprecation_date=deprecation_date, retirement_date=retirement_date, replacement=item.get("replacement"), notes=item.get("notes"), - source_url=self.url, # type: ignore[arg-type] + source_url=self.url, ) except Exception as e: logger.warning(f"Failed to parse API deprecation: {e}") return None - def _parse_html_deprecations(self, soup: BeautifulSoup) -> list[Deprecation]: + def _parse_html_deprecations(self, soup: BeautifulSoup) -> list[DeprecationEntry]: """Parse deprecations from HTML content.""" deprecations = [] @@ -170,7 +170,7 @@ def _parse_html_deprecations(self, soup: BeautifulSoup) -> list[Deprecation]: return merged_deprecations - def _parse_status_table(self, soup: BeautifulSoup) -> list[Deprecation]: + def _parse_status_table(self, soup: BeautifulSoup) -> list[DeprecationEntry]: """Parse the main model status table.""" deprecations = [] @@ -220,7 +220,7 @@ def _parse_status_table(self, soup: BeautifulSoup) -> list[Deprecation]: logger.debug(f"Status table parsing found {len(deprecations)} deprecations") return deprecations - def _parse_status_table_row(self, row: Any) -> Deprecation | None: + def _parse_status_table_row(self, row: Any) -> DeprecationEntry | None: """Parse a single row from the status table.""" try: cells = row.find_all(["td", "th"]) @@ -260,19 +260,21 @@ def _parse_status_table_row(self, row: Any) -> Deprecation | None: logger.warning(f"Invalid date ordering for {model}, skipping") return None - return Deprecation( + return DeprecationEntry( provider="Anthropic", model=model, deprecation_date=deprecation_date, retirement_date=retirement_date, - source_url=self.url, # type: ignore[arg-type] + replacement=None, + notes=None, + source_url=self.url, ) except Exception as e: logger.warning(f"Failed to parse status table row: {e}") return None - def _parse_history_tables(self, soup: BeautifulSoup) -> list[Deprecation]: + def _parse_history_tables(self, soup: BeautifulSoup) -> list[DeprecationEntry]: """Parse deprecation history tables.""" deprecations = [] @@ -305,7 +307,7 @@ def _parse_history_tables(self, soup: BeautifulSoup) -> list[Deprecation]: return deprecations - def _parse_history_table_row(self, row: Any) -> Deprecation | None: + def _parse_history_table_row(self, row: Any) -> DeprecationEntry | None: """Parse a single row from a history table.""" try: cells = row.find_all(["td", "th"]) @@ -325,13 +327,14 @@ def _parse_history_table_row(self, row: Any) -> Deprecation | None: # deprecation date as retirement date minus 60 days as a fallback deprecation_date = self._estimate_deprecation_date(retirement_date) - return Deprecation( + return DeprecationEntry( provider="Anthropic", model=model, deprecation_date=deprecation_date, retirement_date=retirement_date, replacement=replacement, - source_url=self.url, # type: ignore[arg-type] + notes=None, + source_url=self.url, ) except Exception as e: @@ -344,10 +347,12 @@ def _estimate_deprecation_date(self, retirement_date: datetime) -> datetime: return retirement_date - timedelta(days=60) - def _merge_duplicate_deprecations(self, deprecations: list[Deprecation]) -> list[Deprecation]: + def _merge_duplicate_deprecations( + self, deprecations: list[DeprecationEntry] + ) -> list[DeprecationEntry]: """Merge duplicate deprecations, preferring ones with replacement info.""" # Group by model name - by_model: dict[str, list[Deprecation]] = {} + by_model: dict[str, list[DeprecationEntry]] = {} for dep in deprecations: if dep.model not in by_model: by_model[dep.model] = [] @@ -392,7 +397,7 @@ def _merge_duplicate_deprecations(self, deprecations: list[Deprecation]) -> list notes = dep.notes if dep.notes else best_dep.notes # Create merged deprecation - best_dep = Deprecation( + best_dep = DeprecationEntry( provider=best_dep.provider, model=best_dep.model, deprecation_date=deprecation_date, diff --git a/src/scrapers/base_scraper.py b/src/scrapers/base_scraper.py new file mode 100644 index 0000000..8237859 --- /dev/null +++ b/src/scrapers/base_scraper.py @@ -0,0 +1,186 @@ +"""Base scraper class with common functionality.""" + +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import httpx +from bs4 import BeautifulSoup, Tag + +from src.scrapers.utils import clean_text, parse_date + + +@dataclass +class ScraperConfig: + """Configuration for scraper behavior.""" + + timeout: int = 30 + max_retries: int = 3 + retry_delay: float = 1.0 + user_agent: str = field( + default="DeprecationsRSS/1.0 (+https://github.com/leblancfg/deprecations-rss)" + ) + + +class ScraperError(Exception): + """Base exception for scraper errors.""" + + pass + + +class ExtractionError(ScraperError): + """Raised when data extraction fails.""" + + pass + + +class BaseScraper(ABC): + """ + Abstract base class for all deprecation scrapers. + Provides common functionality: + - HTTP client with retry logic + - HTML parsing utilities + - Text extraction helpers + - Error handling + """ + + def __init__( + self, + url: str, + config: ScraperConfig | None = None, + ) -> None: + """ + Initialize scraper with target URL. + Args: + url: Base URL to scrape + config: Scraper configuration + """ + self.url = url + self.config = config or ScraperConfig() + self._client: httpx.AsyncClient | None = None + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create HTTP client.""" + if self._client is None: + self._client = httpx.AsyncClient( + timeout=self.config.timeout, + headers={"User-Agent": self.config.user_agent}, + follow_redirects=True, + ) + return self._client + + async def fetch(self, url: str) -> str: + """ + Fetch content from URL with retry logic. + Args: + url: URL to fetch + Returns: + Response text content + Raises: + ScraperError: If fetch fails after retries + """ + client = await self._get_client() + last_error: Exception | None = None + + for attempt in range(self.config.max_retries): + try: + response = await client.get(url) + response.raise_for_status() + return response.text + + except (httpx.HTTPError, httpx.HTTPStatusError) as e: + last_error = e + + if attempt < self.config.max_retries - 1: + # Exponential backoff + delay = self.config.retry_delay * (2**attempt) + await asyncio.sleep(delay) + continue + + raise ScraperError( + f"Failed after {self.config.max_retries} retries: {last_error}" + ) from last_error + + async def parse_html(self, html: str) -> BeautifulSoup: + """ + Parse HTML content into BeautifulSoup object. + Args: + html: HTML string to parse + Returns: + Parsed BeautifulSoup object + """ + return BeautifulSoup(html, "html.parser") + + def extract_text( + self, + element: Tag | None, + default: str = "", + ) -> str: + """ + Safely extract and clean text from element. + Args: + element: BeautifulSoup element + default: Default value if element is None + Returns: + Cleaned text content + """ + if element is None: + return default + + text = element.get_text() if hasattr(element, "get_text") else str(element) + return clean_text(text) + + def extract_date( + self, + element: Tag | None, + ) -> datetime | None: + """ + Extract and parse date from element. + Args: + element: Element containing date text + Returns: + Parsed datetime or None + """ + if element is None: + return None + + date_text = self.extract_text(element) + return parse_date(date_text, raise_on_error=False) + + @abstractmethod + async def extract_deprecations(self) -> list[dict[str, Any]]: + """ + Extract deprecation data from source. + + Must be implemented by subclasses. + Returns: + List of deprecation dictionaries with keys: + - provider: Provider name + - model: Model name + - announcement_date: When deprecation was announced + - retirement_date: When model will be retired + - replacement_model: Suggested replacement (optional) + - notes: Additional context (optional) + """ + pass + + async def run(self) -> list[dict[str, Any]]: + """ + Run the scraper and return results. + Returns: + List of extracted deprecation data + Raises: + ExtractionError: If extraction fails + """ + try: + return await self.extract_deprecations() + except Exception as e: + raise ExtractionError(f"Extraction failed: {e}") from e + + async def close(self) -> None: + """Close HTTP client and cleanup resources.""" + if self._client: + await self._client.aclose() + self._client = None diff --git a/src/scrapers/example_scraper.py b/src/scrapers/example_scraper.py new file mode 100644 index 0000000..e93e5d0 --- /dev/null +++ b/src/scrapers/example_scraper.py @@ -0,0 +1,66 @@ +"""Example scraper implementation for demonstration purposes.""" + +from typing import Any + +from src.scrapers.base_scraper import BaseScraper + + +class ExampleScraper(BaseScraper): + """ + Example scraper demonstrating how to extend BaseScraper. + + This is a template for provider-specific scrapers. + """ + + async def extract_deprecations(self) -> list[dict[str, Any]]: + """ + Extract deprecation data from the example provider. + + Returns: + List of deprecation dictionaries + """ + # Fetch the main page + html = await self.fetch(self.url) + soup = await self.parse_html(html) + + deprecations = [] + + # Example: Find all deprecation announcements + # This would be customized for each provider's HTML structure + announcements = soup.find_all("div", class_="deprecation-notice") + + for announcement in announcements: + # Extract model name + model_elem = announcement.find("h3", class_="model-name") + model = self.extract_text(model_elem) + + if not model: + continue + + # Extract dates + announcement_date_elem = announcement.find("time", class_="announced") + retirement_date_elem = announcement.find("time", class_="retirement") + + announcement_date = self.extract_date(announcement_date_elem) + retirement_date = self.extract_date(retirement_date_elem) + + # Extract replacement model + replacement_elem = announcement.find("span", class_="replacement") + replacement = self.extract_text(replacement_elem, default="") + + # Extract additional notes + notes_elem = announcement.find("p", class_="notes") + notes = self.extract_text(notes_elem, default="") + + deprecations.append( + { + "provider": "Example Provider", + "model": model, + "announcement_date": announcement_date, + "retirement_date": retirement_date, + "replacement_model": replacement if replacement else None, + "notes": notes if notes else None, + } + ) + + return deprecations diff --git a/src/scrapers/openai.py b/src/scrapers/openai.py index e2e694e..a845c11 100644 --- a/src/scrapers/openai.py +++ b/src/scrapers/openai.py @@ -8,7 +8,7 @@ import httpx from bs4 import BeautifulSoup -from src.models.deprecation import Deprecation +from src.models.deprecation import DeprecationEntry from src.scrapers.base import BaseScraper try: @@ -121,7 +121,7 @@ async def scrape_playwright(self) -> dict[str, Any]: logger.error(f"Playwright scraping failed: {e}") raise - def _parse_api_deprecation(self, item: dict[str, Any]) -> Deprecation | None: + def _parse_api_deprecation(self, item: dict[str, Any]) -> DeprecationEntry | None: """Parse a single deprecation from API response.""" try: model = item.get("model") @@ -141,21 +141,21 @@ def _parse_api_deprecation(self, item: dict[str, Any]) -> Deprecation | None: logger.warning(f"Invalid date ordering for {model}, skipping") return None - return Deprecation( + return DeprecationEntry( provider="OpenAI", model=model, deprecation_date=deprecation_date, retirement_date=retirement_date, replacement=item.get("replacement"), notes=item.get("notes"), - source_url=self.url, # type: ignore[arg-type] + source_url=self.url, ) except Exception as e: logger.warning(f"Failed to parse API deprecation: {e}") return None - def _parse_html_deprecations(self, soup: BeautifulSoup) -> list[Deprecation]: + def _parse_html_deprecations(self, soup: BeautifulSoup) -> list[DeprecationEntry]: """Parse deprecations from HTML content.""" deprecations = [] @@ -240,7 +240,7 @@ def _contains_deprecation_info(self, element: Any) -> bool: return has_keyword and has_date - def _parse_deprecation_block(self, block: Any) -> Deprecation | None: + def _parse_deprecation_block(self, block: Any) -> DeprecationEntry | None: """Parse a single deprecation block from HTML.""" try: text = block.get_text() @@ -290,14 +290,14 @@ def _parse_deprecation_block(self, block: Any) -> Deprecation | None: # Extract notes notes = self._extract_notes(text) - return Deprecation( + return DeprecationEntry( provider="OpenAI", model=model, deprecation_date=deprecation_date, retirement_date=retirement_date, replacement=replacement, notes=notes, - source_url=self.url, # type: ignore[arg-type] + source_url=self.url, ) except Exception as e: diff --git a/src/scrapers/utils.py b/src/scrapers/utils.py new file mode 100644 index 0000000..baaeb5c --- /dev/null +++ b/src/scrapers/utils.py @@ -0,0 +1,172 @@ +"""Utility functions for scraping and parsing.""" + +import html +from datetime import UTC, datetime +from urllib.parse import urlparse, urlunparse + +from bs4 import BeautifulSoup +from dateutil import parser as date_parser + + +def parse_date( + date_str: str | None, + raise_on_error: bool = True, +) -> datetime | None: + """ + Parse various date formats into datetime objects. + + Handles: + - ISO format (2024-03-15T10:30:00Z) + - RFC format (Wed, 15 Mar 2024 10:30:00 GMT) + - Human-readable formats (March 15, 2024, etc.) + + Args: + date_str: String to parse as date + raise_on_error: Whether to raise exception on parse error + + Returns: + Parsed datetime with timezone info, or None if parsing fails + + Raises: + ValueError: If date cannot be parsed and raise_on_error is True + """ + if not date_str: + if raise_on_error: + raise ValueError("Could not parse date: empty string") + return None + + try: + # Use dateutil parser which handles many formats + dt = date_parser.parse(date_str) + + # Ensure timezone awareness + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + + return dt + except (ValueError, TypeError) as e: + if raise_on_error: + raise ValueError(f"Could not parse date: {date_str}") from e + return None + + +def clean_text( + text: str | None, + preserve_lines: bool = False, +) -> str: + """ + Clean and normalize extracted text. + + - Removes HTML tags + - Decodes HTML entities + - Normalizes whitespace + - Trims leading/trailing space + + Args: + text: Text to clean + preserve_lines: Whether to preserve line breaks + + Returns: + Cleaned text string + """ + if not text: + return "" + + # Remove HTML tags + soup = BeautifulSoup(text, "html.parser") + text = soup.get_text() + + # Decode HTML entities + text = html.unescape(text) + + if preserve_lines: + # Normalize spaces within lines but preserve line breaks + lines = text.split("\n") + lines = [" ".join(line.split()) for line in lines] + text = "\n".join(lines) + else: + # Normalize all whitespace + text = " ".join(text.split()) + + return text.strip() + + +def validate_url(url: str | None, require_https: bool = False) -> bool: + """ + Validate that a string is a valid HTTP(S) URL. + + Args: + url: URL string to validate + require_https: Whether to require HTTPS scheme + + Returns: + True if valid URL, False otherwise + """ + if not url: + return False + + try: + parsed = urlparse(url) + + # Check basic requirements + if not parsed.scheme or not parsed.netloc: + return False + + # Check scheme + if require_https: + return parsed.scheme == "https" + else: + return parsed.scheme in ("http", "https") + + except Exception: + return False + + +def normalize_url(url: str | None) -> str: + """ + Normalize a URL for consistent comparison. + + - Adds missing scheme (defaults to https) + - Lowercases domain + - Removes trailing slashes from path + + Args: + url: URL to normalize + + Returns: + Normalized URL string, or empty string if invalid + """ + if not url: + return "" + + # Add scheme if missing + if not url.startswith(("http://", "https://")): + url = f"https://{url}" + + try: + parsed = urlparse(url) + + # Lowercase the domain + netloc = parsed.netloc.lower() + + # Remove trailing slash from path + path = parsed.path.rstrip("/") + if not path: + path = "" + + # Reconstruct URL + normalized = urlunparse( + ( + parsed.scheme, + netloc, + path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + + return normalized + + except Exception: + return "" diff --git a/tests/models/test_deprecation.py b/tests/models/test_deprecation.py new file mode 100644 index 0000000..1c58874 --- /dev/null +++ b/tests/models/test_deprecation.py @@ -0,0 +1,279 @@ +"""Tests for deprecation model.""" + +from datetime import datetime + +import pytest +from pydantic import ValidationError + +from src.models.deprecation import DeprecationEntry + + +class DescribeDeprecationEntry: + """Tests for DeprecationEntry model.""" + + def it_creates_valid_entry_with_required_fields(self) -> None: + """Test creating entry with only required fields.""" + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://openai.com/blog", + ) + + assert entry.provider == "OpenAI" + assert entry.model == "gpt-3.5-turbo" + assert entry.deprecation_date == datetime(2024, 1, 1) + assert entry.retirement_date == datetime(2024, 6, 1) + assert entry.replacement is None + assert entry.notes is None + assert entry.source_url == "https://openai.com/blog" + + def it_creates_valid_entry_with_all_fields(self) -> None: + """Test creating entry with all fields.""" + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + replacement="gpt-4-turbo", + notes="Model being retired due to newer version availability", + source_url="https://openai.com/blog/deprecation", + ) + + assert entry.provider == "OpenAI" + assert entry.model == "gpt-3.5-turbo" + assert entry.replacement == "gpt-4-turbo" + assert entry.notes == "Model being retired due to newer version availability" + assert entry.source_url == "https://openai.com/blog/deprecation" + + def it_validates_retirement_date_after_deprecation(self) -> None: + """Test that retirement date must be after deprecation date.""" + with pytest.raises(ValidationError) as exc_info: + DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 6, 1), + retirement_date=datetime(2024, 1, 1), + source_url="https://example.com", + ) + + errors = exc_info.value.errors() + assert len(errors) == 1 + assert "Retirement date must be after deprecation date" in errors[0]["msg"] + + def it_validates_retirement_date_not_equal_to_deprecation(self) -> None: + """Test that retirement date cannot equal deprecation date.""" + with pytest.raises(ValidationError) as exc_info: + DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 1, 1), + source_url="https://example.com", + ) + + errors = exc_info.value.errors() + assert len(errors) == 1 + assert "Retirement date must be after deprecation date" in errors[0]["msg"] + + def it_validates_non_empty_provider(self) -> None: + """Test that provider must be non-empty.""" + with pytest.raises(ValidationError) as exc_info: + DeprecationEntry( + provider="", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + ) + + errors = exc_info.value.errors() + assert any("Field must be a non-empty string" in e["msg"] for e in errors) + + def it_validates_non_empty_model(self) -> None: + """Test that model must be non-empty.""" + with pytest.raises(ValidationError) as exc_info: + DeprecationEntry( + provider="OpenAI", + model=" ", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + ) + + errors = exc_info.value.errors() + assert any("Field must be a non-empty string" in e["msg"] for e in errors) + + def it_validates_url_format(self) -> None: + """Test URL validation.""" + with pytest.raises(ValidationError) as exc_info: + DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="not-a-url", + ) + + errors = exc_info.value.errors() + assert any("URL must start with http:// or https://" in e["msg"] for e in errors) + + def it_accepts_valid_urls(self) -> None: + """Test that valid URLs are accepted.""" + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com/deprecation", + ) + assert entry.source_url == "https://example.com/deprecation" + + entry2 = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="http://example.com/deprecation", + ) + assert entry2.source_url == "http://example.com/deprecation" + + def it_strips_whitespace_from_strings(self) -> None: + """Test that string fields are stripped of whitespace.""" + entry = DeprecationEntry( + provider=" OpenAI ", + model=" gpt-3.5-turbo ", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ) + + assert entry.provider == "OpenAI" + assert entry.model == "gpt-3.5-turbo" + + +class DescribeDeprecationEntryRSSConversion: + """Tests for RSS conversion methods.""" + + def it_converts_to_rss_item_with_required_fields(self) -> None: + """Test conversion to RSS item with only required fields.""" + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1, 12, 0, 0), + retirement_date=datetime(2024, 6, 1, 12, 0, 0), + source_url="https://example.com", + ) + + rss_item = entry.to_rss_item() + + assert rss_item["title"] == "OpenAI - gpt-3.5-turbo Deprecation" + assert "Provider: OpenAI" in rss_item["description"] + assert "Model: gpt-3.5-turbo" in rss_item["description"] + assert "Deprecation Date: 2024-01-01T12:00:00" in rss_item["description"] + assert "Retirement Date: 2024-06-01T12:00:00" in rss_item["description"] + assert rss_item["link"] == "https://example.com" + assert rss_item["guid"] == "OpenAI-gpt-3.5-turbo-2024-01-01T12:00:00" + assert rss_item["pubDate"] == datetime(2024, 1, 1, 12, 0, 0) + + def it_converts_to_rss_item_with_all_fields(self) -> None: + """Test conversion to RSS item with all fields.""" + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + replacement="gpt-4-turbo", + notes="Upgrade recommended", + source_url="https://openai.com/blog", + ) + + rss_item = entry.to_rss_item() + + assert "Replacement: gpt-4-turbo" in rss_item["description"] + assert "Notes: Upgrade recommended" in rss_item["description"] + assert rss_item["link"] == "https://openai.com/blog" + + def it_generates_unique_guid_for_same_model_different_dates(self) -> None: + """Test that GUID is unique for same model deprecated at different times.""" + entry1 = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ) + + entry2 = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 2, 1), + retirement_date=datetime(2024, 7, 1), + source_url="https://example.com", + ) + + assert entry1.to_rss_item()["guid"] != entry2.to_rss_item()["guid"] + + +class DescribeDeprecationEntryJSONConversion: + """Tests for JSON conversion methods.""" + + def it_converts_to_json_dict_with_required_fields(self) -> None: + """Test conversion to JSON dict with only required fields.""" + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1, 12, 0, 0), + retirement_date=datetime(2024, 6, 1, 12, 0, 0), + source_url="https://example.com", + ) + + json_dict = entry.to_json_dict() + + # Check all fields except last_updated which is dynamic + assert json_dict["provider"] == "OpenAI" + assert json_dict["model"] == "gpt-3.5-turbo" + assert json_dict["deprecation_date"] == "2024-01-01T12:00:00" + assert json_dict["retirement_date"] == "2024-06-01T12:00:00" + assert json_dict["replacement"] is None + assert json_dict["notes"] is None + assert json_dict["source_url"] == "https://example.com" + assert "last_updated" in json_dict # Dynamic field + + def it_converts_to_json_dict_with_all_fields(self) -> None: + """Test conversion to JSON dict with all fields.""" + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + replacement="gpt-4-turbo", + notes="Upgrade recommended", + source_url="https://openai.com/blog", + ) + + json_dict = entry.to_json_dict() + + assert json_dict["replacement"] == "gpt-4-turbo" + assert json_dict["notes"] == "Upgrade recommended" + assert json_dict["source_url"] == "https://openai.com/blog" + + def it_produces_serializable_output(self) -> None: + """Test that JSON dict can be serialized.""" + import json + + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + replacement="gpt-4-turbo", + source_url="https://example.com", + ) + + json_dict = entry.to_json_dict() + json_str = json.dumps(json_dict) + parsed = json.loads(json_str) + + assert parsed["provider"] == "OpenAI" + assert parsed["model"] == "gpt-3.5-turbo" + assert parsed["replacement"] == "gpt-4-turbo" diff --git a/tests/rss/test_config.py b/tests/rss/test_config.py new file mode 100644 index 0000000..3bd6ac3 --- /dev/null +++ b/tests/rss/test_config.py @@ -0,0 +1,286 @@ +"""Tests for RSS configuration.""" + +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from src.rss.config import ( + FeedConfig, + OutputConfig, + RSSConfig, + VersionConfig, + get_default_config, +) + + +class DescribeFeedConfig: + """Tests for FeedConfig.""" + + def it_creates_with_default_values(self) -> None: + """Test creation with default values.""" + config = FeedConfig() + + assert config.title == "AI Model Deprecations" + assert ( + config.description + == "Daily-updated RSS feed tracking AI model deprecations across providers" + ) + assert config.link == "https://deprecations.example.com" + assert config.language == "en" + assert config.copyright is None + assert config.managing_editor is None + assert config.webmaster is None + assert config.ttl == 1440 + + def it_creates_with_custom_values(self) -> None: + """Test creation with custom values.""" + config = FeedConfig( + title="Custom Title", + description="Custom Description", + link="https://custom.example.com", + language="fr", + copyright="© 2024 Example Corp", + managing_editor="editor@example.com", + webmaster="webmaster@example.com", + ttl=720, + ) + + assert config.title == "Custom Title" + assert config.description == "Custom Description" + assert config.link == "https://custom.example.com" + assert config.language == "fr" + assert config.copyright == "© 2024 Example Corp" + assert config.managing_editor == "editor@example.com" + assert config.webmaster == "webmaster@example.com" + assert config.ttl == 720 + + def it_validates_positive_ttl(self) -> None: + """Test that TTL must be positive.""" + with pytest.raises(ValidationError) as exc_info: + FeedConfig(ttl=0) + + errors = exc_info.value.errors() + assert any("greater than 0" in str(e) for e in errors) + + with pytest.raises(ValidationError): + FeedConfig(ttl=-1) + + def it_allows_assignment_validation(self) -> None: + """Test that assignment validation works.""" + config = FeedConfig() + config.title = "New Title" + assert config.title == "New Title" + + with pytest.raises(ValidationError): + config.ttl = -1 + + +class DescribeVersionConfig: + """Tests for VersionConfig.""" + + def it_creates_with_default_values(self) -> None: + """Test creation with default values.""" + config = VersionConfig() + + assert config.version == "v1" + assert config.supported_versions == ["v1"] + + def it_creates_with_custom_values(self) -> None: + """Test creation with custom values.""" + config = VersionConfig( + version="v2", + supported_versions=["v1", "v2", "v3"], + ) + + assert config.version == "v2" + assert config.supported_versions == ["v1", "v2", "v3"] + + def it_validates_version_pattern(self) -> None: + """Test version pattern validation.""" + with pytest.raises(ValidationError) as exc_info: + VersionConfig(version="version1") + + errors = exc_info.value.errors() + assert any("string_pattern_mismatch" in e["type"] for e in errors) + + with pytest.raises(ValidationError): + VersionConfig(version="1") + + config = VersionConfig(version="v123") + assert config.version == "v123" + + def it_checks_version_support(self) -> None: + """Test version support checking.""" + config = VersionConfig( + version="v1", + supported_versions=["v1", "v2"], + ) + + assert config.is_version_supported("v1") + assert config.is_version_supported("v2") + assert not config.is_version_supported("v3") + assert not config.is_version_supported("invalid") + + +class DescribeOutputConfig: + """Tests for OutputConfig.""" + + def it_creates_with_default_values(self) -> None: + """Test creation with default values.""" + config = OutputConfig() + + assert config.base_path == Path("output/rss") + assert config.filename == "feed.xml" + + def it_creates_with_custom_values(self) -> None: + """Test creation with custom values.""" + config = OutputConfig( + base_path=Path("/custom/path"), + filename="custom_feed.xml", + ) + + assert config.base_path == Path("/custom/path") + assert config.filename == "custom_feed.xml" + + def it_generates_versioned_path(self) -> None: + """Test versioned path generation.""" + config = OutputConfig( + base_path=Path("output"), + filename="feed.xml", + ) + + path = config.get_versioned_path("v1") + assert path == Path("output/v1/feed.xml") + + path = config.get_versioned_path("v2") + assert path == Path("output/v2/feed.xml") + + def it_ensures_directories_exist(self, tmp_path: Path) -> None: + """Test directory creation.""" + config = OutputConfig( + base_path=tmp_path / "test_output", + filename="feed.xml", + ) + + version_dir = tmp_path / "test_output" / "v1" + assert not version_dir.exists() + + config.ensure_directories("v1") + assert version_dir.exists() + assert version_dir.is_dir() + + config.ensure_directories("v1") + assert version_dir.exists() + + +class DescribeRSSConfig: + """Tests for RSSConfig.""" + + def it_creates_with_default_values(self) -> None: + """Test creation with default values.""" + config = RSSConfig() + + assert isinstance(config.feed, FeedConfig) + assert isinstance(config.version, VersionConfig) + assert isinstance(config.output, OutputConfig) + + assert config.feed.title == "AI Model Deprecations" + assert config.version.version == "v1" + assert config.output.filename == "feed.xml" + + def it_creates_with_custom_values(self) -> None: + """Test creation with custom values.""" + config = RSSConfig( + feed=FeedConfig(title="Custom Feed"), + version=VersionConfig(version="v2"), + output=OutputConfig(filename="custom.xml"), + ) + + assert config.feed.title == "Custom Feed" + assert config.version.version == "v2" + assert config.output.filename == "custom.xml" + + def it_creates_from_dict(self) -> None: + """Test creation from dictionary.""" + config_dict = { + "feed": { + "title": "Dict Title", + "description": "Dict Description", + "ttl": 360, + }, + "version": { + "version": "v3", + "supported_versions": ["v1", "v2", "v3"], + }, + "output": { + "base_path": "/custom/path", + "filename": "dict_feed.xml", + }, + } + + config = RSSConfig.from_dict(config_dict) + + assert config.feed.title == "Dict Title" + assert config.feed.description == "Dict Description" + assert config.feed.ttl == 360 + assert config.version.version == "v3" + assert config.version.supported_versions == ["v1", "v2", "v3"] + assert config.output.base_path == Path("/custom/path") + assert config.output.filename == "dict_feed.xml" + + def it_creates_from_partial_dict(self) -> None: + """Test creation from partial dictionary uses defaults.""" + config_dict = {"feed": {"title": "Partial Title"}} + + config = RSSConfig.from_dict(config_dict) + + assert config.feed.title == "Partial Title" + assert ( + config.feed.description + == "Daily-updated RSS feed tracking AI model deprecations across providers" + ) + assert config.version.version == "v1" + assert config.output.filename == "feed.xml" + + def it_converts_to_dict(self) -> None: + """Test conversion to dictionary.""" + config = RSSConfig( + feed=FeedConfig(title="Test Title", ttl=480), + version=VersionConfig(version="v2"), + output=OutputConfig( + base_path=Path("/test/path"), + filename="test.xml", + ), + ) + + config_dict = config.to_dict() + + assert config_dict["feed"]["title"] == "Test Title" + assert config_dict["feed"]["ttl"] == 480 + assert config_dict["version"]["version"] == "v2" + assert config_dict["output"]["base_path"] == "/test/path" + assert config_dict["output"]["filename"] == "test.xml" + + +class DescribeGetDefaultConfig: + """Tests for get_default_config function.""" + + def it_returns_default_config(self) -> None: + """Test that get_default_config returns expected defaults.""" + config = get_default_config() + + assert isinstance(config, RSSConfig) + assert config.feed.title == "AI Model Deprecations" + assert config.version.version == "v1" + assert config.output.base_path == Path("output/rss") + + def it_returns_new_instance_each_time(self) -> None: + """Test that get_default_config returns new instances.""" + config1 = get_default_config() + config2 = get_default_config() + + assert config1 is not config2 + + config1.feed.title = "Modified Title" + assert config2.feed.title == "AI Model Deprecations" diff --git a/tests/rss/test_generator.py b/tests/rss/test_generator.py new file mode 100644 index 0000000..4db0281 --- /dev/null +++ b/tests/rss/test_generator.py @@ -0,0 +1,514 @@ +"""Tests for RSS generator.""" + +import xml.etree.ElementTree as ET +from datetime import datetime +from pathlib import Path + +import pytest + +from src.models.deprecation import DeprecationEntry +from src.rss.config import FeedConfig, OutputConfig, RSSConfig, VersionConfig +from src.rss.generator import RSSGenerator + + +class DescribeRSSGenerator: + """Tests for RSSGenerator initialization.""" + + def it_creates_with_default_config(self) -> None: + """Test creation with default configuration.""" + generator = RSSGenerator() + + assert generator.config is not None + assert generator.config.feed.title == "AI Model Deprecations" + assert generator.config.version.version == "v1" + + def it_creates_with_custom_config(self) -> None: + """Test creation with custom configuration.""" + config = RSSConfig( + feed=FeedConfig(title="Custom Feed"), + version=VersionConfig(version="v2"), + ) + generator = RSSGenerator(config) + + assert generator.config.feed.title == "Custom Feed" + assert generator.config.version.version == "v2" + + +class DescribeFeedManagement: + """Tests for feed creation and management.""" + + def it_creates_feed_on_demand(self) -> None: + """Test that feeds are created on demand.""" + generator = RSSGenerator() + + assert len(generator._feeds) == 0 + + feed = generator.get_feed("v1") + assert feed is not None + assert len(generator._feeds) == 1 + assert "v1" in generator._feeds + + def it_reuses_existing_feed(self) -> None: + """Test that existing feeds are reused.""" + generator = RSSGenerator() + + feed1 = generator.get_feed("v1") + feed2 = generator.get_feed("v1") + + assert feed1 is feed2 + assert len(generator._feeds) == 1 + + def it_creates_multiple_version_feeds(self) -> None: + """Test creation of multiple version feeds.""" + config = RSSConfig( + version=VersionConfig( + version="v1", + supported_versions=["v1", "v2"], + ) + ) + generator = RSSGenerator(config) + + feed1 = generator.get_feed("v1") + feed2 = generator.get_feed("v2") + + assert feed1 is not feed2 + assert len(generator._feeds) == 2 + assert "v1" in generator._feeds + assert "v2" in generator._feeds + + def it_raises_error_for_unsupported_version(self) -> None: + """Test error for unsupported version.""" + generator = RSSGenerator() + + with pytest.raises(ValueError) as exc_info: + generator.get_feed("v99") + + assert "Version v99 not supported" in str(exc_info.value) + assert "Supported versions: v1" in str(exc_info.value) + + def it_configures_feed_metadata(self) -> None: + """Test that feed metadata is properly configured.""" + config = RSSConfig( + feed=FeedConfig( + title="Test Feed", + description="Test Description", + link="https://test.example.com", + language="fr", + copyright="© Test", + managing_editor="editor@test.com", + webmaster="webmaster@test.com", + ttl=360, + ) + ) + generator = RSSGenerator(config) + + feed = generator.get_feed("v1") + + rss_str = feed.rss_str(pretty=True).decode("utf-8") + + assert "Test Feed" in rss_str + assert "Test Description" in rss_str + assert "https://test.example.com" in rss_str + assert "fr" in rss_str + assert "© Test" in rss_str + assert "editor@test.com" in rss_str + assert "webmaster@test.com" in rss_str + assert "360" in rss_str + + +class DescribeAddingEntries: + """Tests for adding deprecation entries.""" + + def it_adds_single_entry(self) -> None: + """Test adding a single entry.""" + generator = RSSGenerator() + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://openai.com/blog", + ) + + generator.add_entry(entry) + + assert generator.get_entry_count("v1") == 1 + + rss_str = generator.generate_rss("v1") + assert "OpenAI - gpt-3.5-turbo Deprecation" in rss_str + assert "Provider: OpenAI" in rss_str + assert "Model: gpt-3.5-turbo" in rss_str + + def it_adds_entry_with_all_fields(self) -> None: + """Test adding entry with all fields.""" + generator = RSSGenerator() + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + replacement="gpt-4-turbo", + notes="Please migrate soon", + source_url="https://openai.com/deprecation", + ) + + generator.add_entry(entry) + + rss_str = generator.generate_rss("v1") + assert "Replacement: gpt-4-turbo" in rss_str + assert "Notes: Please migrate soon" in rss_str + assert "https://openai.com/deprecation" in rss_str + + def it_adds_multiple_entries(self) -> None: + """Test adding multiple entries.""" + generator = RSSGenerator() + entries = [ + DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ), + DeprecationEntry( + provider="Anthropic", + model="claude-1", + deprecation_date=datetime(2024, 2, 1), + retirement_date=datetime(2024, 7, 1), + source_url="https://example.com", + ), + DeprecationEntry( + provider="Google", + model="palm-2", + deprecation_date=datetime(2024, 3, 1), + retirement_date=datetime(2024, 8, 1), + source_url="https://example.com", + ), + ] + + generator.add_entries(entries) + + assert generator.get_entry_count("v1") == 3 + + rss_str = generator.generate_rss("v1") + assert "OpenAI - gpt-3.5-turbo" in rss_str + assert "Anthropic - claude-1" in rss_str + assert "Google - palm-2" in rss_str + + def it_sorts_entries_by_date_newest_first(self) -> None: + """Test that entries are sorted by deprecation date (newest first).""" + generator = RSSGenerator() + entries = [ + DeprecationEntry( + provider="OpenAI", + model="old-model", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ), + DeprecationEntry( + provider="Anthropic", + model="newest-model", + deprecation_date=datetime(2024, 3, 1), + retirement_date=datetime(2024, 8, 1), + source_url="https://example.com", + ), + DeprecationEntry( + provider="Google", + model="middle-model", + deprecation_date=datetime(2024, 2, 1), + retirement_date=datetime(2024, 7, 1), + source_url="https://example.com", + ), + ] + + generator.add_entries(entries, sort_by_date=True) + + rss_str = generator.generate_rss("v1") + + # Check items appear in the XML in newest-first order + # Note: feedgen outputs entries in LIFO order, so we need to check + # that when sorted newest-first and added, they appear correctly + import re + + items = re.findall(r"Model: (.*?)\n", rss_str) + assert items == ["newest-model", "middle-model", "old-model"] + + def it_adds_without_sorting_when_specified(self) -> None: + """Test adding entries without sorting.""" + generator = RSSGenerator() + entries = [ + DeprecationEntry( + provider="OpenAI", + model="first", + deprecation_date=datetime(2024, 3, 1), + retirement_date=datetime(2024, 8, 1), + source_url="https://example.com", + ), + DeprecationEntry( + provider="Anthropic", + model="second", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ), + ] + + generator.add_entries(entries, sort_by_date=False) + + rss_str = generator.generate_rss("v1") + + # feedgen outputs in LIFO order, so second added appears first + import re + + items = re.findall(r"Model: (.*?)\n", rss_str) + assert items == ["second", "first"] + + +class DescribeRSSGeneration: + """Tests for RSS generation.""" + + def it_generates_valid_rss_xml(self) -> None: + """Test that generated RSS is valid XML.""" + generator = RSSGenerator() + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://openai.com/blog", + ) + + generator.add_entry(entry) + rss_str = generator.generate_rss("v1") + + root = ET.fromstring(rss_str) + assert root.tag == "rss" + assert root.get("version") == "2.0" + + channel = root.find("channel") + assert channel is not None + + title = channel.find("title") + assert title is not None + assert title.text == "AI Model Deprecations" + + items = channel.findall("item") + assert len(items) == 1 + + def it_includes_all_required_rss_fields(self) -> None: + """Test that all required RSS fields are included.""" + generator = RSSGenerator() + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ) + + generator.add_entry(entry) + rss_str = generator.generate_rss("v1") + + root = ET.fromstring(rss_str) + channel = root.find("channel") + + assert channel.find("title") is not None + assert channel.find("description") is not None + assert channel.find("link") is not None + + item = channel.find("item") + assert item is not None + + assert item.find("title") is not None + assert item.find("description") is not None + assert item.find("guid") is not None + assert item.find("pubDate") is not None + assert item.find("link") is not None + + def it_generates_empty_feed(self) -> None: + """Test generating empty feed.""" + generator = RSSGenerator() + rss_str = generator.generate_rss("v1") + + root = ET.fromstring(rss_str) + channel = root.find("channel") + items = channel.findall("item") + + assert len(items) == 0 + + +class DescribeSavingFeeds: + """Tests for saving feeds to disk.""" + + def it_saves_feed_to_default_path(self, tmp_path: Path) -> None: + """Test saving feed to default configured path.""" + config = RSSConfig( + output=OutputConfig( + base_path=tmp_path / "output", + filename="test_feed.xml", + ) + ) + generator = RSSGenerator(config) + + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://openai.com/blog", + ) + generator.add_entry(entry) + + saved_path = generator.save_feed("v1") + + assert saved_path == tmp_path / "output" / "v1" / "test_feed.xml" + assert saved_path.exists() + + content = saved_path.read_text() + assert "OpenAI - gpt-3.5-turbo" in content + + def it_saves_feed_to_custom_path(self, tmp_path: Path) -> None: + """Test saving feed to custom path.""" + generator = RSSGenerator() + + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://openai.com/blog", + ) + generator.add_entry(entry) + + custom_path = tmp_path / "custom" / "feed.xml" + saved_path = generator.save_feed("v1", custom_path) + + assert saved_path == custom_path + assert saved_path.exists() + + content = saved_path.read_text() + assert "OpenAI - gpt-3.5-turbo" in content + + def it_creates_directories_if_not_exist(self, tmp_path: Path) -> None: + """Test that directories are created if they don't exist.""" + config = RSSConfig( + output=OutputConfig( + base_path=tmp_path / "new" / "nested" / "path", + filename="feed.xml", + ) + ) + generator = RSSGenerator(config) + + assert not (tmp_path / "new").exists() + + saved_path = generator.save_feed("v1") + + assert saved_path.exists() + assert saved_path.parent.exists() + + +class DescribeFeedManipulation: + """Tests for feed manipulation operations.""" + + def it_clears_single_feed(self) -> None: + """Test clearing a single feed.""" + generator = RSSGenerator() + + entries = [ + DeprecationEntry( + provider="OpenAI", + model=f"model-{i}", + deprecation_date=datetime(2024, 1, i), + retirement_date=datetime(2024, 6, i), + source_url="https://example.com", + ) + for i in range(1, 4) + ] + generator.add_entries(entries) + + assert generator.get_entry_count("v1") == 3 + + generator.clear_feed("v1") + + assert generator.get_entry_count("v1") == 0 + + def it_clears_all_feeds(self) -> None: + """Test clearing all feeds.""" + config = RSSConfig( + version=VersionConfig( + supported_versions=["v1", "v2"], + ) + ) + generator = RSSGenerator(config) + + entry = DeprecationEntry( + provider="OpenAI", + model="test-model", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ) + + generator.add_entry(entry, "v1") + generator.add_entry(entry, "v2") + + assert len(generator._feeds) == 2 + + generator.clear_all_feeds() + + assert len(generator._feeds) == 0 + + def it_counts_entries_correctly(self) -> None: + """Test entry counting.""" + generator = RSSGenerator() + + assert generator.get_entry_count("v1") == 0 + + entry = DeprecationEntry( + provider="OpenAI", + model="test-model", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://example.com", + ) + + generator.add_entry(entry) + assert generator.get_entry_count("v1") == 1 + + generator.add_entry(entry) + assert generator.get_entry_count("v1") == 2 + + generator.clear_feed("v1") + assert generator.get_entry_count("v1") == 0 + + +class DescribeFeedValidation: + """Tests for feed validation.""" + + def it_validates_complete_feed(self) -> None: + """Test validation of complete feed.""" + generator = RSSGenerator() + + entry = DeprecationEntry( + provider="OpenAI", + model="gpt-3.5-turbo", + deprecation_date=datetime(2024, 1, 1), + retirement_date=datetime(2024, 6, 1), + source_url="https://openai.com/blog", + ) + generator.add_entry(entry) + + assert generator.validate_feed("v1") is True + + def it_validates_empty_feed(self) -> None: + """Test validation of empty feed (should be valid).""" + generator = RSSGenerator() + + assert generator.validate_feed("v1") is True + + def it_validates_feed_has_required_metadata(self) -> None: + """Test that feed validation checks required metadata.""" + generator = RSSGenerator() + + assert generator.validate_feed("v1") is True diff --git a/tests/scrapers/test_base_scraper.py b/tests/scrapers/test_base_scraper.py new file mode 100644 index 0000000..0502488 --- /dev/null +++ b/tests/scrapers/test_base_scraper.py @@ -0,0 +1,240 @@ +"""Tests for base scraper functionality.""" + +import asyncio +from datetime import UTC, datetime +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import pytest + +from src.scrapers.base_scraper import ( + BaseScraper, + ExtractionError, + ScraperConfig, + ScraperError, +) + + +class MockScraper(BaseScraper): + """Test implementation of BaseScraper.""" + + async def extract_deprecations(self) -> list[dict]: + """Test extraction method.""" + return [ + { + "provider": "test", + "model": "test-model", + "announcement_date": datetime.now(UTC), + "retirement_date": datetime.now(UTC), + } + ] + + +class DescribeScraperConfig: + """Test ScraperConfig functionality.""" + + def it_has_sensible_defaults(self): + config = ScraperConfig() + assert config.timeout == 30 + assert config.max_retries == 3 + assert config.retry_delay == 1.0 + assert config.user_agent.startswith("DeprecationsRSS") + + def it_accepts_custom_values(self): + config = ScraperConfig( + timeout=60, + max_retries=5, + retry_delay=2.0, + user_agent="CustomAgent/1.0", + ) + assert config.timeout == 60 + assert config.max_retries == 5 + assert config.retry_delay == 2.0 + assert config.user_agent == "CustomAgent/1.0" + + +class DescribeBaseScraper: + """Test BaseScraper functionality.""" + + @pytest.fixture + def scraper(self): + return MockScraper("https://example.com") + + def it_initializes_with_url(self, scraper): + assert scraper.url == "https://example.com" + assert isinstance(scraper.config, ScraperConfig) + assert scraper._client is None + + def it_initializes_with_custom_config(self): + config = ScraperConfig(timeout=60) + scraper = MockScraper("https://example.com", config=config) + assert scraper.config.timeout == 60 + + @pytest.mark.asyncio + async def it_creates_client_on_demand(self, scraper): + client = await scraper._get_client() + assert isinstance(client, httpx.AsyncClient) + assert str(client.timeout) == "Timeout(timeout=30)" # Default timeout + + @pytest.mark.asyncio + async def it_fetches_content_successfully(self, scraper): + mock_response = Mock( + status_code=200, + text="Test content", + raise_for_status=Mock(), + ) + + with patch.object(httpx.AsyncClient, "get", new_callable=AsyncMock) as mock_get: + mock_get.return_value = mock_response + content = await scraper.fetch("https://example.com/page") + + assert content == "Test content" + mock_get.assert_called_once() + + @pytest.mark.asyncio + async def it_retries_on_failure(self, scraper): + scraper.config.max_retries = 2 + scraper.config.retry_delay = 0.1 + + mock_response_fail = Mock( + status_code=500, + raise_for_status=Mock( + side_effect=httpx.HTTPStatusError("Server error", request=Mock(), response=Mock()) + ), + ) + mock_response_success = Mock( + status_code=200, + text="Success", + raise_for_status=Mock(), + ) + + with patch.object(httpx.AsyncClient, "get", new_callable=AsyncMock) as mock_get: + mock_get.side_effect = [mock_response_fail, mock_response_success] + content = await scraper.fetch("https://example.com/page") + + assert content == "Success" + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def it_raises_after_max_retries(self, scraper): + scraper.config.max_retries = 2 + scraper.config.retry_delay = 0.1 + + mock_response = Mock( + status_code=500, + raise_for_status=Mock( + side_effect=httpx.HTTPStatusError("Server error", request=Mock(), response=Mock()) + ), + ) + + with patch.object(httpx.AsyncClient, "get", new_callable=AsyncMock) as mock_get: + mock_get.return_value = mock_response + + with pytest.raises(ScraperError, match="Failed after 2 retries"): + await scraper.fetch("https://example.com/page") + + assert mock_get.call_count == 2 + + @pytest.mark.asyncio + async def it_implements_exponential_backoff(self, scraper): + scraper.config.max_retries = 3 + scraper.config.retry_delay = 0.1 + + mock_response = Mock( + status_code=500, + raise_for_status=Mock( + side_effect=httpx.HTTPStatusError("Server error", request=Mock(), response=Mock()) + ), + ) + + delays = [] + original_sleep = asyncio.sleep + + async def track_sleep(delay): + delays.append(delay) + await original_sleep(0.01) # Speed up test + + with ( + patch.object(httpx.AsyncClient, "get", new_callable=AsyncMock) as mock_get, + patch("asyncio.sleep", side_effect=track_sleep), + pytest.raises(ScraperError), + ): + mock_get.return_value = mock_response + await scraper.fetch("https://example.com/page") + + # Check exponential backoff: 0.1, 0.2, 0.4 + assert len(delays) == 2 # max_retries - 1 + assert delays[0] == pytest.approx(0.1, rel=0.1) + assert delays[1] == pytest.approx(0.2, rel=0.1) + + @pytest.mark.asyncio + async def it_parses_html_content(self, scraper): + html = "

Title

Content

" + soup = await scraper.parse_html(html) + + assert soup.find("h1").text == "Title" + assert soup.find("p").text == "Content" + + @pytest.mark.asyncio + async def it_extracts_text_from_elements(self, scraper): + html = "

Test content

" + soup = await scraper.parse_html(html) + elem = soup.find("p") + + text = scraper.extract_text(elem) + assert text == "Test content" + + @pytest.mark.asyncio + async def it_handles_none_elements_safely(self, scraper): + text = scraper.extract_text(None) + assert text == "" + + text = scraper.extract_text(None, default="default") + assert text == "default" + + @pytest.mark.asyncio + async def it_extracts_dates_from_text(self, scraper): + html = "
" + soup = await scraper.parse_html(html) + elem = soup.find("time") + + date = scraper.extract_date(elem) + assert date.date() == datetime(2024, 3, 15).date() + + @pytest.mark.asyncio + async def it_returns_none_for_invalid_dates(self, scraper): + html = "
not a date
" + soup = await scraper.parse_html(html) + elem = soup.find("div") + + date = scraper.extract_date(elem) + assert date is None + + @pytest.mark.asyncio + async def it_runs_scraping_workflow(self, scraper): + with patch.object(scraper, "extract_deprecations", new_callable=AsyncMock) as mock_extract: + mock_extract.return_value = [{"test": "data"}] + + results = await scraper.run() + + assert results == [{"test": "data"}] + mock_extract.assert_called_once() + + @pytest.mark.asyncio + async def it_closes_client_on_cleanup(self, scraper): + client = await scraper._get_client() + mock_close = AsyncMock() + client.aclose = mock_close + + await scraper.close() + + mock_close.assert_called_once() + assert scraper._client is None + + @pytest.mark.asyncio + async def it_handles_extraction_errors(self, scraper): + with patch.object(scraper, "extract_deprecations", new_callable=AsyncMock) as mock_extract: + mock_extract.side_effect = Exception("Extraction failed") + + with pytest.raises(ExtractionError, match="Extraction failed"): + await scraper.run() diff --git a/tests/scrapers/test_utils.py b/tests/scrapers/test_utils.py new file mode 100644 index 0000000..6125baf --- /dev/null +++ b/tests/scrapers/test_utils.py @@ -0,0 +1,147 @@ +"""Tests for scraper utilities.""" + +from datetime import UTC, datetime + +import pytest + +from src.scrapers.utils import ( + clean_text, + normalize_url, + parse_date, + validate_url, +) + + +class DescribeParseDateUtil: + """Test date parsing utilities.""" + + def it_parses_iso_format(self): + result = parse_date("2024-03-15T10:30:00Z") + assert result == datetime(2024, 3, 15, 10, 30, 0, tzinfo=UTC) + + def it_parses_iso_date_only(self): + result = parse_date("2024-03-15") + assert result.date() == datetime(2024, 3, 15).date() + + def it_parses_rfc_format(self): + result = parse_date("Wed, 15 Mar 2024 10:30:00 GMT") + assert result == datetime(2024, 3, 15, 10, 30, 0, tzinfo=UTC) + + def it_parses_human_readable_format(self): + # Test various human-readable formats + dates = [ + "March 15, 2024", + "15 March 2024", + "Mar 15, 2024", + "2024-03-15", + ] + for date_str in dates: + result = parse_date(date_str) + assert result.date() == datetime(2024, 3, 15).date() + + def it_handles_relative_dates(self): + # These would need to be mocked for consistent testing + with pytest.raises(ValueError, match="Could not parse date"): + parse_date("tomorrow") + + def it_returns_none_for_invalid_dates(self): + assert parse_date("not a date", raise_on_error=False) is None + + def it_raises_for_invalid_dates_when_requested(self): + with pytest.raises(ValueError, match="Could not parse date"): + parse_date("not a date", raise_on_error=True) + + +class DescribeCleanText: + """Test text cleaning utilities.""" + + def it_removes_extra_whitespace(self): + text = " This has \n\n extra spaces " + assert clean_text(text) == "This has extra spaces" + + def it_removes_html_tags(self): + text = "

This is HTML content

" + assert clean_text(text) == "This is HTML content" + + def it_handles_special_characters(self): + text = "This has   special & characters <>" + assert clean_text(text) == "This has special & characters <>" + + def it_preserves_sentence_structure(self): + text = "First sentence. Second sentence! Third?" + assert clean_text(text) == "First sentence. Second sentence! Third?" + + def it_handles_empty_strings(self): + assert clean_text("") == "" + assert clean_text(" ") == "" + + def it_handles_none_values(self): + assert clean_text(None) == "" + + def it_optionally_preserves_line_breaks(self): + text = "Line one\nLine two\nLine three" + assert clean_text(text, preserve_lines=True) == "Line one\nLine two\nLine three" + + +class DescribeValidateUrl: + """Test URL validation.""" + + def it_validates_http_urls(self): + assert validate_url("http://example.com") is True + assert validate_url("https://example.com") is True + + def it_validates_complex_urls(self): + urls = [ + "https://example.com/path/to/page", + "https://example.com:8080/path", + "https://sub.example.com/path?query=1¶m=2", + "https://example.com/path#anchor", + ] + for url in urls: + assert validate_url(url) is True + + def it_rejects_invalid_urls(self): + invalid_urls = [ + "not a url", + "ftp://example.com", # Only http/https + "http://", + "//example.com", + "", + None, + ] + for url in invalid_urls: + assert validate_url(url) is False + + def it_optionally_requires_https(self): + assert validate_url("https://example.com", require_https=True) is True + assert validate_url("http://example.com", require_https=True) is False + + +class DescribeNormalizeUrl: + """Test URL normalization.""" + + def it_adds_missing_scheme(self): + assert normalize_url("example.com") == "https://example.com" + assert normalize_url("www.example.com") == "https://www.example.com" + + def it_preserves_existing_scheme(self): + assert normalize_url("http://example.com") == "http://example.com" + assert normalize_url("https://example.com") == "https://example.com" + + def it_removes_trailing_slashes(self): + assert normalize_url("https://example.com/") == "https://example.com" + assert normalize_url("https://example.com/path/") == "https://example.com/path" + + def it_preserves_query_params(self): + url = "https://example.com/path?param=value" + assert normalize_url(url) == url + + def it_handles_fragments(self): + assert normalize_url("https://example.com#section") == "https://example.com#section" + + def it_lowercases_domain(self): + assert normalize_url("https://EXAMPLE.COM/Path") == "https://example.com/Path" + + def it_handles_invalid_urls(self): + assert normalize_url("") == "" + assert normalize_url(None) == "" diff --git a/tests/unit/test_deprecation_models.py b/tests/unit/test_deprecation_models.py index 40fd657..9d4fed8 100644 --- a/tests/unit/test_deprecation_models.py +++ b/tests/unit/test_deprecation_models.py @@ -151,7 +151,7 @@ def it_validates_retirement_after_deprecation(): errors = exc_info.value.errors() assert any( - "retirement_date must be after deprecation_date" in str(error["msg"]) + "Retirement date must be after deprecation date" in str(error["msg"]) for error in errors ) @@ -225,7 +225,7 @@ def it_deserializes_from_dict(): assert deprecation.retirement_date == datetime(2024, 4, 1, 12, 0, 0, tzinfo=UTC) assert deprecation.replacement == "claude-3-haiku" assert deprecation.notes == "Upgrading to Claude 3" - assert str(deprecation.source_url) == "https://docs.anthropic.com/" + assert str(deprecation.source_url) == "https://docs.anthropic.com" assert deprecation.last_updated == datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC) def it_generates_consistent_hash():