diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 4778c13..776725a 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -27,7 +27,6 @@ REGEXP = [ re.compile(r'^from ((?!\.+).*?) import (?:.*)$') ] - def get_all_imports(start_path): imports = [] packages = [] @@ -52,11 +51,9 @@ def get_all_imports(start_path): for item in s.groups(): if "," in item: for match in item.split(","): - imports.append(match.strip()) + imports.append(get_import_name_without_alias(match)) else: - to_append = item.partition( - ' as ')[0].partition('.')[0] - imports.append(to_append.strip()) + imports.append(get_import_name_without_alias(item)) third_party_packages = set(imports) - set(set(packages) & set(imports)) logging.debug( 'Found third-party packages: {0}'.format(third_party_packages)) @@ -128,6 +125,8 @@ def get_pkg_names_from_import_names(pkgs): result.append(toappend) return result +def get_import_name_without_alias(import_name): + return import_name.partition(' as ')[0].partition('.')[0].strip() def init(args): print("Looking for imports")