diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 5cf763e..9a466f3 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -87,22 +87,22 @@ def get_all_imports( file_name = os.path.join(root, file_name) with open_func(file_name, "r", encoding=encoding) as f: contents = f.read() - try: - tree = ast.parse(contents) - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for subnode in node.names: - raw_imports.add(subnode.name) - elif isinstance(node, ast.ImportFrom): - raw_imports.add(node.module) - except Exception as exc: - if ignore_errors: - traceback.print_exc(exc) - logging.warn("Failed on file: %s" % file_name) - continue - else: - logging.error("Failed on file: %s" % file_name) - raise exc + try: + tree = ast.parse(contents) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for subnode in node.names: + raw_imports.add(subnode.name) + elif isinstance(node, ast.ImportFrom): + raw_imports.add(node.module) + except Exception as exc: + if ignore_errors: + traceback.print_exc(exc) + logging.warn("Failed on file: %s" % file_name) + continue + else: + logging.error("Failed on file: %s" % file_name) + raise exc # Clean up imports for name in [n for n in raw_imports if n]: @@ -114,13 +114,14 @@ def get_all_imports( cleaned_name, _, _ = name.partition('.') imports.add(cleaned_name) - packages = set(imports) - set(set(candidates) & set(imports)) + packages = imports - (set(candidates) & imports) logging.debug('Found packages: {0}'.format(packages)) with open(join("stdlib"), "r") as f: - data = [x.strip() for x in f.readlines()] - data = [x for x in data if x not in py2_exclude] if py2 else data - return list(set(packages) - set(data)) + data = {x.strip() for x in f} + + data = {x for x in data if x not in py2_exclude} if py2 else data + return list(packages - data) def filter_line(l):