Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"python-json-logger>=3.3.0",
"jinja2>=3.1.0",
"ty>=0.0.5",
"opentelemetry-sdk>=1.38.0",
]

[build-system]
Expand All @@ -31,19 +32,8 @@ dev = [
"ruff>=0.1.0",
"pytest-socket>=0.7.0",
"ipdb>=0.13.13",
"ty>=0.0.5",
]

[tool.ariadne-codegen]
schema_path = "./schema.graphql"
queries_path = "src/saleor_mcp/graphql"
target_package_name = "saleor_client"
target_package_path = "src/saleor_mcp"
convert_to_snake_case = false

[tool.ty.src]
exclude = ["src/saleor_mcp/saleor_client/*"]

[tool.ruff]
target-version = "py312"
line-length = 88
Expand Down Expand Up @@ -109,3 +99,9 @@ addopts = [
"--allow-unix-socket",
]

[tool.ariadne-codegen]
schema_path = "./schema.graphql"
queries_path = "src/saleor_mcp/graphql"
target_package_name = "saleor_client"
target_package_path = "src/saleor_mcp"
convert_to_snake_case = false
5 changes: 4 additions & 1 deletion src/saleor_mcp/ctx_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from .config import get_config_from_headers
from .saleor_client.client import Client
from .saleor_client.graphql_client import instrument_graphql_client


def get_saleor_client() -> Client:
"""Create and return a Saleor GraphQL client using configuration from headers.

Note: This function works only within a request context.
"""

saleor_headers = get_config_from_headers()
headers = {"Authorization": f"Bearer {saleor_headers.auth_token}"}
return Client(url=saleor_headers.api_url, headers=headers)
client = Client(url=saleor_headers.api_url, headers=headers)
return instrument_graphql_client(client)
26 changes: 2 additions & 24 deletions src/saleor_mcp/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import inspect
import logging
import tomllib
from pathlib import Path
from typing import Annotated, Any, get_args, get_origin, get_type_hints

Expand All @@ -12,28 +11,6 @@
logger = logging.getLogger(__name__)


def get_version_from_pyproject() -> str:
"""Read version from pyproject.toml.

Returns:
Version string from pyproject.toml, or "unknown" if not found.

"""
try:
# Find pyproject.toml - go up from this file to project root
project_root = Path(__file__).parent.parent.parent
pyproject_path = project_root / "pyproject.toml"

if pyproject_path.exists():
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)
return data.get("project", {}).get("version", "unknown")
except Exception:
logger.warning("Failed to read version from pyproject.toml")

return "unknown"


def generate_html(output_path: str | None = None) -> str:
"""Generate HTML documentation from tools.

Expand All @@ -50,12 +27,13 @@ def generate_html(output_path: str | None = None) -> str:
"""
# Import here to avoid circular dependency
from saleor_mcp.main import mcp
from saleor_mcp.utils import get_pyproject_value

# Introspect tools from the MCP server and all mounted routers
tools = introspect_from_mcp_server(mcp)

# Get version from pyproject.toml
version = get_version_from_pyproject()
version = get_pyproject_value("project", "version", default="unknown")

# Setup Jinja2 environment
template_dir = Path(__file__).parent / "templates"
Expand Down
6 changes: 6 additions & 0 deletions src/saleor_mcp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from starlette.staticfiles import StaticFiles

from saleor_mcp.docs import generate_html
from saleor_mcp.telemetry import initialise_telemetry
from saleor_mcp.tools import (
channels_router,
customers_router,
orders_router,
products_router,
utils_router,
)
from saleor_mcp.utils import get_pyproject_value

mcp = FastMCP("Saleor MCP Server")
mcp.add_middleware(DetailedTimingMiddleware())
Expand Down Expand Up @@ -56,6 +58,10 @@ async def index(request: Request):
def main():
import uvicorn

initialise_telemetry(
service_name=get_pyproject_value("project", "name", default="saleor-mcp"),
service_version=get_pyproject_value("project", "version", default="unknown")
)
uvicorn.run(app, host="127.0.0.1", port=6000)


Expand Down
90 changes: 90 additions & 0 deletions src/saleor_mcp/saleor_client/graphql_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Optional, Any, Dict, Tuple, IO, List

