From 81f3575ca7e7422f52ab0dbe242a527ea556ecf0 Mon Sep 17 00:00:00 2001 From: suraciii Date: Sat, 28 Feb 2026 22:47:08 +0800 Subject: [PATCH] feat: Add C# AST extractor support - Add CSharpExtractor class with support for: - Classes, interfaces, structs, records - Methods, constructors, properties - XML documentation comments (/// and /** */) - Namespaces and file-scoped namespaces - using directives - Register .cs file extension and csharp extractor - Add tree-sitter-c-sharp dependency - Add comprehensive tests aligned with other language extractors --- .../parse/parsers/code/ast/extractor.py | 2 + .../parsers/code/ast/languages/csharp.py | 189 ++++++++++++++++++ pyproject.toml | 1 + tests/parse/test_ast_extractor.py | 135 +++++++++++++ 4 files changed, 327 insertions(+) create mode 100644 openviking/parse/parsers/code/ast/languages/csharp.py diff --git a/openviking/parse/parsers/code/ast/extractor.py b/openviking/parse/parsers/code/ast/extractor.py index 5866ceea..514756da 100644 --- a/openviking/parse/parsers/code/ast/extractor.py +++ b/openviking/parse/parsers/code/ast/extractor.py @@ -27,6 +27,7 @@ ".hpp": "cpp", ".rs": "rust", ".go": "go", + ".cs": "csharp", } # Language key → (module path, class name, constructor kwargs) @@ -38,6 +39,7 @@ "cpp": ("openviking.parse.parsers.code.ast.languages.cpp", "CppExtractor", {}), "rust": ("openviking.parse.parsers.code.ast.languages.rust", "RustExtractor", {}), "go": ("openviking.parse.parsers.code.ast.languages.go", "GoExtractor", {}), + "csharp": ("openviking.parse.parsers.code.ast.languages.csharp", "CSharpExtractor", {}), } diff --git a/openviking/parse/parsers/code/ast/languages/csharp.py b/openviking/parse/parsers/code/ast/languages/csharp.py new file mode 100644 index 00000000..f47f0cbb --- /dev/null +++ b/openviking/parse/parsers/code/ast/languages/csharp.py @@ -0,0 +1,189 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""C# AST extractor using tree-sitter-c-sharp.""" + +import re +from typing import List + +from openviking.parse.parsers.code.ast.languages.base import LanguageExtractor +from openviking.parse.parsers.code.ast.skeleton import ClassSkeleton, CodeSkeleton, FunctionSig + + +def _node_text(node, content_bytes: bytes) -> str: + return content_bytes[node.start_byte : node.end_byte].decode("utf-8", errors="replace") + + +def _parse_doc_comment(raw: str) -> str: + """Strip XML doc comment markers (/// or /** */) and extract text from XML tags.""" + raw = raw.strip() + if raw.startswith("///"): + lines = raw.split("\n") + cleaned = [] + for line in lines: + stripped = line.strip() + if stripped.startswith("///"): + stripped = stripped[3:].strip() + if stripped: + cleaned.append(stripped) + raw = " ".join(cleaned) + elif raw.startswith("/**"): + raw = raw[3:] + if raw.endswith("*/"): + raw = raw[:-2] + lines = [l.strip().lstrip("*").strip() for l in raw.split("\n")] + raw = "\n".join(l for l in lines if l).strip() + # Remove XML tags + raw = re.sub(r"]*)?/?>", "", raw) + # Normalize whitespace + raw = re.sub(r"\s+", " ", raw).strip() + return raw + + +def _preceding_doc(siblings: list, idx: int, content_bytes: bytes) -> str: + """Return XML doc comment immediately before siblings[idx], or ''.""" + if idx == 0: + return "" + comments = [] + for i in range(idx - 1, -1, -1): + prev = siblings[i] + if prev.type == "comment": + text = _node_text(prev, content_bytes) + if text.strip().startswith("///") or text.strip().startswith("/**"): + comments.insert(0, _parse_doc_comment(text)) + else: + break + elif prev.type in ("preprocessor_directive", "nullable_directive"): + continue + else: + break + return "\n".join(comments) if comments else "" + + +def _extract_method(node, content_bytes: bytes, docstring: str = "") -> FunctionSig: + name = "" + params = "" + return_type = "" + + for child in node.children: + if child.type == "identifier" and not name: + name = _node_text(child, content_bytes) + elif child.type == "void_keyword": + return_type = "void" + elif child.type in ("predefined_type", "type_identifier", "generic_name"): + if not return_type: + return_type = _node_text(child, content_bytes) + elif child.type == "parameter_list": + raw = _node_text(child, content_bytes).strip() + if raw.startswith("(") and raw.endswith(")"): + raw = raw[1:-1] + params = raw.strip() + + if node.type == "property_declaration": + for child in node.children: + if child.type == "accessor_list": + accessors = [] + for acc in child.children: + if acc.type == "accessor_declaration": + accessor_name = "" + name_node = acc.child_by_field_name("name") + if name_node is not None: + accessor_name = _node_text(name_node, content_bytes).strip() + else: + for sub in acc.children: + if sub.type in ("get", "set", "init"): + accessor_name = sub.type + break + if accessor_name in ("get", "set", "init"): + accessors.append(accessor_name) + if accessors: + params = f"{{ {' '.join(accessors)} }}" + + return FunctionSig(name=name, params=params, return_type=return_type, docstring=docstring) + + +def _extract_class(node, content_bytes: bytes, docstring: str = "") -> ClassSkeleton: + name = "" + bases: List[str] = [] + body_node = None + + for child in node.children: + if child.type == "identifier" and not name: + name = _node_text(child, content_bytes) + elif child.type == "base_list": + for sub in child.children: + if sub.type in ("type_identifier", "identifier"): + bases.append(_node_text(sub, content_bytes)) + elif child.type == "declaration_list": + body_node = child + + methods: List[FunctionSig] = [] + if body_node: + siblings = list(body_node.children) + for idx, child in enumerate(siblings): + if child.type in ("method_declaration", "constructor_declaration"): + doc = _preceding_doc(siblings, idx, content_bytes) + methods.append(_extract_method(child, content_bytes, docstring=doc)) + elif child.type == "property_declaration": + doc = _preceding_doc(siblings, idx, content_bytes) + methods.append(_extract_method(child, content_bytes, docstring=doc)) + + return ClassSkeleton(name=name, bases=bases, docstring=docstring, methods=methods) + + +class CSharpExtractor(LanguageExtractor): + def __init__(self): + import tree_sitter_c_sharp as tscsharp + from tree_sitter import Language, Parser + + self._language = Language(tscsharp.language()) + self._parser = Parser(self._language) + + def extract(self, file_name: str, content: str) -> CodeSkeleton: + content_bytes = content.encode("utf-8") + tree = self._parser.parse(content_bytes) + root = tree.root_node + + imports: List[str] = [] + classes: List[ClassSkeleton] = [] + functions: List[FunctionSig] = [] + + siblings = list(root.children) + for idx, child in enumerate(siblings): + if child.type == "using_directive": + for sub in child.children: + if sub.type == "identifier": + imports.append(_node_text(sub, content_bytes)) + elif sub.type == "qualified_name": + imports.append(_node_text(sub, content_bytes)) + elif child.type in ("namespace_declaration", "file_scoped_namespace_declaration"): + for sub in child.children: + if sub.type == "declaration_list": + ns_siblings = list(sub.children) + for ns_idx, ns_child in enumerate(ns_siblings): + if ns_child.type in ( + "class_declaration", + "interface_declaration", + "struct_declaration", + "record_declaration", + ): + doc = _preceding_doc(ns_siblings, ns_idx, content_bytes) + classes.append( + _extract_class(ns_child, content_bytes, docstring=doc) + ) + elif child.type in ( + "class_declaration", + "interface_declaration", + "struct_declaration", + "record_declaration", + ): + doc = _preceding_doc(siblings, idx, content_bytes) + classes.append(_extract_class(child, content_bytes, docstring=doc)) + + return CodeSkeleton( + file_name=file_name, + language="C#", + module_doc="", + imports=imports, + classes=classes, + functions=functions, + ) diff --git a/pyproject.toml b/pyproject.toml index 0793c1d5..82553269 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ "tree-sitter-cpp>=0.23.0", "tree-sitter-rust>=0.23.0", "tree-sitter-go>=0.23.0", + "tree-sitter-c-sharp>=0.23.0", ] [tool.uv.sources] diff --git a/tests/parse/test_ast_extractor.py b/tests/parse/test_ast_extractor.py index f505b826..e7018c10 100644 --- a/tests/parse/test_ast_extractor.py +++ b/tests/parse/test_ast_extractor.py @@ -30,6 +30,11 @@ def _ts_extractor(): return JsTsExtractor(lang="typescript") +def _csharp_extractor(): + from openviking.parse.parsers.code.ast.languages.csharp import CSharpExtractor + return CSharpExtractor() + + # --------------------------------------------------------------------------- # Python @@ -452,6 +457,130 @@ def test_to_text_verbose(self): assert "@return sum of a and b" in text +# --------------------------------------------------------------------------- +# C# +# --------------------------------------------------------------------------- + +class TestCSharpExtractor: + SAMPLE = """ +using System; +using System.Collections.Generic; + +namespace MyApp.Services +{ + /// + /// A simple calculator service. + /// + /// Supports basic arithmetic operations. + /// + public class Calculator + { + /// + /// Add two integers. + /// + /// First operand + /// Second operand + /// Sum of a and b + /// + public int Add(int a, int b) + { + return a + b; + } + + /// + /// Subtract b from a. + /// + public int Subtract(int a, int b) + { + return a - b; + } + } +} +""" + + def setup_method(self): + self.e = _csharp_extractor() + + def test_imports(self): + sk = self.e.extract("Calculator.cs", self.SAMPLE) + assert "System" in sk.imports + assert "System.Collections.Generic" in sk.imports + + def test_class_extracted(self): + sk = self.e.extract("Calculator.cs", self.SAMPLE) + names = {c.name for c in sk.classes} + assert "Calculator" in names + + def test_class_docstring(self): + sk = self.e.extract("Calculator.cs", self.SAMPLE) + cls = next(c for c in sk.classes if c.name == "Calculator") + assert "simple calculator service" in cls.docstring + assert "Supports basic arithmetic" in cls.docstring + + def test_methods_extracted(self): + sk = self.e.extract("Calculator.cs", self.SAMPLE) + cls = next(c for c in sk.classes if c.name == "Calculator") + methods = {m.name: m for m in cls.methods} + assert "Add" in methods + assert "Subtract" in methods + + def test_method_docstring(self): + sk = self.e.extract("Calculator.cs", self.SAMPLE) + cls = next(c for c in sk.classes if c.name == "Calculator") + methods = {m.name: m for m in cls.methods} + assert "Add two integers." in methods["Add"].docstring + assert "First operand" in methods["Add"].docstring + + def test_to_text_compact(self): + sk = self.e.extract("Calculator.cs", self.SAMPLE) + text = sk.to_text(verbose=False) + assert "# Calculator.cs [C#]" in text + assert "class Calculator" in text + assert "+ Add(" in text + assert "First operand" not in text + + def test_to_text_verbose(self): + sk = self.e.extract("Calculator.cs", self.SAMPLE) + text = sk.to_text(verbose=True) + assert "simple calculator service" in text + assert "First operand" in text + + def test_file_scoped_namespace(self): + code = ''' +using System; + +namespace MyApp.Services; + +public class Calculator +{ + public int Add(int a, int b) + { + return a + b; + } +} +''' + sk = self.e.extract("Calculator.cs", code) + names = {c.name for c in sk.classes} + assert "Calculator" in names + + def test_property_accessor_signature(self): + code = ''' +public class Calculator +{ + /// + /// Current result. + /// + public int Result { get; set; } +} +''' + sk = self.e.extract("Calculator.cs", code) + cls = next(c for c in sk.classes if c.name == "Calculator") + methods = {m.name: m for m in cls.methods} + assert "Result" in methods + assert "get" in methods["Result"].params + assert "set" in methods["Result"].params + + # --------------------------------------------------------------------------- # C/C++ # --------------------------------------------------------------------------- @@ -853,6 +982,12 @@ def test_go_dispatch(self): assert "# main.go [Go]" in text assert "Run" in text + def test_csharp_dispatch(self): + code = 'namespace Demo;\n\npublic class Util { public int Add(int a, int b) { return a + b; } }\n' + text = self.extractor.extract_skeleton("util.cs", code) + assert "# util.cs [C#]" in text + assert "class Util" in text + def test_unknown_extension_returns_none(self): code = "def foo(x): pass\nclass Bar: pass\n" result = self.extractor.extract_skeleton("script.lua", code)