Improved code readability

Credits to @mateuslatrova for the suggestions.
This commit is contained in:
fernandocrz 2023-11-03 15:57:21 -03:00
parent 861c072aa3
commit bbb7d96765
2 changed files with 19 additions and 22 deletions

View File

@ -122,25 +122,25 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
dirs[:] = [d for d in dirs if d not in ignore_dirs]
candidates.append(os.path.basename(root))
if PythonExporter and not ignore_notebooks:
files = [fn for fn in files if filter_ext(fn, [".py", ".ipynb"])]
if notebooks_are_enabled():
files = [fn for fn in files if file_ext_is_allowed(fn, [".py", ".ipynb"])]
else:
files = [fn for fn in files if filter_ext(fn, [".py"])]
files = [fn for fn in files if file_ext_is_allowed(fn, [".py"])]
candidates = list(
map(
lambda fn: os.path.splitext(fn)[0],
filter(lambda fn: filter_ext(fn, [".py"]), files),
filter(lambda fn: file_ext_is_allowed(fn, [".py"]), files),
)
)
for file_name in files:
file_name = os.path.join(root, file_name)
contents = ""
if filter_ext(file_name, [".py"]):
if file_ext_is_allowed(file_name, [".py"]):
with open(file_name, "r", encoding=encoding) as f:
contents = f.read()
elif filter_ext(file_name, [".ipynb"]) and PythonExporter and not ignore_notebooks:
elif file_ext_is_allowed(file_name, [".ipynb"]) and notebooks_are_enabled():
contents = ipynb_2_py(file_name, encoding=encoding)
try:
tree = ast.parse(contents)
@ -178,7 +178,11 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
return list(packages - data)
def filter_ext(file_name, acceptable):
def notebooks_are_enabled():
return PythonExporter and not ignore_notebooks
def file_ext_is_allowed(file_name, acceptable):
return os.path.splitext(file_name)[1] in acceptable
@ -313,7 +317,7 @@ def get_import_local(imports, encoding="utf-8"):
# had to use second method instead of the previous one,
# because we have a list in the 'exports' field
# https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python
result_unique = [i for n, i in enumerate(result) if i not in result[n + 1:]]
result_unique = [i for n, i in enumerate(result) if i not in result[n + 1 :]]
return result_unique
@ -500,15 +504,8 @@ def init(args):
if extra_ignore_dirs:
extra_ignore_dirs = extra_ignore_dirs.split(",")
path = (
args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
)
if (
not args["--print"]
and not args["--savepath"]
and not args["--force"]
and os.path.exists(path)
):
path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path):
logging.warning("requirements.txt already exists, " "use --force to overwrite it")
return

View File

@ -543,13 +543,13 @@ class TestPipreqs(unittest.TestCase):
parsed = pipreqs.get_all_imports(self.compatible_files["notebook"], encoding="utf-8")
self.assertEqual(expected, parsed)
def test_filter_ext(self):
def test_file_ext_is_allowed(self):
"""
Test the function filter_ext()
Test the function file_ext_is_allowed()
"""
self.assertTrue(pipreqs.filter_ext("main.py", [".py"]))
self.assertTrue(pipreqs.filter_ext("main.py", [".py", ".ipynb"]))
self.assertFalse(pipreqs.filter_ext("main.py", [".ipynb"]))
self.assertTrue(pipreqs.file_ext_is_allowed("main.py", [".py"]))
self.assertTrue(pipreqs.file_ext_is_allowed("main.py", [".py", ".ipynb"]))
self.assertFalse(pipreqs.file_ext_is_allowed("main.py", [".ipynb"]))
def test_parse_requirements(self):
"""