from opentelemetry.propagate import inject
from opentelemetry.trace import SpanKind

from .client import Client
from ..telemetry import mcp_attributes
from ..telemetry import tracer


def instrument_graphql_client(client: Client) -> Client:
"""Instrumented GraphQL client.

This wraps the generated AsyncBaseClient to add OpenTelemetry tracing
without modifying the code-generated file itself. Since code-gen may
overwrite AsyncBaseClient in future runs, this wrapper provides a safe
and stable place to extend the client with observability features.
"""

original_execute_json = client._execute_json
original_execute_multipart = client._execute_multipart

async def traced_execute_json(
query: str,
operation_name: Optional[str],
variables: Dict[str, Any],
**kwargs: Any,
):
headers = {"Content-Type": "application/json"}
headers.update(kwargs.get("headers", {}))
inject(headers)
merged_kwargs = {**kwargs, "headers": headers}
span_name = f"saleor.graphql.{operation_name or 'query'}"
with tracer.start_as_current_span(
span_name,
kind=SpanKind.CLIENT,
attributes={
mcp_attributes.HTTP_METHOD: "POST",
mcp_attributes.SALEOR_ENDPOINT: client.url,
mcp_attributes.GRAPHQL_OPERATION_NAME: operation_name,
mcp_attributes.GRAPHQL_VARIABLES: str(variables),
},
) as span:
response = await original_execute_json(
query=query,
operation_name=operation_name,
variables=variables,
**merged_kwargs,
)
span.set_attribute(mcp_attributes.HTTP_STATUS_CODE, response.status_code)
return response

async def traced_multipart(
query: str,
operation_name: Optional[str],
variables: Dict[str, Any],
files: Dict[str, Tuple[str, IO[bytes], str]],
files_map: Dict[str, List[str]],
**kwargs: Any,
):
headers = kwargs.get("headers", {})
inject(headers)
merged_kwargs = {**kwargs, "headers": headers}
span_name = f"saleor.graphql.{operation_name or 'multipart'}"
with tracer.start_as_current_span(
span_name,
kind=SpanKind.CLIENT,
attributes={
mcp_attributes.HTTP_METHOD: "POST",
mcp_attributes.SALEOR_ENDPOINT: client.url,
mcp_attributes.GRAPHQL_OPERATION_NAME: operation_name,
mcp_attributes.GRAPHQL_VARIABLES: str(variables),
mcp_attributes.GRAPHQL_MULTIPART: True,
mcp_attributes.GRAPHQL_FILES_COUNT: len(files),
},
) as span:
response = await original_execute_multipart(
query=query,
operation_name=operation_name,
variables=variables,
files=files,
files_map=files_map,
**merged_kwargs,
)
span.set_attribute(mcp_attributes.HTTP_STATUS_CODE, response.status_code)
return response

client._execute_json = traced_execute_json
client._execute_multipart = traced_multipart
return client
149 changes: 149 additions & 0 deletions src/saleor_mcp/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import functools
import uuid
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, AsyncIterator, Optional, Callable, Sequence

from opentelemetry import trace
from opentelemetry.sdk._configuration import _OTelSDKConfigurator
from opentelemetry.sdk.resources import SERVICE_INSTANCE_ID, SERVICE_NAME, SERVICE_VERSION
from opentelemetry.trace import StatusCode, Link
from opentelemetry.util.types import Attributes

from . import mcp_attributes
from .metric import record_duration, record_operation_count

tracer = trace.get_tracer("saleor-mcp")


class Kind(Enum):
TOOL = "tool"


@asynccontextmanager
async def operation_context(
*,
operation_kind: str,
operation_name: str,
attributes: Attributes | None = None,
links: Sequence[Link] | None = None,
record_exception: bool = True,
set_status_on_exception: bool = True,
end_on_exit: bool = True,
) -> AsyncIterator[tuple[Any, dict[str, Any]]]:
"""Telemetry wrapper for a single MCP operation.

Creates a tracing span and records all related metrics:
- operation count (increments once per call)
- operation duration (nanoseconds)
- operation errors (increments on exception)
"""

