Skip to content

Commit 1288f20

Browse files
authored
Narrowing for comparisons against x.__class__ (#20642)
1 parent 0c63401 commit 1288f20

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

mypy/checker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6734,6 +6734,8 @@ def narrow_type_by_identity_equality(
67346734
and len(type_expr.args) == 1
67356735
):
67366736
expr_in_type_expr = type_expr.args[0]
6737+
elif isinstance(type_expr, MemberExpr) and type_expr.name == "__class__":
6738+
expr_in_type_expr = type_expr.expr
67376739
else:
67386740
continue
67396741
for j in expr_indices:

test-data/unit/check-narrowing.test

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3136,3 +3136,31 @@ if type(x) is not C:
31363136
reveal_type(x) # N: Revealed type is "__main__.D"
31373137
else:
31383138
reveal_type(x) # N: Revealed type is "__main__.C"
3139+
3140+
[case testDunderClassNarrowing]
3141+
# flags: --warn-unreachable
3142+
from typing import Any
3143+
3144+
def foo(y: object):
3145+
if y.__class__ == int:
3146+
reveal_type(y) # N: Revealed type is "builtins.int"
3147+
else:
3148+
reveal_type(y) # N: Revealed type is "builtins.object"
3149+
3150+
if y.__class__ is int:
3151+
reveal_type(y) # N: Revealed type is "builtins.int"
3152+
else:
3153+
reveal_type(y) # N: Revealed type is "builtins.object"
3154+
3155+
3156+
def bar(y: Any):
3157+
if y.__class__ == int:
3158+
reveal_type(y) # N: Revealed type is "Any"
3159+
else:
3160+
reveal_type(y) # N: Revealed type is "Any"
3161+
3162+
if y.__class__ is int:
3163+
reveal_type(y) # N: Revealed type is "builtins.int"
3164+
else:
3165+
reveal_type(y) # N: Revealed type is "Any"
3166+
[builtins fixtures/dict-full.pyi]

test-data/unit/fixtures/dict-full.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ KT = TypeVar('KT')
1313
VT = TypeVar('VT')
1414

1515
class object:
16+
__class__: object
1617
def __init__(self) -> None: pass
1718
def __init_subclass__(cls) -> None: pass
1819
def __eq__(self, other: object) -> bool: pass
@@ -75,8 +76,7 @@ class float: pass
7576
class complex: pass
7677
class bool(int): pass
7778

78-
class ellipsis:
79-
__class__: object
79+
class ellipsis: pass
8080
def isinstance(x: object, t: Union[type, Tuple[type, ...]]) -> bool: pass
8181
class BaseException: pass
8282

0 commit comments

Comments
 (0)