diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 776725a..3690756 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -97,10 +97,10 @@ def get_locally_installed_packages(): package_import = f.read().strip().split("\n") for item in package_import: if item not in ["tests", "_tests"]: - packages[item] = { - 'version': package[1].replace(".dist", ""), - 'name': package[0] - } + packages[item] = { + 'version': package[1].replace(".dist", ""), + 'name': package[0] + } return packages @@ -113,17 +113,17 @@ def get_import_local(imports): return result def get_pkg_names_from_import_names(pkgs): - result = [] - with open(os.path.join(os.path.dirname(__file__), "mapping"), "r") as f: - data = [x.strip().split(":") for x in f.readlines()] - for pkg in pkgs: - toappend = pkg - for item in data: - if item[0] == pkg: - toappend = item[1] - break - result.append(toappend) - return result + result = [] + with open(os.path.join(os.path.dirname(__file__), "mapping"), "r") as f: + data = [x.strip().split(":") for x in f.readlines()] + for pkg in pkgs: + toappend = pkg + for item in data: + if item[0] == pkg: + toappend = item[1] + break + result.append(toappend) + return result def get_import_name_without_alias(import_name): return import_name.partition(' as ')[0].partition('.')[0].strip() diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index c1548dd..eedcb60 100755 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -68,6 +68,11 @@ class TestPipreqs(unittest.TestCase): for item in self.modules[:-1]: self.assertTrue(item in data) + def test_get_import_name_without_alias(self): + import_name_with_alias = "requests as R" + expected_import_name_without_alias = "requests" + import_name_without_aliases = pipreqs.get_import_name_without_alias(import_name_with_alias) + self.assertEqual(import_name_without_aliases, expected_import_name_without_alias, "The import alias was not correctly stripped") def tearDown(self): try: