Skip to content

Commit fac6068

Browse files
authored
feat(parse): add AST-based code skeleton extraction mode (#334)
1 parent 8274cc3 commit fac6068

File tree

19 files changed

+5517
-3206
lines changed

19 files changed

+5517
-3206
lines changed

examples/ov.conf.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
"mineru_timeout": 300.0
8787
},
8888
"code": {
89-
"enable_ast": true,
89+
"code_summary_mode": "ast",
9090
"extract_functions": true,
9191
"extract_classes": true,
9292
"extract_imports": true,
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Public API for AST-based code skeleton extraction."""
4+
5+
from typing import Optional
6+
7+
from openviking.parse.parsers.code.ast.extractor import get_extractor
8+
9+
10+
def extract_skeleton(file_name: str, content: str, verbose: bool = False) -> Optional[str]:
11+
"""Extract a skeleton from source code.
12+
13+
Supports Python, JS/TS, Java, C/C++, Rust, Go via tree-sitter.
14+
Returns None for unsupported languages or on extraction failure,
15+
signalling the caller to fall back to LLM.
16+
17+
Args:
18+
file_name: File name with extension (used for language detection).
19+
content: Source code content.
20+
verbose: If True, include full docstrings (for ast_llm / LLM input).
21+
If False, only first line of each docstring (for ast / embedding).
22+
23+
Returns:
24+
Plain-text skeleton string, or None if unsupported / failed.
25+
"""
26+
return get_extractor().extract_skeleton(file_name, content, verbose=verbose)
27+
28+
29+
__all__ = ["extract_skeleton"]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""ASTExtractor: language detection + dispatch to per-language extractors."""
4+
5+
import importlib
6+
import logging
7+
from pathlib import Path
8+
from typing import Dict, Optional
9+
10+
from openviking.parse.parsers.code.ast.languages.base import LanguageExtractor
11+
from openviking.parse.parsers.code.ast.skeleton import CodeSkeleton
12+
13+
logger = logging.getLogger(__name__)
14+
15+
# File extension → internal language key
16+
_EXT_MAP: Dict[str, str] = {
17+
".py": "python",
18+
".js": "javascript",
19+
".jsx": "javascript",
20+
".ts": "typescript",
21+
".tsx": "typescript",
22+
".java": "java",
23+
".c": "cpp",
24+
".cpp": "cpp",
25+
".cc": "cpp",
26+
".h": "cpp",
27+
".hpp": "cpp",
28+
".rs": "rust",
29+
".go": "go",
30+
}
31+
32+
# Language key → (module path, class name, constructor kwargs)
33+
_EXTRACTOR_REGISTRY: Dict[str, tuple] = {
34+
"python": ("openviking.parse.parsers.code.ast.languages.python", "PythonExtractor", {}),
35+
"javascript": ("openviking.parse.parsers.code.ast.languages.js_ts", "JsTsExtractor", {"lang": "javascript"}),
36+
"typescript": ("openviking.parse.parsers.code.ast.languages.js_ts", "JsTsExtractor", {"lang": "typescript"}),
37+
"java": ("openviking.parse.parsers.code.ast.languages.java", "JavaExtractor", {}),
38+
"cpp": ("openviking.parse.parsers.code.ast.languages.cpp", "CppExtractor", {}),
39+
"rust": ("openviking.parse.parsers.code.ast.languages.rust", "RustExtractor", {}),
40+
"go": ("openviking.parse.parsers.code.ast.languages.go", "GoExtractor", {}),
41+
}
42+
43+
44+
class ASTExtractor:
45+
"""Dispatches to per-language tree-sitter extractors for supported languages.
46+
47+
Unsupported languages return None, signalling the caller to fall back to LLM.
48+
"""
49+
50+
def __init__(self):
51+
self._cache: Dict[str, Optional[LanguageExtractor]] = {}
52+
53+
def _detect_language(self, file_name: str) -> Optional[str]:
54+
suffix = Path(file_name).suffix.lower()
55+
return _EXT_MAP.get(suffix)
56+
57+
def _get_extractor(self, lang: Optional[str]) -> Optional[LanguageExtractor]:
58+
if lang is None or lang not in _EXTRACTOR_REGISTRY:
59+
return None
60+
61+
if lang in self._cache:
62+
return self._cache[lang]
63+
64+
module_path, class_name, kwargs = _EXTRACTOR_REGISTRY[lang]
65+
try:
66+
mod = importlib.import_module(module_path)
67+
cls = getattr(mod, class_name)
68+
extractor = cls(**kwargs)
69+
self._cache[lang] = extractor
70+
return extractor
71+
except Exception as e:
72+
logger.warning("AST extractor unavailable for language '%s', falling back to LLM: %s", lang, e)
73+
self._cache[lang] = None
74+
return None
75+
76+
def extract_skeleton(self, file_name: str, content: str, verbose: bool = False) -> Optional[str]:
77+
"""Extract skeleton text from source code.
78+
79+
Returns None for unsupported languages or on extraction failure,
80+
signalling the caller to fall back to LLM.
81+
82+
Args:
83+
verbose: If True, include full docstrings (for ast_llm / LLM input).
84+
If False, only first line of each docstring (for ast / embedding).
85+
"""
86+
lang = self._detect_language(file_name)
87+
extractor = self._get_extractor(lang)
88+
if extractor is None:
89+
return None
90+
91+
try:
92+
skeleton: CodeSkeleton = extractor.extract(file_name, content)
93+
return skeleton.to_text(verbose=verbose)
94+
except Exception as e:
95+
logger.warning("AST extraction failed for '%s' (language: %s), falling back to LLM: %s", file_name, lang, e)
96+
return None
97+
98+
99+
# Module-level singleton
100+
_extractor: Optional[ASTExtractor] = None
101+
102+
103+
def get_extractor() -> ASTExtractor:
104+
global _extractor
105+
if _extractor is None:
106+
_extractor = ASTExtractor()
107+
return _extractor
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Abstract base class for language-specific AST extractors."""
4+
5+
from abc import ABC, abstractmethod
6+
7+
from openviking.parse.parsers.code.ast.skeleton import CodeSkeleton
8+
9+
10+
class LanguageExtractor(ABC):
11+
@abstractmethod
12+
def extract(self, file_name: str, content: str) -> CodeSkeleton:
13+
"""Extract code skeleton from source. Raises on unrecoverable error."""
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""C/C++ AST extractor using tree-sitter-cpp."""
4+
5+
from typing import List
6+
7+
from openviking.parse.parsers.code.ast.languages.base import LanguageExtractor
8+
from openviking.parse.parsers.code.ast.skeleton import ClassSkeleton, CodeSkeleton, FunctionSig
9+
10+
11+
def _node_text(node, content_bytes: bytes) -> str:
12+
return content_bytes[node.start_byte:node.end_byte].decode("utf-8", errors="replace")
13+
14+
15+
def _parse_block_comment(raw: str) -> str:
16+
"""Strip /** ... */ markers and leading * from each line."""
17+
raw = raw.strip()
18+
if raw.startswith("/**"):
19+
raw = raw[3:]
20+
elif raw.startswith("/*"):
21+
raw = raw[2:]
22+
if raw.endswith("*/"):
23+
raw = raw[:-2]
24+
lines = [l.strip().lstrip("*").strip() for l in raw.split("\n")]
25+
return "\n".join(l for l in lines if l).strip()
26+
27+
28+
def _preceding_doc(siblings: list, idx: int, content_bytes: bytes) -> str:
29+
"""Return Doxygen block comment immediately before siblings[idx], or ''."""
30+
if idx == 0:
31+
return ""
32+
prev = siblings[idx - 1]
33+
if prev.type == "comment":
34+
return _parse_block_comment(_node_text(prev, content_bytes))
35+
return ""
36+
37+
38+
def _extract_function_declarator(node, content_bytes: bytes):
39+
name = ""
40+
params = ""
41+
for child in node.children:
42+
if child.type in ("identifier", "field_identifier") and not name:
43+
name = _node_text(child, content_bytes)
44+
elif child.type == "qualified_identifier" and not name:
45+
name = _node_text(child, content_bytes)
46+
elif child.type == "function_declarator":
47+
n, p = _extract_function_declarator(child, content_bytes)
48+
if n:
49+
name = n
50+
if p:
51+
params = p
52+
elif child.type == "parameter_list":
53+
raw = _node_text(child, content_bytes).strip()
54+
if raw.startswith("(") and raw.endswith(")"):
55+
raw = raw[1:-1]
56+
params = raw.strip()
57+
return name, params
58+
59+
60+
def _extract_function(node, content_bytes: bytes, docstring: str = "") -> FunctionSig:
61+
name = ""
62+
params = ""
63+
return_type = ""
64+
65+
for child in node.children:
66+
if child.type == "function_declarator":
67+
name, params = _extract_function_declarator(child, content_bytes)
68+
elif child.type in ("type_specifier", "primitive_type", "type_identifier",
69+
"qualified_identifier", "auto"):
70+
if not return_type:
71+
return_type = _node_text(child, content_bytes)
72+
elif child.type == "pointer_declarator":
73+
for sub in child.children:
74+
if sub.type == "function_declarator":
75+
name, params = _extract_function_declarator(sub, content_bytes)
76+
77+
return FunctionSig(name=name, params=params, return_type=return_type, docstring=docstring)
78+
79+
80+
def _extract_class(node, content_bytes: bytes, docstring: str = "") -> ClassSkeleton:
81+
name = ""
82+
bases: List[str] = []
83+
body_node = None
84+
85+
for child in node.children:
86+
if child.type == "type_identifier" and not name:
87+
name = _node_text(child, content_bytes)
88+
elif child.type == "base_class_clause":
89+
for sub in child.children:
90+
if sub.type == "type_identifier":
91+
bases.append(_node_text(sub, content_bytes))
92+
elif child.type == "field_declaration_list":
93+
body_node = child
94+
95+
methods: List[FunctionSig] = []
96+
if body_node:
97+
siblings = list(body_node.children)
98+
for idx, child in enumerate(siblings):
99+
if child.type == "function_definition":
100+
doc = _preceding_doc(siblings, idx, content_bytes)
101+
methods.append(_extract_function(child, content_bytes, docstring=doc))
102+
elif child.type in ("declaration", "field_declaration"):
103+
ret_type = ""
104+
fn_name = ""
105+
fn_params = ""
106+
for sub in child.children:
107+
if sub.type in ("type_specifier", "primitive_type", "type_identifier",
108+
"qualified_identifier") and not ret_type:
109+
ret_type = _node_text(sub, content_bytes)
110+
elif sub.type == "function_declarator":
111+
fn_name, fn_params = _extract_function_declarator(sub, content_bytes)
112+
break
113+
if fn_name:
114+
doc = _preceding_doc(siblings, idx, content_bytes)
115+
methods.append(FunctionSig(name=fn_name, params=fn_params, return_type=ret_type, docstring=doc))
116+
117+
return ClassSkeleton(name=name, bases=bases, docstring=docstring, methods=methods)
118+
119+
120+
class CppExtractor(LanguageExtractor):
121+
def __init__(self):
122+
import tree_sitter_cpp as tscpp
123+
from tree_sitter import Language, Parser
124+
125+
self._language = Language(tscpp.language())
126+
self._parser = Parser(self._language)
127+
128+
def extract(self, file_name: str, content: str) -> CodeSkeleton:
129+
content_bytes = content.encode("utf-8")
130+
tree = self._parser.parse(content_bytes)
131+
root = tree.root_node
132+
133+
imports: List[str] = []
134+
classes: List[ClassSkeleton] = []
135+
functions: List[FunctionSig] = []
136+
137+
siblings = list(root.children)
138+
for idx, child in enumerate(siblings):
139+
if child.type == "preproc_include":
140+
for sub in child.children:
141+
if sub.type in ("string_literal", "system_lib_string"):
142+
raw = _node_text(sub, content_bytes).strip().strip('"<>')
143+
imports.append(raw)
144+
elif child.type in ("class_specifier", "struct_specifier"):
145+
doc = _preceding_doc(siblings, idx, content_bytes)
146+
classes.append(_extract_class(child, content_bytes, docstring=doc))
147+
elif child.type == "function_definition":
148+
doc = _preceding_doc(siblings, idx, content_bytes)
149+
functions.append(_extract_function(child, content_bytes, docstring=doc))
150+
elif child.type == "namespace_definition":
151+
for sub in child.children:
152+
if sub.type == "declaration_list":
153+
inner = list(sub.children)
154+
for i2, s2 in enumerate(inner):
155+
if s2.type in ("class_specifier", "struct_specifier"):
156+
doc = _preceding_doc(inner, i2, content_bytes)
157+
classes.append(_extract_class(s2, content_bytes, docstring=doc))
158+
elif s2.type == "function_definition":
159+
doc = _preceding_doc(inner, i2, content_bytes)
160+
functions.append(_extract_function(s2, content_bytes, docstring=doc))
161+
162+
return CodeSkeleton(
163+
file_name=file_name,
164+
language="C/C++",
165+
module_doc="",
166+
imports=imports,
167+
classes=classes,
168+
functions=functions,
169+
)

0 commit comments

Comments
 (0)