metric_attrs: dict[str, Any] = dict(attributes or {})
metric_attrs.setdefault(mcp_attributes.MCP_OPERATION_KIND, operation_kind)
metric_attrs.setdefault(mcp_attributes.MCP_OPERATION_NAME, operation_name)
span_name = f"mcp.{operation_kind}.{operation_name}"
async with record_operation_count(metric_attrs):
async with record_duration(metric_attrs) as (metric_attrs, start_time):
with tracer.start_as_current_span(
span_name,
links=links,
start_time=start_time,
end_on_exit=end_on_exit,
attributes=metric_attrs,
record_exception=record_exception,
set_status_on_exception=set_status_on_exception,
) as span:
yield span, metric_attrs
span.set_status(status=StatusCode.OK)


def instrument(
kind: Kind,
*,
name: str | None = None,
include_args: bool = True,
record_exception: bool = True,
set_status_on_exception: bool = True,
end_on_exit: bool = True,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorator to instrument MCP operations with tracing + metrics.

Args:
kind:
Logical kind of operation (e.g. Kind.TOOL).
name:
Optional explicit operation name. Defaults to function __name__.
include_args:
If True, adds function kwargs as span attributes, prefixed with
"mcp.<kind>.input.".
record_exception:
Passed through to operation_context / tracer.
set_status_on_exception:
Passed through to operation_context / tracer.
end_on_exit:
Passed through to operation_context / tracer.
"""

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
op_name = name or func.__name__

@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
attrs: dict[str, Any] = {}
if kind is Kind.TOOL:
attrs[mcp_attributes.MCP_TOOL_NAME] = op_name

async with operation_context(
operation_kind=kind.value,
operation_name=op_name,
attributes=attrs,
end_on_exit=end_on_exit,
record_exception=record_exception,
set_status_on_exception=set_status_on_exception,
) as (span, metric_attrs):
if include_args and kwargs:
prefix = f"mcp.{kind.value}.input."
for key, value in kwargs.items():
value_str = str(value)
span.set_attribute(prefix + key, value_str)
result = await func(*args, **kwargs)
return result

return wrapper

return decorator


def otel_configure_sdk(
service_name: str,
service_version: str,
additional_attributes: dict[str, Any] | None = None,
):
resource_attributes = {
SERVICE_NAME: service_name,
SERVICE_VERSION: service_version,
SERVICE_INSTANCE_ID: str(uuid.uuid4()),
}

if additional_attributes:
resource_attributes.update(additional_attributes)

configurator = _OTelSDKConfigurator()
configurator.configure(resource_attributes=resource_attributes)


def initialise_telemetry(
service_name: str,
service_version: str,
additional_attributes: Optional[dict[str, Any]] = None,
):
"""
Initialise OpenTelemetry SDK with resource attributes.

Args:
service_name: Name of the MCP server
service_version: Version of the MCP server
additional_attributes: Additional resource attributes
"""
otel_configure_sdk(service_name, service_version, additional_attributes)
22 changes: 22 additions & 0 deletions src/saleor_mcp/telemetry/mcp_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Server
MCP_SERVER_NAME = "mcp.server.name"
MCP_SERVER_VERSION = "mcp.server.version"

# Operation
MCP_OPERATION_KIND = "mcp.operation.kind"
MCP_OPERATION_NAME = "mcp.operation.name"
METRIC_OPERATION_COUNT = "mcp.operation.total"
METRIC_OPERATION_DURATION = "mcp.operation.duration"

# Tool
MCP_TOOL_NAME = "mcp.tool.name"

# HTTP / GraphQL client attributes for Saleor calls
HTTP_METHOD = "http.method"
HTTP_STATUS_CODE = "http.status_code"
HTTP_RESPONSE_CONTENT_LENGTH = "http.response_content_length"
GRAPHQL_OPERATION_NAME = "graphql.operation_name"
GRAPHQL_VARIABLES = "graphql.variables"
GRAPHQL_MULTIPART = "graphql.multipart"
GRAPHQL_FILES_COUNT = "graphql.files.count"
SALEOR_ENDPOINT = "mcp.saleor.endpoint"
Loading