Skip to content

Commit be72982

Browse files
authored
Merge pull request #28 from NASA-IMPACT/feature/code-signals-search
feat: add CodeSignalsSearchTool for code-level repository search
2 parents fee360b + b73d0aa commit be72982

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

akd_ext/tools/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
SDESearchToolInputSchema,
99
SDESearchToolOutputSchema,
1010
)
11+
from .code_search.code_signals import (
12+
CodeSignalsSearchInputSchema,
13+
CodeSignalsSearchOutputSchema,
14+
CodeSignalsSearchTool,
15+
CodeSignalsSearchToolConfig,
16+
)
1117
from .code_search.repository_search import (
1218
RepositorySearchTool,
1319
RepositorySearchToolInputSchema,
@@ -24,6 +30,10 @@
2430
"SDESearchToolOutputSchema",
2531
"SDESearchToolConfig",
2632
"SDEDocument",
33+
"CodeSignalsSearchInputSchema",
34+
"CodeSignalsSearchOutputSchema",
35+
"CodeSignalsSearchTool",
36+
"CodeSignalsSearchToolConfig",
2737
"RepositorySearchTool",
2838
"RepositorySearchToolInputSchema",
2939
"RepositorySearchToolOutputSchema",
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
Code Signals Search Tool.
3+
4+
Searches LLM-extracted code signals from GitHub repositories.
5+
Use as fallback when README-based search (RepositorySearchTool) is insufficient.
6+
"""
7+
8+
import os
9+
from typing import Any, Literal
10+
from urllib.parse import urljoin
11+
12+
import httpx
13+
from loguru import logger
14+
from pydantic import Field
15+
16+
from akd._base import InputSchema, OutputSchema
17+
from akd.structures import SearchResult
18+
from akd.tools import BaseTool, BaseToolConfig
19+
20+
from akd_ext.mcp import mcp_tool
21+
22+
23+
class CodeSignalsSearchToolConfig(BaseToolConfig):
24+
"""Configuration for the Code Signals Search Tool."""
25+
26+
sde_base_url: str = Field(
27+
default=os.getenv("SDE_BASE_URL", "https://dyejsbdumgpqz.cloudfront.net/"),
28+
description="Base URL for SDE API",
29+
)
30+
endpoint: str = Field(
31+
default="/api/code_signals/search",
32+
description="Code signals search endpoint",
33+
)
34+
timeout: float = Field(
35+
default=30.0,
36+
description="HTTP request timeout in seconds",
37+
)
38+
search_type: Literal["hybrid", "vector", "keyword"] = Field(
39+
default="hybrid",
40+
description="Search type: 'hybrid' (vector + keyword, recommended), 'vector' (semantic only), 'keyword' (exact matching)",
41+
)
42+
debug: bool = Field(default=False, description="Enable debug logging")
43+
44+
45+
class CodeSignalsHit(SearchResult):
46+
"""A single code signals hit from SDE search."""
47+
48+
repo_id: str | None = Field(None, description="Repository identifier")
49+
repo_url: str | None = Field(None, description="GitHub repository URL")
50+
51+
52+
class CodeSignalsSearchInputSchema(InputSchema):
53+
"""Input schema for Code Signals search."""
54+
55+
query: str = Field(..., description="Search query for code functionality")
56+
limit: int = Field(default=5, ge=1, le=6, description="Maximum results to return")
57+
page: int = Field(default=1, ge=1, description="Page number for pagination")
58+
59+
60+
class CodeSignalsSearchOutputSchema(OutputSchema):
61+
"""Output schema for Code Signals search."""
62+
63+
results: list[CodeSignalsHit] = Field(..., description="List of matching code signals")
64+
65+
66+
@mcp_tool
67+
class CodeSignalsSearchTool(BaseTool[CodeSignalsSearchInputSchema, CodeSignalsSearchOutputSchema]):
68+
"""
69+
Search code repositories using LLM-extracted code signals.
70+
71+
Use this tool when README-based search is insufficient. Searches through
72+
extracted function names, class names, imports, data formats, and code summaries.
73+
"""
74+
75+
input_schema = CodeSignalsSearchInputSchema
76+
output_schema = CodeSignalsSearchOutputSchema
77+
config_schema = CodeSignalsSearchToolConfig
78+
79+
def _extract_summary(self, content: str) -> str:
80+
"""Extract all Code Summary sections."""
81+
if not content:
82+
return ""
83+
84+
summaries = []
85+
86+
for part in content.split("Code Summary:")[1:]:
87+
summary = part.split("\n\n")[0].split("===")[0].strip()
88+
summary = " ".join(summary.split())
89+
if summary:
90+
summaries.append(summary)
91+
92+
return "\n\n".join(summaries) if summaries else content[:1500]
93+
94+
def _parse_hit(self, doc: dict[str, Any], query: str) -> CodeSignalsHit:
95+
"""Parse a single document from API response."""
96+
code_signals = doc.get("code_signals") or ""
97+
98+
return CodeSignalsHit(
99+
query=query,
100+
title=doc.get("repo_id") or doc.get("name") or "Unknown",
101+
content=self._extract_summary(code_signals),
102+
score=doc.get("score") or doc.get("_score") or 0.0,
103+
repo_id=doc.get("repo_id"),
104+
repo_url=doc.get("repo_url"),
105+
)
106+
107+
async def _arun(self, params: CodeSignalsSearchInputSchema) -> CodeSignalsSearchOutputSchema:
108+
"""Execute Code Signals search."""
109+
request_body = {
110+
"search_term": params.query,
111+
"search_type": self.config.search_type,
112+
"page_size": params.limit,
113+
"page": params.page,
114+
}
115+
116+
if self.config.debug:
117+
logger.debug(f"Code Signals API request: {request_body}")
118+
119+
url = urljoin(self.config.sde_base_url, self.config.endpoint)
120+
121+
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
122+
try:
123+
response = await client.post(url, json=request_body)
124+
response.raise_for_status()
125+
data = response.json()
126+
except httpx.TimeoutException as e:
127+
msg = f"Code Signals API request timed out after {self.config.timeout}s"
128+
raise TimeoutError(msg) from e
129+
except httpx.HTTPStatusError as e:
130+
msg = f"Code Signals API returned error status {e.response.status_code}: {e.response.text}"
131+
raise RuntimeError(msg) from e
132+
except Exception as e:
133+
msg = f"Failed to query Code Signals API: {e}"
134+
raise RuntimeError(msg) from e
135+
136+
documents = [self._parse_hit(doc, params.query) for doc in data.get("documents", [])]
137+
138+
return CodeSignalsSearchOutputSchema(results=documents)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Tests for Code Signals Search Tool."""
2+
3+
import pytest
4+
5+
from akd_ext.tools.code_search.code_signals import (
6+
CodeSignalsSearchTool,
7+
CodeSignalsSearchToolConfig,
8+
CodeSignalsSearchInputSchema,
9+
CodeSignalsSearchOutputSchema,
10+
)
11+
12+
13+
class TestCodeSignalsSearchTool:
14+
@pytest.mark.asyncio
15+
@pytest.mark.parametrize(
16+
"query",
17+
[
18+
"MODIS",
19+
"earth science data processing",
20+
],
21+
)
22+
async def test_code_signals_search_basic(self, query: str):
23+
"""Test Code Signals Search Tool returns results for a query.
24+
25+
Args:
26+
query: Code search query to test
27+
"""
28+
config = CodeSignalsSearchToolConfig()
29+
tool = CodeSignalsSearchTool(config=config)
30+
result = await tool.arun(CodeSignalsSearchInputSchema(query=query, limit=5))
31+
32+
assert isinstance(result, CodeSignalsSearchOutputSchema)
33+
assert len(result.results) <= 5
34+
35+
for hit in result.results:
36+
assert hasattr(hit, "query")
37+
assert hasattr(hit, "score")
38+
assert hit.query == query
39+
40+
@pytest.mark.asyncio
41+
async def test_code_signals_search_pagination(self):
42+
"""Test Code Signals Search Tool pagination."""
43+
config = CodeSignalsSearchToolConfig()
44+
tool = CodeSignalsSearchTool(config=config)
45+
46+
page1 = await tool.arun(
47+
CodeSignalsSearchInputSchema(query="MODIS", limit=3, page=1),
48+
)
49+
page2 = await tool.arun(
50+
CodeSignalsSearchInputSchema(query="MODIS", limit=3, page=2),
51+
)
52+
53+
assert isinstance(page1, CodeSignalsSearchOutputSchema)
54+
assert isinstance(page2, CodeSignalsSearchOutputSchema)
55+
assert len(page1.results) <= 3
56+
assert len(page2.results) <= 3
57+
58+
if page1.results and page2.results:
59+
keys1 = {(h.title, h.repo_id) for h in page1.results}
60+
keys2 = {(h.title, h.repo_id) for h in page2.results}
61+
assert keys1.isdisjoint(keys2), "Page 1 and page 2 should not overlap"
62+
63+
@pytest.mark.parametrize(
64+
"limit,expected_max",
65+
[
66+
(1, 1),
67+
(5, 5),
68+
(6, 6),
69+
],
70+
)
71+
@pytest.mark.asyncio
72+
async def test_code_signals_search_limit(self, limit: int, expected_max: int):
73+
"""Test Code Signals Search Tool respects limit parameter."""
74+
config = CodeSignalsSearchToolConfig()
75+
tool = CodeSignalsSearchTool(config=config)
76+
result = await tool.arun(
77+
CodeSignalsSearchInputSchema(query="MODIS", limit=limit, page=1),
78+
)
79+
assert len(result.results) <= expected_max

0 commit comments

Comments
 (0)