diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index ace2967..115edc0 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -28,46 +28,44 @@ REGEXP = [ ] -def get_all_imports(start_path): +def get_all_imports(path): imports = [] - packages = [] - logging.debug('Traversing tree, start: {0}'.format(start_path)) - for root, dirs, files in os.walk(start_path): - packages.append(os.path.basename(root)) + candidates = [] + + for root, dirs, files in os.walk(path): + candidates.append(os.path.basename(root)) files = [fn for fn in files if os.path.splitext(fn)[1] == ".py"] - packages += [os.path.splitext(fn)[0] for fn in files] + candidates += [os.path.splitext(fn)[0] for fn in files] for file_name in files: - with open(os.path.join(root, file_name), "r") as file_object: - lines = filter( - lambda l: len(l) > 0, map(lambda l: l.strip(), file_object)) + with open(os.path.join(root, file_name), "r") as f: + lines = filter(filter_line, map(lambda l: l.strip(), f)) for line in lines: - if line[0] == "#": - continue if "(" in line: break for rex in REGEXP: - s = rex.match(line) - if not s: - continue - for item in s.groups(): - if "," in item: - for match in item.split(","): - imports.append(get_import_name_without_alias(match)) - else: - imports.append(get_import_name_without_alias(item)) - third_party_packages = set(imports) - set(set(packages) & set(imports)) - logging.debug( - 'Found third-party packages: {0}'.format(third_party_packages)) - with open(os.path.join(os.path.dirname(__file__), "stdlib"), "r") as f: + s = rex.findall(line) + for item in s: + res = map(get_name_without_alias, item.split(",")) + imports = [x for x in imports + res if len(x) > 0] + + packages = set(imports) - set(set(candidates) & set(imports)) + logging.debug('Found packages: {0}'.format(packages)) + + with open(join("stdlib"), "r") as f: data = [x.strip() for x in f.readlines()] - return sorted(list(set(third_party_packages) - set(data))) + return sorted(list(set(packages) - set(data))) + + +def filter_line(l): + return len(l) > 0 and l[0] != "#" def generate_requirements_file(path, imports): with open(path, "w") as out_file: - logging.debug('Writing {num} requirements to {file}'.format( + logging.debug('Writing {num} requirements: {imports} to {file}'.format( num=len(imports), - file=path + file=path, + imports=", ".join([x['name'] for x in imports]) )) fmt = '{name} == {version}' out_file.write('\n'.join(fmt.format(**item) @@ -80,17 +78,16 @@ def get_imports_info(imports): try: data = yarg.get(item) except HTTPError: - logging.debug('Package does not exist or network problems') + logging.debug( + 'Package %s does not exist or network problems', item) continue - last_release = data.latest_release_id - result.append({'name': item, 'version': last_release}) + result.append({'name': item, 'version': data.latest_release_id}) return result def get_locally_installed_packages(): - path = get_python_lib() packages = {} - for root, dirs, files in os.walk(path): + for root, dirs, files in os.walk(get_python_lib()): for item in files: if "top_level" in item: with open(os.path.join(root, item), "r") as f: @@ -114,9 +111,9 @@ def get_import_local(imports): return result -def get_pkg_names_from_import_names(pkgs): +def get_pkg_names(pkgs): result = [] - with open(os.path.join(os.path.dirname(__file__), "mapping"), "r") as f: + with open(join("mapping"), "r") as f: data = [x.strip().split(":") for x in f.readlines()] for pkg in pkgs: toappend = pkg @@ -128,33 +125,36 @@ def get_pkg_names_from_import_names(pkgs): return result -def get_import_name_without_alias(import_name): - return import_name.partition(' as ')[0].partition('.')[0].strip() +def get_name_without_alias(name): + if "import" in name: + name = REGEXP[0].match(name.strip()).groups(0)[0] + return name.partition(' as ')[0].partition('.')[0].strip() + + +def join(f): + return os.path.join(os.path.dirname(__file__), f) def init(args): - print("Looking for imports") - imports = get_all_imports(args['']) - imports = get_pkg_names_from_import_names(imports) - print("Found third-party imports: " + ", ".join(imports)) + candidates = get_all_imports(args['']) + candidates = get_pkg_names(get_all_imports(args[''])) + logging.debug("Found imports: " + ", ".join(candidates)) if args['--use-local']: - print( - "Getting package version information ONLY from local installation.") - imports_with_info = get_import_local(imports) + logging.debug( + "Getting package information ONLY from local installation.") + imports = get_import_local(candidates) else: - print( - "Getting latest version information about packages from Local/PyPI") - imports_local = get_import_local(imports) - difference = [x for x in imports if x not in [z['name'] - for z in imports_local]] - imports_pypi = get_imports_info(difference) - imports_with_info = imports_local + imports_pypi - print("Imports written to requirements file:", ", ".join( - [x['name'] for x in imports_with_info])) + logging.debug("Getting packages information from Local/PyPI") + local = get_import_local(candidates) + # Get packages that were not found locally + difference = [x for x in candidates if x not in [z['name'] + for z in local]] + imports = local + get_imports_info(difference) + path = args[ "--savepath"] if args["--savepath"] else os.path.join(args[''], "requirements.txt") - generate_requirements_file(path, imports_with_info) + generate_requirements_file(path, imports) print("Successfully saved requirements file in " + path) diff --git a/tests/_data/test.py b/tests/_data/test.py index 0634a4b..0c8dfb9 100644 --- a/tests/_data/test.py +++ b/tests/_data/test.py @@ -11,6 +11,7 @@ import signal import bs4 import requests import nonexistendmodule +import boto as b, import peewee as p, # import django import flask.ext.somext from sqlalchemy import model @@ -23,3 +24,5 @@ import models def main(): pass + +import after_method_should_be_ignored diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index d9139f5..47d49c7 100755 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -18,7 +18,7 @@ class TestPipreqs(unittest.TestCase): def setUp(self): self.modules = ['flask', 'requests', 'sqlalchemy', - 'docopt', 'ujson', 'nonexistendmodule', 'bs4'] + 'docopt', 'boto', 'peewee', 'ujson', 'nonexistendmodule', 'bs4',] self.modules2 = ['beautifulsoup4'] self.project = os.path.join(os.path.dirname(__file__), "_data") self.requirements_path = os.path.join(self.project, "requirements.txt") @@ -27,7 +27,7 @@ class TestPipreqs(unittest.TestCase): def test_get_all_imports(self): imports = pipreqs.get_all_imports(self.project) - self.assertEqual(len(imports), 7, "Incorrect Imports array length") + self.assertEqual(len(imports), 9) for item in imports: self.assertTrue( item.lower() in self.modules, "Import is missing: " + item) @@ -43,7 +43,7 @@ class TestPipreqs(unittest.TestCase): with_info = pipreqs.get_imports_info(imports) # Should contain only 5 Elements without the "nonexistendmodule" self.assertEqual( - len(with_info), 5, "Length of imports array with info is wrong") + len(with_info), 7) for item in with_info: self.assertTrue(item['name'].lower( ) in self.modules, "Import item appears to be missing " + item['name']) @@ -84,7 +84,7 @@ class TestPipreqs(unittest.TestCase): def test_get_import_name_without_alias(self): import_name_with_alias = "requests as R" expected_import_name_without_alias = "requests" - import_name_without_aliases = pipreqs.get_import_name_without_alias(import_name_with_alias) + import_name_without_aliases = pipreqs.get_name_without_alias(import_name_with_alias) self.assertEqual(import_name_without_aliases, expected_import_name_without_alias) def tearDown(self):