diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index d21b071..c0f9c55 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -33,6 +33,7 @@ Options: that are not imported in project. """ from __future__ import print_function, absolute_import +from contextlib import contextmanager import os import sys import re @@ -61,6 +62,37 @@ else: py2_exclude = ["concurrent", "concurrent.futures"] +@contextmanager +def _open(filename=None, mode='r'): + """Open a file or ``sys.stdout`` depending on the provided filename. + + Args: + filename (str): The path to the file that should be opened. If + ``None`` or ``'-'``, ``sys.stdout`` or ``sys.stdin`` is + returned depending on the desired mode. Defaults to ``None``. + mode (str): The mode that should be used to open the file. + + Yields: + A file handle. + + """ + if not filename or filename == '-': + if not mode or 'r' in mode: + file = sys.stdin + elif 'w' in mode: + file = sys.stdout + else: + raise ValueError('Invalid mode for file: {}'.format(mode)) + else: + file = open(filename, mode) + + try: + yield file + finally: + if file not in (sys.stdin, sys.stdout): + file.close() + + def get_all_imports( path, encoding=None, extra_ignore_dirs=None, follow_links=True): imports = set() @@ -128,7 +160,7 @@ def filter_line(l): def generate_requirements_file(path, imports): - with open(path, "w") as out_file: + with _open(path, "w") as out_file: logging.debug('Writing {num} requirements: {imports} to {file}'.format( num=len(imports), file=path, @@ -141,14 +173,7 @@ def generate_requirements_file(path, imports): def output_requirements(imports): - logging.debug('Writing {num} requirements: {imports} to stdout'.format( - num=len(imports), - imports=", ".join([x['name'] for x in imports]) - )) - fmt = '{name}=={version}' - print('\n'.join( - fmt.format(**item) if item['version'] else '{name}'.format(**item) - for item in imports)) + generate_requirements_file('-', imports) def get_imports_info(