From bbb7d967654e6bb4af02aeb5fc37936ed32123f9 Mon Sep 17 00:00:00 2001 From: fernandocrz Date: Fri, 3 Nov 2023 15:57:21 -0300 Subject: [PATCH] Improved code readability Credits to @mateuslatrova for the suggestions. --- pipreqs/pipreqs.py | 31 ++++++++++++++----------------- tests/test_pipreqs.py | 10 +++++----- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 044960a..9426d4c 100644 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -122,25 +122,25 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links dirs[:] = [d for d in dirs if d not in ignore_dirs] candidates.append(os.path.basename(root)) - if PythonExporter and not ignore_notebooks: - files = [fn for fn in files if filter_ext(fn, [".py", ".ipynb"])] + if notebooks_are_enabled(): + files = [fn for fn in files if file_ext_is_allowed(fn, [".py", ".ipynb"])] else: - files = [fn for fn in files if filter_ext(fn, [".py"])] + files = [fn for fn in files if file_ext_is_allowed(fn, [".py"])] candidates = list( map( lambda fn: os.path.splitext(fn)[0], - filter(lambda fn: filter_ext(fn, [".py"]), files), + filter(lambda fn: file_ext_is_allowed(fn, [".py"]), files), ) ) for file_name in files: file_name = os.path.join(root, file_name) contents = "" - if filter_ext(file_name, [".py"]): + if file_ext_is_allowed(file_name, [".py"]): with open(file_name, "r", encoding=encoding) as f: contents = f.read() - elif filter_ext(file_name, [".ipynb"]) and PythonExporter and not ignore_notebooks: + elif file_ext_is_allowed(file_name, [".ipynb"]) and notebooks_are_enabled(): contents = ipynb_2_py(file_name, encoding=encoding) try: tree = ast.parse(contents) @@ -178,7 +178,11 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links return list(packages - data) -def filter_ext(file_name, acceptable): +def notebooks_are_enabled(): + return PythonExporter and not ignore_notebooks + + +def file_ext_is_allowed(file_name, acceptable): return os.path.splitext(file_name)[1] in acceptable @@ -313,7 +317,7 @@ def get_import_local(imports, encoding="utf-8"): # had to use second method instead of the previous one, # because we have a list in the 'exports' field # https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python - result_unique = [i for n, i in enumerate(result) if i not in result[n + 1:]] + result_unique = [i for n, i in enumerate(result) if i not in result[n + 1 :]] return result_unique @@ -500,15 +504,8 @@ def init(args): if extra_ignore_dirs: extra_ignore_dirs = extra_ignore_dirs.split(",") - path = ( - args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt") - ) - if ( - not args["--print"] - and not args["--savepath"] - and not args["--force"] - and os.path.exists(path) - ): + path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt") + if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path): logging.warning("requirements.txt already exists, " "use --force to overwrite it") return diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index da229ba..02faf1d 100644 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -543,13 +543,13 @@ class TestPipreqs(unittest.TestCase): parsed = pipreqs.get_all_imports(self.compatible_files["notebook"], encoding="utf-8") self.assertEqual(expected, parsed) - def test_filter_ext(self): + def test_file_ext_is_allowed(self): """ - Test the function filter_ext() + Test the function file_ext_is_allowed() """ - self.assertTrue(pipreqs.filter_ext("main.py", [".py"])) - self.assertTrue(pipreqs.filter_ext("main.py", [".py", ".ipynb"])) - self.assertFalse(pipreqs.filter_ext("main.py", [".ipynb"])) + self.assertTrue(pipreqs.file_ext_is_allowed("main.py", [".py"])) + self.assertTrue(pipreqs.file_ext_is_allowed("main.py", [".py", ".ipynb"])) + self.assertFalse(pipreqs.file_ext_is_allowed("main.py", [".ipynb"])) def test_parse_requirements(self): """