add support for .pyw files

Now, pipreqs will also scan imports in .pyw files by default.
This commit is contained in:
mateuslatrova 2023-12-06 15:34:59 -03:00 committed by Alan Barzilay
parent de68691438
commit 4a9176b39a
4 changed files with 51 additions and 4 deletions

View File

@ -52,7 +52,7 @@ from yarg.exceptions import HTTPError
from pipreqs import __version__
REGEXP = [re.compile(r"^import (.+)$"), re.compile(r"^from ((?!\.+).*?) import (?:.*)$")]
DEFAULT_EXTENSIONS = [".py", ".pyw"]
scan_noteboooks = False
@ -126,7 +126,7 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
dirs[:] = [d for d in dirs if d not in ignore_dirs]
candidates.append(os.path.basename(root))
py_files = [file for file in files if file_ext_is_allowed(file, [".py"])]
py_files = [file for file in files if file_ext_is_allowed(file, DEFAULT_EXTENSIONS)]
candidates.extend([os.path.splitext(filename)[0] for filename in py_files])
files = [fn for fn in files if file_ext_is_allowed(fn, extensions)]
@ -172,11 +172,11 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
def get_file_extensions():
return [".py", ".ipynb"] if scan_noteboooks else [".py"]
return DEFAULT_EXTENSIONS + [".ipynb"] if scan_noteboooks else DEFAULT_EXTENSIONS
def read_file_content(file_name: str, encoding="utf-8"):
if file_ext_is_allowed(file_name, [".py"]):
if file_ext_is_allowed(file_name, DEFAULT_EXTENSIONS):
with open(file_name, "r", encoding=encoding) as f:
contents = f.read()
elif file_ext_is_allowed(file_name, [".ipynb"]) and scan_noteboooks:

5
tests/_data_pyw/py.py Normal file
View File

@ -0,0 +1,5 @@
import airflow
import numpy
airflow
numpy

3
tests/_data_pyw/pyw.pyw Normal file
View File

@ -0,0 +1,3 @@
import matplotlib
import pandas
import tensorflow

View File

@ -629,6 +629,45 @@ class TestPipreqs(unittest.TestCase):
assert os.path.exists(notebook_requirement_path) == 1
assert os.path.getsize(notebook_requirement_path) == 1 # file only has a "\n", meaning it's empty
def test_pipreqs_get_imports_from_pyw_file(self):
pyw_test_dirpath = os.path.join(os.path.dirname(__file__), "_data_pyw")
requirements_path = os.path.join(pyw_test_dirpath, "requirements.txt")
pipreqs.init(
{
"<path>": pyw_test_dirpath,
"--savepath": None,
"--print": False,
"--use-local": None,
"--force": True,
"--proxy": None,
"--pypi-server": None,
"--diff": None,
"--clean": None,
"--mode": None,
}
)
self.assertTrue(os.path.exists(requirements_path))
expected_imports = [
"airflow",
"matplotlib",
"numpy",
"pandas",
"tensorflow",
]
with open(requirements_path, "r") as f:
imports_data = f.read().lower()
for _import in expected_imports:
self.assertTrue(
_import.lower() in imports_data,
f"'{_import}' import was expected but not found.",
)
os.remove(requirements_path)
def mock_scan_notebooks(self):
pipreqs.scan_noteboooks = Mock(return_value=True)
pipreqs.handle_scan_noteboooks()