diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 03b3961..7d33bf9 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -13,45 +13,58 @@ Options: import os, re, logging from docopt import docopt import yarg +from yarg.exceptions import HTTPError -REGEXP = [re.compile(r'import ([a-zA-Z123456789]+)'),re.compile(r'from (.*?) import (?:.*)')] +REGEXP = [ + re.compile(r'^import (.+)$'), + re.compile(r'from (.*?) import (?:.*)') +] def get_all_imports(start_path): imports = [] packages = [] for root, dirs, files in os.walk(start_path): - path = root.split('/') - packages.append(os.path.basename(root)) - for file in files: - if file[-3:] != ".py": - continue - for rex in REGEXP: - with open(root + "/" + file, "r") as f: - s = rex.match(f.read()) - if s: - for item in s.groups(): - if "." in item: - imports.append(item.split(".")[0]) - else: - imports.append(item) - local_packages = list(set(packages) & set(imports)) - third_party_packages = list(set(imports) - set(local_packages)) - with open(os.path.dirname(__file__)+"/stdlib", "r") as f: + path = root.split('/') + packages.append(os.path.basename(root)) + for file in files: + if file[-3:] != ".py": + continue + for rex in REGEXP: + with open(os.path.join(root,file), "r") as f: + lines = f.readlines() + for line in lines: + if line[0] == "#": + continue + if "(" in line: + break + s = rex.match(line) + if not s: + continue + for item in s.groups(): + if "," in item: + for match in item.split(","): + imports.append(match.strip()) + else: + to_append = item if "." not in item else item.split(".")[0] + imports.append(to_append.strip()) + third_party_packages = set(imports) - set(set(packages) & set(imports)) + with open(os.path.join(os.path.dirname(__file__), "stdlib"), "r") as f: data = [x.strip() for x in f.readlines()] return list(set(third_party_packages) - set(data)) def generate_requirements_file(path, imports): with open(path, "w") as ff: for item in imports: - ff.write(item['name']) - ff.write("==") - ff.write(item['version']) + ff.write(item['name'] + "==" + item['version']) ff.write("\n") def get_imports_info(imports): result = [] for item in imports: - data = yarg.get(item) + try: + data = yarg.get(item) + except HTTPError: + continue if not data or len(data.release_ids) < 1: continue last_release = data.release_ids[-1]