diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index d21b071..5cf763e 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -120,7 +120,7 @@ def get_all_imports( with open(join("stdlib"), "r") as f: data = [x.strip() for x in f.readlines()] data = [x for x in data if x not in py2_exclude] if py2 else data - return sorted(list(set(packages) - set(data))) + return list(set(packages) - set(data)) def filter_line(l): @@ -224,18 +224,24 @@ def get_import_local(imports, encoding=None): def get_pkg_names(pkgs): - result = [] + """Get PyPI package names from a list of imports. + + Args: + pkgs (List[str]): List of import names. + + Returns: + List[str]: The corresponding PyPI package names. + + """ + result = set() with open(join("mapping"), "r") as f: - data = [x.strip().split(":") for x in f.readlines()] - for pkg in pkgs: - toappend = pkg - for item in data: - if item[0] == pkg: - toappend = item[1] - break - if toappend not in result: - result.append(toappend) - return result + data = dict(x.strip().split(":") for x in f) + for pkg in pkgs: + # Look up the mapped requirement. If a mapping isn't found, + # simply use the package name. + result.add(data.get(pkg, pkg)) + # Return a sorted list for backward compatibility. + return sorted(result) def get_name_without_alias(name):