diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index b622103..cf7b14e 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -51,11 +51,8 @@ def get_all_imports(start_path): if "," in item: for match in item.split(","): imports.append(match.strip()) - elif " as " in item: - to_append = item.split(" as ")[0] - imports.append(to_append.strip()) else: - to_append = item if "." not in item else item.split(".")[0] + to_append = item.partition(' as ')[0].partition('.')[0] imports.append(to_append.strip()) third_party_packages = set(imports) - set(set(packages) & set(imports)) logging.debug('Found third-party packages: %s', third_party_packages) @@ -65,11 +62,10 @@ def get_all_imports(start_path): def generate_requirements_file(path, imports): - with open(path, "w") as ff: - logging.debug('Writing requirements to file %s', path) - for item in imports: - ff.write(item['name'] + "==" + item['version']) - ff.write("\n") + with open(path, "w") as out_file: + logging.debug('Writing %d requirements to file %s', (len(imports), path)) + fmt = '{name} == {version}' + out_file.write('\n'.join(fmt.format(**item) for item in imports) + '\n') def get_imports_info(imports):