Improved code readability

Credits to @mateuslatrova for the suggestions.
This commit is contained in:
fernandocrz 2023-11-03 15:57:21 -03:00
parent 861c072aa3
commit bbb7d96765
2 changed files with 19 additions and 22 deletions

View File

@ -122,25 +122,25 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
dirs[:] = [d for d in dirs if d not in ignore_dirs] dirs[:] = [d for d in dirs if d not in ignore_dirs]
candidates.append(os.path.basename(root)) candidates.append(os.path.basename(root))
if PythonExporter and not ignore_notebooks: if notebooks_are_enabled():
files = [fn for fn in files if filter_ext(fn, [".py", ".ipynb"])] files = [fn for fn in files if file_ext_is_allowed(fn, [".py", ".ipynb"])]
else: else:
files = [fn for fn in files if filter_ext(fn, [".py"])] files = [fn for fn in files if file_ext_is_allowed(fn, [".py"])]
candidates = list( candidates = list(
map( map(
lambda fn: os.path.splitext(fn)[0], lambda fn: os.path.splitext(fn)[0],
filter(lambda fn: filter_ext(fn, [".py"]), files), filter(lambda fn: file_ext_is_allowed(fn, [".py"]), files),
) )
) )
for file_name in files: for file_name in files:
file_name = os.path.join(root, file_name) file_name = os.path.join(root, file_name)
contents = "" contents = ""
if filter_ext(file_name, [".py"]): if file_ext_is_allowed(file_name, [".py"]):
with open(file_name, "r", encoding=encoding) as f: with open(file_name, "r", encoding=encoding) as f:
contents = f.read() contents = f.read()
elif filter_ext(file_name, [".ipynb"]) and PythonExporter and not ignore_notebooks: elif file_ext_is_allowed(file_name, [".ipynb"]) and notebooks_are_enabled():
contents = ipynb_2_py(file_name, encoding=encoding) contents = ipynb_2_py(file_name, encoding=encoding)
try: try:
tree = ast.parse(contents) tree = ast.parse(contents)
@ -178,7 +178,11 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
return list(packages - data) return list(packages - data)
def filter_ext(file_name, acceptable): def notebooks_are_enabled():
return PythonExporter and not ignore_notebooks
def file_ext_is_allowed(file_name, acceptable):
return os.path.splitext(file_name)[1] in acceptable return os.path.splitext(file_name)[1] in acceptable
@ -313,7 +317,7 @@ def get_import_local(imports, encoding="utf-8"):
# had to use second method instead of the previous one, # had to use second method instead of the previous one,
# because we have a list in the 'exports' field # because we have a list in the 'exports' field
# https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python # https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python
result_unique = [i for n, i in enumerate(result) if i not in result[n + 1:]] result_unique = [i for n, i in enumerate(result) if i not in result[n + 1 :]]
return result_unique return result_unique
@ -500,15 +504,8 @@ def init(args):
if extra_ignore_dirs: if extra_ignore_dirs:
extra_ignore_dirs = extra_ignore_dirs.split(",") extra_ignore_dirs = extra_ignore_dirs.split(",")
path = ( path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt") if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path):
)
if (
not args["--print"]
and not args["--savepath"]
and not args["--force"]
and os.path.exists(path)
):
logging.warning("requirements.txt already exists, " "use --force to overwrite it") logging.warning("requirements.txt already exists, " "use --force to overwrite it")
return return

View File

@ -543,13 +543,13 @@ class TestPipreqs(unittest.TestCase):
parsed = pipreqs.get_all_imports(self.compatible_files["notebook"], encoding="utf-8") parsed = pipreqs.get_all_imports(self.compatible_files["notebook"], encoding="utf-8")
self.assertEqual(expected, parsed) self.assertEqual(expected, parsed)
def test_filter_ext(self): def test_file_ext_is_allowed(self):
""" """
Test the function filter_ext() Test the function file_ext_is_allowed()
""" """
self.assertTrue(pipreqs.filter_ext("main.py", [".py"])) self.assertTrue(pipreqs.file_ext_is_allowed("main.py", [".py"]))
self.assertTrue(pipreqs.filter_ext("main.py", [".py", ".ipynb"])) self.assertTrue(pipreqs.file_ext_is_allowed("main.py", [".py", ".ipynb"]))
self.assertFalse(pipreqs.filter_ext("main.py", [".ipynb"])) self.assertFalse(pipreqs.file_ext_is_allowed("main.py", [".ipynb"]))
def test_parse_requirements(self): def test_parse_requirements(self):
""" """