diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 44855f49afaf..3ad253597df7 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3801,7 +3801,17 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: if not encountered_partial_type and not failed_out: iterable_type = UnionType.make_union(iterable_types) - if not is_subtype(left_type, iterable_type): + matches_iterable_item = False + if iterable_types: + left_item_type = remove_instance_last_known_values(left_type) + erased_iterable_type = remove_instance_last_known_values(iterable_type) + matches_iterable_item = is_subtype( + left_item_type, erased_iterable_type + ) or ( + not is_literal_type_like(left_type) + and is_subtype(erased_iterable_type, left_item_type) + ) + if not matches_iterable_item: if not container_types: self.msg.unsupported_operand_types("in", left_type, right_type, e) else: diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 7123285e5eca..a5bde3fa1346 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -412,6 +412,20 @@ class D(Iterable[A]): def __iter__(self) -> Iterator[A]: pass [builtins fixtures/bool.pyi] +[case testInOperatorWithOverlappingIterableItemType] +from typing import Iterator + +class Number(int): pass + +def numbers() -> Iterator[Number]: + yield Number() + +1 in numbers() +Number() in numbers() +'' in numbers() # E: Unsupported operand types for in ("str" and "Iterator[Number]") +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + [case testNotInOperator] from typing import Iterator, Iterable, Any a: A