add test for "compare_modules" function

This commit is contained in:
Mateus Latrova 2023-09-28 20:28:32 -03:00
parent 6ac4357cf4
commit ed46d270e9
3 changed files with 35 additions and 0 deletions

0
tests/_data/empty.txt Normal file
View File

3
tests/_data/imports.txt Normal file
View File

@ -0,0 +1,3 @@
pandas==2.0.0
numpy>=1.2.3
torch<4.0.0

View File

@ -39,6 +39,13 @@ class TestPipreqs(unittest.TestCase):
self.project = os.path.join(os.path.dirname(__file__), "_data")
self.project_clean = os.path.join(os.path.dirname(__file__), "_data_clean")
self.project_invalid = os.path.join(os.path.dirname(__file__), "_invalid_data")
self.parsed_packages = [
{"name": "pandas", "version": "2.0.0"},
{"name": "numpy", "version": "1.2.3"},
{"name": "torch", "version": "4.0.0"},
]
self.empty_filepath = os.path.join(self.project, "empty.txt")
self.imports_filepath = os.path.join(self.project, "imports.txt")
self.project_with_ignore_directory = os.path.join(
os.path.dirname(__file__), "_data_ignore"
)
@ -427,6 +434,31 @@ class TestPipreqs(unittest.TestCase):
data = f.read().lower()
self.assertTrue(cleaned_module not in data)
def test_compare_modules(self):
test_cases = [
(self.empty_filepath, [], set()), # both empty
(self.empty_filepath, self.parsed_packages, set()), # only file empty
(
self.imports_filepath,
[],
set(package["name"] for package in self.parsed_packages),
), # only imports empty
(self.imports_filepath, self.parsed_packages, set()), # no difference
(
self.imports_filepath,
self.parsed_packages[1:],
set([self.parsed_packages[0]["name"]]),
), # common case
]
for test_case in test_cases:
with self.subTest(test_case):
filename, imports, expected_modules_not_imported = test_case
modules_not_imported = pipreqs.compare_modules(filename, imports)
self.assertSetEqual(modules_not_imported, expected_modules_not_imported)
def tearDown(self):
"""
Remove requiremnts.txt files that were written