Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aeon/facade/driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from dataclasses import dataclass
from functools import reduce
from typing import Any, Iterable
from typing import Any, Iterable, Optional

from aeon.backend.evaluator import EvaluationContext
from aeon.backend.evaluator import eval
Expand Down Expand Up @@ -67,7 +67,7 @@ def parse_core(self, filename: str):
self.core_ast = parse_term(aeon_code)
self.metadata: Metadata = {}

def parse(self, filename: str = None, aeon_code: str = None) -> Iterable[AeonError]:
def parse(self, filename: str, aeon_code: str | None = None) -> Iterable[AeonError]:
if aeon_code is None:
aeon_code = read_file(filename)

Expand Down Expand Up @@ -135,7 +135,7 @@ def has_synth(self) -> bool:
def synth(self) -> STerm:
with RecordTime("Synthesis"):
synthesizer = make_synthesizer(self.cfg.synthesizer)
mapping: dict[Name, Term] = synthesize_holes(
mapping: dict[Name, Optional[Term]] = synthesize_holes(
self.typing_ctx,
self.evaluation_ctx,
self.core,
Expand All @@ -157,7 +157,7 @@ def synth(self) -> STerm:

return lift(core_ast_anf)

def pretty_print(self, filename: str = None, should_be_fixed: bool = False) -> None:
def pretty_print(self, filename: str, should_be_fixed: bool = False) -> None:
aeon_code = read_file(filename)
prog: Program = parse_main_program(aeon_code, filename=filename)
prog = bind_program(prog, [])
Expand Down
30 changes: 25 additions & 5 deletions aeon/lsp/aeon_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
import urllib.parse
from dataclasses import dataclass
from typing import Dict, List, TextIO
from typing import Any, Dict, List, Optional, TextIO

import requests
from lark.exceptions import UnexpectedToken
Expand All @@ -40,9 +40,11 @@
@dataclass(frozen=True)
class ParseResult:
diagnostics: List[Diagnostic]
# Not needed for now but could be added on the future
# core_ast: Any = None
# typing_ctx: Any = None
# Added for code actions
core_ast: Optional[Any] = None
typing_ctx: Optional[Any] = None
evaluation_ctx: Optional[Any] = None
metadata: Optional[Any] = None


_parse_result_cache: Dict[URI, ParseResult] = {}
Expand Down Expand Up @@ -163,7 +165,25 @@ async def _parse(
severity=DiagnosticSeverity.Error,
)
)
return ParseResult(diagnostics)

# If parsing succeeded without errors, include AST and context for code actions
core_ast = None
typing_ctx = None
evaluation_ctx = None
metadata = None
if not diagnostics and hasattr(driver, "core") and hasattr(driver, "typing_ctx"):
core_ast = driver.core
typing_ctx = driver.typing_ctx
evaluation_ctx = getattr(driver, "evaluation_ctx", None)
metadata = getattr(driver, "metadata", None)

return ParseResult(
diagnostics=diagnostics,
core_ast=core_ast,
typing_ctx=typing_ctx,
evaluation_ctx=evaluation_ctx,
metadata=metadata,
)


async def _open(
Expand Down
Loading
Loading