diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 29bdd13..b2d63e2 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -42,8 +42,11 @@ else: open_func = codecs.open +import ast, traceback + def get_all_imports(path, encoding=None): - imports = [] + imports = set() + raw_imports = set() candidates = [] ignore_dirs = [".hg", ".svn", ".git", "__pycache__", "env", "venv"] @@ -56,19 +59,27 @@ def get_all_imports(path, encoding=None): candidates += [os.path.splitext(fn)[0] for fn in files] for file_name in files: with open_func(os.path.join(root, file_name), "r", encoding=encoding) as f: - contents = re.sub(re.compile("'''.+?'''", re.DOTALL), '', f.read()) - contents = re.sub(re.compile('""".+?"""', re.DOTALL), "", contents) - lines = contents.split("\n") - lines = filter( - filter_line, map(lambda l: l.partition("#")[0].strip(), lines)) - for line in lines: - if "(" in line: - break - for rex in REGEXP: - s = rex.findall(line) - for item in s: - res = map(get_name_without_alias, item.split(",")) - imports = imports + [x for x in res if len(x) > 0] + contents = f.read() + try: + tree = ast.parse(contents) + except Exception, e: + traceback.print_exc(e) + print("Failed on file: %s" % os.path.join(root, file_name)) + exit(1) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for subnode in node.names: + raw_imports.add(subnode.name) + elif isinstance(node, ast.ImportFrom): + raw_imports.add(node.module) + + # Clean up imports + for name in [n for n in raw_imports if n]: + # Sanity check: Name could have been None if the import statement was as from . import X + # Cleanup: We only want to first part of the import. + # Ex: from django.conf --> django.conf. But we only want django as an import + cleaned_name, _, _ = name.partition('.') + imports.add(cleaned_name) packages = set(imports) - set(set(candidates) & set(imports)) logging.debug('Found packages: {0}'.format(packages))