Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions Lib/test/support/import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,45 @@ def _save_and_remove_modules(names):
return orig_modules


_PARENT_ATTR_MISSING = object()


def _save_parent_attrs(names):
parent_attrs = {}
prefixes = tuple(name + '.' for name in names)
removed = {modname for modname in sys.modules
if modname in names or modname.startswith(prefixes)}
for modname in set(names) | removed:
parent_name, _, attr = modname.rpartition('.')
if not parent_name:
continue
if parent_name in removed or parent_name.startswith(prefixes):
continue
parent = sys.modules.get(parent_name)
if parent is None:
continue
try:
value = getattr(parent, attr)
except AttributeError:
value = _PARENT_ATTR_MISSING
parent_attrs[parent_name, attr] = value
return parent_attrs


def _restore_parent_attrs(parent_attrs):
for (parent_name, attr), value in parent_attrs.items():
parent = sys.modules.get(parent_name)
if parent is None:
continue
if value is _PARENT_ATTR_MISSING:
try:
delattr(parent, attr)
except AttributeError:
pass
else:
setattr(parent, attr, value)


@contextlib.contextmanager
def frozen_modules(enabled=True):
"""Force frozen modules to be used (or not).
Expand Down Expand Up @@ -179,6 +218,7 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
fresh = list(fresh)
blocked = list(blocked)
names = {name, *fresh, *blocked}
orig_parent_attrs = _save_parent_attrs(names)
orig_modules = _save_and_remove_modules(names)
for modname in blocked:
sys.modules[modname] = None
Expand All @@ -195,6 +235,7 @@ def import_fresh_module(name, fresh=(), blocked=(), *,
finally:
_save_and_remove_modules(names)
sys.modules.update(orig_modules)
_restore_parent_attrs(orig_parent_attrs)


class CleanImport(object):
Expand Down
42 changes: 42 additions & 0 deletions Lib/test/test_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,48 @@ def test_import_module(self):
def test_import_fresh_module(self):
import_helper.import_fresh_module("ftplib")

def test_import_fresh_module_restores_parent_attr(self):
import importlib.util

name = "importlib.util"
original_module = sys.modules[name]
self.assertIs(importlib.util, original_module)

fresh_module = import_helper.import_fresh_module(name)

self.assertIsNot(fresh_module, original_module)
self.assertIs(sys.modules[name], original_module)
self.assertIs(importlib.util, original_module)
self.assertIs(importlib.util, sys.modules[name])

def test_import_fresh_module_removes_added_parent_attr(self):
import xml

name = "xml.sax"
original_module = sys.modules.pop(name, None)
original_attr = getattr(xml, "sax", None)
had_attr = hasattr(xml, "sax")
if had_attr:
del xml.sax
try:
self.assertFalse(hasattr(xml, "sax"))

fresh_module = import_helper.import_fresh_module(name)

self.assertIsNotNone(fresh_module)
self.assertNotIn(name, sys.modules)
self.assertFalse(hasattr(xml, "sax"))
finally:
if original_module is not None:
sys.modules[name] = original_module
if had_attr:
xml.sax = original_attr
else:
try:
del xml.sax
except AttributeError:
pass

def test_get_attribute(self):
self.assertEqual(support.get_attribute(self, "test_get_attribute"),
self.test_get_attribute)
Expand Down
Loading