Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: refactor
channels:
- defaults
dependencies:
- python>=3.8.2
- pytest
5 changes: 3 additions & 2 deletions refactor/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def apply(self, context: Context, source: str) -> str:
indentation, start_prefix = find_indent(source_lines[0][:col_offset])
end_suffix = source_lines[-1][end_col_offset:]
replacement = split_lines(self._resynthesize(context))
replacement.apply_indentation(
indentation, start_prefix=start_prefix, end_suffix=end_suffix
# Applies the block indentation only if the replacement lines are different
replacement.apply_indentation_from_source(
indentation, source_lines.data, start_prefix=start_prefix, end_suffix=end_suffix
)

lines[view] = replacement
Expand Down
29 changes: 28 additions & 1 deletion refactor/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import cached_property
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, List

from refactor import common
from refactor.common import find_common_chars

DEFAULT_ENCODING = "utf-8"

Expand Down Expand Up @@ -51,6 +52,32 @@ def apply_indentation(
if len(self.data) >= 1:
self.data[-1] += str(end_suffix) # type: ignore

def apply_indentation_from_source(
self,
indentation: StringType,
source_data: List[StringType],
*,
start_prefix: AnyStringType = "",
end_suffix: AnyStringType = "",
) -> None:
"""Apply the given indentation only if the corresponding line in the source is different,
optionally with start and end prefixes to the bound source lines.
"""

def _is_original(i: int) -> bool:
common_chars: str = find_common_chars(str(self.data[i]), str(source_data[i].data))
is_multiline_string: int = str(self.data[i]).find(common_chars) == 0 and common_chars in ["'''", '"""']
return i < len(source_data) and (str(self.data[i]) == common_chars or is_multiline_string)

for index, line in enumerate(self.data):
if index == 0:
self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore
elif not _is_original(index):
self.data[index] = indentation + line # type: ignore

if len(self.data) >= 1:
self.data[-1] += str(end_suffix) # type: ignore

@cached_property
def _newline_type(self) -> str:
"""Guess the used newline type."""
Expand Down
10 changes: 10 additions & 0 deletions refactor/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ def find_indent(source: str) -> tuple[str, str]:
return source[:index], source[index:]


def find_common_chars(source: str, compare: str) -> str:
"""Finds the common characters starting the 2 strings"""
index: int = 0
for index, char in enumerate(source):
if index > len(compare) or char != compare[index]:
index -= 1
break
return source[:index+1]


def find_closest(node: ast.AST, *targets: ast.AST) -> ast.AST:
"""Find the closest node to the given ``node`` from the given
sequence of ``targets`` (uses absolute distance from starting points)."""
Expand Down
76 changes: 68 additions & 8 deletions tests/test_complete_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,37 @@ def match(self, node):
return AsyncifierAction(node)


class AwaitifierAction(LazyReplace):
def build(self):
if isinstance(self.node, ast.Expr):
self.node.value = ast.Await(self.node.value)
return self.node
if isinstance(self.node, ast.Call):
new_node = ast.Await(self.node)
return new_node


class MakeCallAwait(Rule):
INPUT_SOURCE = """
def somefunc():
call(
arg0,
arg1) # Intentional mis-alignment
"""

EXPECTED_SOURCE = """
def somefunc():
await call(
arg0,
arg1) # Intentional mis-alignment
"""

def match(self, node):
assert isinstance(node, ast.Expr)
assert isinstance(node.value, ast.Call)
return AwaitifierAction(node)


class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule):
context_providers = (context.Scope,)

Expand All @@ -322,7 +353,7 @@ def match(self, node: ast.AST) -> BaseAction | None:
assert any(kw_default is None for kw_default in node.args.kw_defaults)

if isinstance(node, ast.Lambda) and not (
isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load)
isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load)
):
scope = self.context["scope"].resolve(node.body)
scope.definitions.get(node.body.id, [])
Expand All @@ -331,8 +362,8 @@ def match(self, node: ast.AST) -> BaseAction | None:
for stmt in node.body:
for identifier in ast.walk(stmt):
if not (
isinstance(identifier, ast.Name)
and isinstance(identifier.ctx, ast.Load)
isinstance(identifier, ast.Name)
and isinstance(identifier.ctx, ast.Load)
):
continue

Expand Down Expand Up @@ -598,13 +629,13 @@ class DownstreamAnalyzer(Representative):
context_providers = (context.Scope,)

def iter_dependents(
self, name: str, source: ast.Import | ast.ImportFrom
self, name: str, source: ast.Import | ast.ImportFrom
) -> Iterator[ast.Name]:
for node in ast.walk(self.context.tree):
if (
isinstance(node, ast.Name)
and isinstance(node.ctx, ast.Load)
and node.id == name
isinstance(node, ast.Name)
and isinstance(node.ctx, ast.Load)
and node.id == name
):
node_scope = self.context.scope.resolve(node)
definitions = node_scope.get_definitions(name)
Expand Down Expand Up @@ -699,7 +730,7 @@ def match(self, node: ast.AST) -> Iterator[Replace]:

[alias] = aliases
for dependent in self.context.downstream_analyzer.iter_dependents(
alias.asname or alias.name, node
alias.asname or alias.name, node
):
yield Replace(dependent, ast.Name("b", ast.Load()))

Expand Down Expand Up @@ -936,11 +967,38 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]:
new_trys.append(new_try)

first_try, *remaining_trys = new_trys
print(ast.unparse(node))
print(ast.unparse(first_try))
yield Replace(node, first_try)
for remaining_try in reversed(remaining_trys):
yield InsertAfter(node, remaining_try)


class WrapInMultilineFstring(Rule):
INPUT_SOURCE = '''
def f():
return """
a
"""
'''
EXPECTED_SOURCE = '''
def f():
return F("""
a
""")
'''

def match(self, node):
assert isinstance(node, ast.Constant)

# Prevent wrapping F-strings that are already wrapped in F()
# Otherwise you get infinite F(F(F(F(...))))
parent = self.context.ancestry.get_parent(node)
assert not (isinstance(parent, ast.Call) and isinstance(parent.func, ast.Name) and parent.func.id == 'F')

return Replace(node, ast.Call(func=ast.Name(id="F"), args=[node], keywords=[]))


@pytest.mark.parametrize(
"rule",
[
Expand All @@ -949,6 +1007,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]:
PropagateConstants,
TypingAutoImporter,
MakeFunctionAsync,
MakeCallAwait,
OnlyKeywordArgumentDefaultNotSetCheckRule,
InternalizeFunctions,
RemoveDeadCode,
Expand All @@ -957,6 +1016,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]:
PropagateAndDelete,
FoldMyConstants,
AtomicTryBlock,
WrapInMultilineFstring,
],
)
def test_complete_rules(rule, tmp_path):
Expand Down