Add support for jupyter notebook

Commit rebased from https://github.com/bndr/pipreqs/pull/210 .
Credits to: pakio https://github.com/pakio

fix coverage report bug

updated nbconvert version

fixed lint issues

add black formatting rules
This commit is contained in:
pakiosann@gmail.com 2020-07-02 15:00:31 +09:00 committed by fernandocrz
parent 55eee298ec
commit c82c7203c3
9 changed files with 367 additions and 109 deletions

View File

@ -39,7 +39,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install coverage docopt yarg requests pip install coverage docopt yarg requests nbconvert
- name: Calculate coverage - name: Calculate coverage
run: coverage run --source=pipreqs -m unittest discover run: coverage run --source=pipreqs -m unittest discover

View File

@ -47,17 +47,18 @@ from docopt import docopt
import requests import requests
from yarg import json2package from yarg import json2package
from yarg.exceptions import HTTPError from yarg.exceptions import HTTPError
from nbconvert import PythonExporter
from pipreqs import __version__ from pipreqs import __version__
REGEXP = [ REGEXP = [
re.compile(r'^import (.+)$'), re.compile(r"^import (.+)$"),
re.compile(r'^from ((?!\.+).*?) import (?:.*)$') re.compile(r"^from ((?!\.+).*?) import (?:.*)$"),
] ]
@contextmanager @contextmanager
def _open(filename=None, mode='r'): def _open(filename=None, mode="r"):
"""Open a file or ``sys.stdout`` depending on the provided filename. """Open a file or ``sys.stdout`` depending on the provided filename.
Args: Args:
@ -70,13 +71,13 @@ def _open(filename=None, mode='r'):
A file handle. A file handle.
""" """
if not filename or filename == '-': if not filename or filename == "-":
if not mode or 'r' in mode: if not mode or "r" in mode:
file = sys.stdin file = sys.stdin
elif 'w' in mode: elif "w" in mode:
file = sys.stdout file = sys.stdout
else: else:
raise ValueError('Invalid mode for file: {}'.format(mode)) raise ValueError("Invalid mode for file: {}".format(mode))
else: else:
file = open(filename, mode) file = open(filename, mode)
@ -87,13 +88,21 @@ def _open(filename=None, mode='r'):
file.close() file.close()
def get_all_imports( def get_all_imports(path, encoding=None, extra_ignore_dirs=None, follow_links=True):
path, encoding=None, extra_ignore_dirs=None, follow_links=True):
imports = set() imports = set()
raw_imports = set() raw_imports = set()
candidates = [] candidates = []
ignore_errors = False ignore_errors = False
ignore_dirs = [".hg", ".svn", ".git", ".tox", "__pycache__", "env", "venv"] ignore_dirs = [
".hg",
".svn",
".git",
".tox",
"__pycache__",
"env",
"venv",
".ipynb_checkpoints",
]
if extra_ignore_dirs: if extra_ignore_dirs:
ignore_dirs_parsed = [] ignore_dirs_parsed = []
@ -106,13 +115,23 @@ def get_all_imports(
dirs[:] = [d for d in dirs if d not in ignore_dirs] dirs[:] = [d for d in dirs if d not in ignore_dirs]
candidates.append(os.path.basename(root)) candidates.append(os.path.basename(root))
files = [fn for fn in files if os.path.splitext(fn)[1] == ".py"] files = [fn for fn in files if filter_ext(fn, [".py", ".ipynb"])]
candidates = list(
map(
lambda fn: os.path.splitext(fn)[0],
filter(lambda fn: filter_ext(fn, [".py"]), files),
)
)
candidates += [os.path.splitext(fn)[0] for fn in files]
for file_name in files: for file_name in files:
file_name = os.path.join(root, file_name) file_name = os.path.join(root, file_name)
with open(file_name, "r", encoding=encoding) as f: contents = ""
contents = f.read() if filter_ext(file_name, [".py"]):
with open(file_name, "r", encoding=encoding) as f:
contents = f.read()
elif filter_ext(file_name, [".ipynb"]):
contents = ipynb_2_py(file_name, encoding=encoding)
try: try:
tree = ast.parse(contents) tree = ast.parse(contents)
for node in ast.walk(tree): for node in ast.walk(tree):
@ -128,6 +147,8 @@ def get_all_imports(
continue continue
else: else:
logging.error("Failed on file: %s" % file_name) logging.error("Failed on file: %s" % file_name)
if filter_ext(file_name, [".ipynb"]):
logging.error("Magic command without % might be failed")
raise exc raise exc
# Clean up imports # Clean up imports
@ -137,11 +158,11 @@ def get_all_imports(
# Cleanup: We only want to first part of the import. # Cleanup: We only want to first part of the import.
# Ex: from django.conf --> django.conf. But we only want django # Ex: from django.conf --> django.conf. But we only want django
# as an import. # as an import.
cleaned_name, _, _ = name.partition('.') cleaned_name, _, _ = name.partition(".")
imports.add(cleaned_name) imports.add(cleaned_name)
packages = imports - (set(candidates) & imports) packages = imports - (set(candidates) & imports)
logging.debug('Found packages: {0}'.format(packages)) logging.debug("Found packages: {0}".format(packages))
with open(join("stdlib"), "r") as f: with open(join("stdlib"), "r") as f:
data = {x.strip() for x in f} data = {x.strip() for x in f}
@ -149,58 +170,81 @@ def get_all_imports(
return list(packages - data) return list(packages - data)
def filter_line(line):
return len(line) > 0 and line[0] != "#"
def filter_ext(file_name, acceptable):
return os.path.splitext(file_name)[1] in acceptable
def ipynb_2_py(file_name, encoding=None):
"""
Args:
file_name (str): notebook file path to parse as python script
encoding (str): encoding of file
Returns:
str: parsed string
"""
exporter = PythonExporter()
(body, _) = exporter.from_filename(file_name)
return body.encode(encoding if encoding is not None else "utf-8")
def generate_requirements_file(path, imports, symbol): def generate_requirements_file(path, imports, symbol):
with _open(path, "w") as out_file: with _open(path, "w") as out_file:
logging.debug('Writing {num} requirements: {imports} to {file}'.format( logging.debug(
num=len(imports), "Writing {num} requirements: {imports} to {file}".format(
file=path, num=len(imports),
imports=", ".join([x['name'] for x in imports]) file=path,
)) imports=", ".join([x["name"] for x in imports]),
fmt = '{name}' + symbol + '{version}' )
out_file.write('\n'.join( )
fmt.format(**item) if item['version'] else '{name}'.format(**item) fmt = "{name}" + symbol + "{version}"
for item in imports) + '\n') out_file.write(
"\n".join(fmt.format(**item) if item["version"] else "{name}".format(**item) for item in imports) + "\n"
)
def output_requirements(imports, symbol): def output_requirements(imports, symbol):
generate_requirements_file('-', imports, symbol) generate_requirements_file("-", imports, symbol)
def get_imports_info( def get_imports_info(imports, pypi_server="https://pypi.python.org/pypi/", proxy=None):
imports, pypi_server="https://pypi.python.org/pypi/", proxy=None):
result = [] result = []
for item in imports: for item in imports:
try: try:
logging.warning( logging.warning(
'Import named "%s" not found locally. ' 'Import named "%s" not found locally. ' "Trying to resolve it at the PyPI server.",
'Trying to resolve it at the PyPI server.', item,
item
) )
response = requests.get( response = requests.get("{0}{1}/json".format(pypi_server, item), proxies=proxy)
"{0}{1}/json".format(pypi_server, item), proxies=proxy)
if response.status_code == 200: if response.status_code == 200:
if hasattr(response.content, 'decode'): if hasattr(response.content, "decode"):
data = json2package(response.content.decode()) data = json2package(response.content.decode())
else: else:
data = json2package(response.content) data = json2package(response.content)
elif response.status_code >= 300: elif response.status_code >= 300:
raise HTTPError(status_code=response.status_code, raise HTTPError(status_code=response.status_code, reason=response.reason)
reason=response.reason)
except HTTPError: except HTTPError:
logging.warning( logging.warning('Package "%s" does not exist or network problems', item)
'Package "%s" does not exist or network problems', item)
continue continue
logging.warning( logging.warning(
'Import named "%s" was resolved to "%s:%s" package (%s).\n' 'Import named "%s" was resolved to "%s:%s" package (%s).\n'
'Please, verify manually the final list of requirements.txt ' "Please, verify manually the final list of requirements.txt "
'to avoid possible dependency confusions.', "to avoid possible dependency confusions.",
item, item,
data.name, data.name,
data.latest_release_id, data.latest_release_id,
data.pypi_url data.pypi_url,
) )
result.append({'name': item, 'version': data.latest_release_id}) result.append({"name": item, "version": data.latest_release_id})
return result return result
@ -225,25 +269,23 @@ def get_locally_installed_packages(encoding=None):
filtered_top_level_modules = list() filtered_top_level_modules = list()
for module in top_level_modules: for module in top_level_modules:
if ( if (module not in ignore) and (package[0] not in ignore):
(module not in ignore) and
(package[0] not in ignore)
):
# append exported top level modules to the list # append exported top level modules to the list
filtered_top_level_modules.append(module) filtered_top_level_modules.append(module)
version = None version = None
if len(package) > 1: if len(package) > 1:
version = package[1].replace( version = package[1].replace(".dist", "").replace(".egg", "")
".dist", "").replace(".egg", "")
# append package: top_level_modules pairs # append package: top_level_modules pairs
# instead of top_level_module: package pairs # instead of top_level_module: package pairs
packages.append({ packages.append(
'name': package[0], {
'version': version, "name": package[0],
'exports': filtered_top_level_modules "version": version,
}) "exports": filtered_top_level_modules,
}
)
return packages return packages
@ -256,14 +298,14 @@ def get_import_local(imports, encoding=None):
# if candidate import name matches export name # if candidate import name matches export name
# or candidate import name equals to the package name # or candidate import name equals to the package name
# append it to the result # append it to the result
if item in package['exports'] or item == package['name']: if item in package["exports"] or item == package["name"]:
result.append(package) result.append(package)
# removing duplicates of package/version # removing duplicates of package/version
# had to use second method instead of the previous one, # had to use second method instead of the previous one,
# because we have a list in the 'exports' field # because we have a list in the 'exports' field
# https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python # https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python
result_unique = [i for n, i in enumerate(result) if i not in result[n+1:]] result_unique = [i for n, i in enumerate(result) if i not in result[n + 1 :]]
return result_unique return result_unique
@ -294,7 +336,7 @@ def get_name_without_alias(name):
match = REGEXP[0].match(name.strip()) match = REGEXP[0].match(name.strip())
if match: if match:
name = match.groups(0)[0] name = match.groups(0)[0]
return name.partition(' as ')[0].partition('.')[0].strip() return name.partition(" as ")[0].partition(".")[0].strip()
def join(f): def join(f):
@ -353,6 +395,7 @@ def parse_requirements(file_):
return modules return modules
def compare_modules(file_, imports): def compare_modules(file_, imports):
"""Compare modules in a file to imported modules in a project. """Compare modules in a file to imported modules in a project.
@ -379,7 +422,8 @@ def diff(file_, imports):
logging.info( logging.info(
"The following modules are in {} but do not seem to be imported: " "The following modules are in {} but do not seem to be imported: "
"{}".format(file_, ", ".join(x for x in modules_not_imported))) "{}".format(file_, ", ".join(x for x in modules_not_imported))
)
def clean(file_, imports): def clean(file_, imports):
@ -427,30 +471,27 @@ def dynamic_versioning(scheme, imports):
def init(args): def init(args):
encoding = args.get('--encoding') encoding = args.get("--encoding")
extra_ignore_dirs = args.get('--ignore') extra_ignore_dirs = args.get("--ignore")
follow_links = not args.get('--no-follow-links') follow_links = not args.get("--no-follow-links")
input_path = args['<path>'] input_path = args["<path>"]
if input_path is None: if input_path is None:
input_path = os.path.abspath(os.curdir) input_path = os.path.abspath(os.curdir)
if extra_ignore_dirs: if extra_ignore_dirs:
extra_ignore_dirs = extra_ignore_dirs.split(',') extra_ignore_dirs = extra_ignore_dirs.split(",")
path = (args["--savepath"] if args["--savepath"] else path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
os.path.join(input_path, "requirements.txt")) if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path):
if (not args["--print"] logging.warning("requirements.txt already exists, " "use --force to overwrite it")
and not args["--savepath"]
and not args["--force"]
and os.path.exists(path)):
logging.warning("requirements.txt already exists, "
"use --force to overwrite it")
return return
candidates = get_all_imports(input_path, candidates = get_all_imports(
encoding=encoding, input_path,
extra_ignore_dirs=extra_ignore_dirs, encoding=encoding,
follow_links=follow_links) extra_ignore_dirs=extra_ignore_dirs,
follow_links=follow_links,
)
candidates = get_pkg_names(candidates) candidates = get_pkg_names(candidates)
logging.debug("Found imports: " + ", ".join(candidates)) logging.debug("Found imports: " + ", ".join(candidates))
pypi_server = "https://pypi.python.org/pypi/" pypi_server = "https://pypi.python.org/pypi/"
@ -459,11 +500,10 @@ def init(args):
pypi_server = args["--pypi-server"] pypi_server = args["--pypi-server"]
if args["--proxy"]: if args["--proxy"]:
proxy = {'http': args["--proxy"], 'https': args["--proxy"]} proxy = {"http": args["--proxy"], "https": args["--proxy"]}
if args["--use-local"]: if args["--use-local"]:
logging.debug( logging.debug("Getting package information ONLY from local installation.")
"Getting package information ONLY from local installation.")
imports = get_import_local(candidates, encoding=encoding) imports = get_import_local(candidates, encoding=encoding)
else: else:
logging.debug("Getting packages information from Local/PyPI") logging.debug("Getting packages information from Local/PyPI")
@ -473,20 +513,21 @@ def init(args):
# the list of exported modules, installed locally # the list of exported modules, installed locally
# and the package name is not in the list of local module names # and the package name is not in the list of local module names
# it add to difference # it add to difference
difference = [x for x in candidates if difference = [
# aggregate all export lists into one x
# flatten the list for x in candidates
# check if candidate is in exports if
x.lower() not in [y for x in local for y in x['exports']] # aggregate all export lists into one
and # flatten the list
# check if candidate is package names # check if candidate is in exports
x.lower() not in [x['name'] for x in local]] x.lower() not in [y for x in local for y in x["exports"]] and
# check if candidate is package names
x.lower() not in [x["name"] for x in local]
]
imports = local + get_imports_info(difference, imports = local + get_imports_info(difference, proxy=proxy, pypi_server=pypi_server)
proxy=proxy,
pypi_server=pypi_server)
# sort imports based on lowercase name of package, similar to `pip freeze`. # sort imports based on lowercase name of package, similar to `pip freeze`.
imports = sorted(imports, key=lambda x: x['name'].lower()) imports = sorted(imports, key=lambda x: x["name"].lower())
if args["--diff"]: if args["--diff"]:
diff(args["--diff"], imports) diff(args["--diff"], imports)
@ -501,8 +542,7 @@ def init(args):
if scheme in ["compat", "gt", "no-pin"]: if scheme in ["compat", "gt", "no-pin"]:
imports, symbol = dynamic_versioning(scheme, imports) imports, symbol = dynamic_versioning(scheme, imports)
else: else:
raise ValueError("Invalid argument for mode flag, " raise ValueError("Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead")
"use 'compat', 'gt' or 'no-pin' instead")
else: else:
symbol = "==" symbol = "=="
@ -516,8 +556,8 @@ def init(args):
def main(): # pragma: no cover def main(): # pragma: no cover
args = docopt(__doc__, version=__version__) args = docopt(__doc__, version=__version__)
log_level = logging.DEBUG if args['--debug'] else logging.INFO log_level = logging.DEBUG if args["--debug"] else logging.INFO
logging.basicConfig(level=log_level, format='%(levelname)s: %(message)s') logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s")
try: try:
init(args) init(args)
@ -525,5 +565,5 @@ def main(): # pragma: no cover
sys.exit(0) sys.exit(0)
if __name__ == '__main__': if __name__ == "__main__":
main() # pragma: no cover main() # pragma: no cover

View File

@ -1,3 +1,5 @@
wheel==0.38.1 wheel==0.38.1
Yarg==0.1.9 Yarg==0.1.9
docopt==0.6.2 docopt==0.6.2
nbconvert==7.9.2

View File

@ -15,7 +15,7 @@ with open('HISTORY.rst') as history_file:
history = history_file.read().replace('.. :changelog:', '') history = history_file.read().replace('.. :changelog:', '')
requirements = [ requirements = [
'docopt', 'yarg' 'docopt', 'yarg', 'nbconvert', 'ipython'
] ]
setup( setup(

View File

@ -0,0 +1,37 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Markdown test\n",
"import sklearn\n",
"\n",
"```python\n",
"import FastAPI\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

View File

@ -0,0 +1,102 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"unused import\"\"\"\n",
"# pylint: disable=undefined-all-variable, import-error, no-absolute-import, too-few-public-methods, missing-docstring\n",
"import xml.etree # [unused-import]\n",
"import xml.sax # [unused-import]\n",
"import os.path as test # [unused-import]\n",
"from sys import argv as test2 # [unused-import]\n",
"from sys import flags # [unused-import]\n",
"# +1:[unused-import,unused-import]\n",
"from collections import deque, OrderedDict, Counter\n",
"# All imports above should be ignored\n",
"import requests # [unused-import]\n",
"\n",
"# setuptools\n",
"import zipimport # command/easy_install.py\n",
"\n",
"# twisted\n",
"from importlib import invalidate_caches # python/test/test_deprecate.py\n",
"\n",
"# astroid\n",
"import zipimport # manager.py\n",
"# IPython\n",
"from importlib.machinery import all_suffixes # core/completerlib.py\n",
"import importlib # html/notebookapp.py\n",
"\n",
"from IPython.utils.importstring import import_item # Many files\n",
"\n",
"# pyflakes\n",
"# test/test_doctests.py\n",
"from pyflakes.test.test_imports import Test as TestImports\n",
"\n",
"# Nose\n",
"from nose.importer import Importer, add_path, remove_path # loader.py\n",
"\n",
"import atexit\n",
"from __future__ import print_function\n",
"from docopt import docopt\n",
"import curses, logging, sqlite3\n",
"import logging\n",
"import os\n",
"import sqlite3\n",
"import time\n",
"import sys\n",
"import signal\n",
"import bs4\n",
"import nonexistendmodule\n",
"import boto as b, peewee as p\n",
"# import django\n",
"import flask.ext.somext # # #\n",
"from sqlalchemy import model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" import ujson as json\n",
"except ImportError:\n",
" import json\n",
"\n",
"import models\n",
"\n",
"\n",
"def main():\n",
" pass\n",
"\n",
"import after_method_is_valid_even_if_not_pep8"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,34 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cd ."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -37,7 +37,7 @@ class TestPipreqs(unittest.TestCase):
"after_method_is_valid_even_if_not_pep8", "after_method_is_valid_even_if_not_pep8",
] ]
self.modules2 = ["beautifulsoup4"] self.modules2 = ["beautifulsoup4"]
self.local = ["docopt", "requests", "nose", "pyflakes"] self.local = ["docopt", "requests", "nose", "pyflakes", "ipython"]
self.project = os.path.join(os.path.dirname(__file__), "_data") self.project = os.path.join(os.path.dirname(__file__), "_data")
self.empty_filepath = os.path.join(self.project, "empty.txt") self.empty_filepath = os.path.join(self.project, "empty.txt")
self.imports_filepath = os.path.join(self.project, "imports.txt") self.imports_filepath = os.path.join(self.project, "imports.txt")
@ -66,19 +66,20 @@ class TestPipreqs(unittest.TestCase):
self.project_clean = os.path.join(os.path.dirname(__file__), "_data_clean") self.project_clean = os.path.join(os.path.dirname(__file__), "_data_clean")
self.project_invalid = os.path.join(os.path.dirname(__file__), "_invalid_data") self.project_invalid = os.path.join(os.path.dirname(__file__), "_invalid_data")
self.parsed_packages = [ self.project_with_ignore_directory = os.path.join(
{"name": "pandas", "version": "2.0.0"}, os.path.dirname(__file__), "_data_ignore"
{"name": "numpy", "version": "1.2.3"}, )
{"name": "torch", "version": "4.0.0"}, self.project_with_duplicated_deps = os.path.join(
] os.path.dirname(__file__), "_data_duplicated_deps"
self.empty_filepath = os.path.join(self.project, "empty.txt") )
self.imports_filepath = os.path.join(self.project, "imports.txt")
self.project_with_ignore_directory = os.path.join(os.path.dirname(__file__), "_data_ignore")
self.project_with_duplicated_deps = os.path.join(os.path.dirname(__file__), "_data_duplicated_deps")
self.requirements_path = os.path.join(self.project, "requirements.txt") self.requirements_path = os.path.join(self.project, "requirements.txt")
self.alt_requirement_path = os.path.join(self.project, "requirements2.txt") self.alt_requirement_path = os.path.join(self.project, "requirements2.txt")
self.project_with_notebooks = os.path.join(os.path.dirname(__file__), "_data_notebook")
self.project_with_invalid_notebooks = os.path.join(os.path.dirname(__file__), "_invalid_data_notebook")
self.compatible_files = {
"original": os.path.join(os.path.dirname(__file__), "_data/test.py"),
"notebook": os.path.join(os.path.dirname(__file__), "_data_notebook/test.ipynb"),
}
def test_get_all_imports(self): def test_get_all_imports(self):
imports = pipreqs.get_all_imports(self.project) imports = pipreqs.get_all_imports(self.project)
@ -471,7 +472,7 @@ class TestPipreqs(unittest.TestCase):
modules_not_imported = pipreqs.compare_modules(filename, imports) modules_not_imported = pipreqs.compare_modules(filename, imports)
self.assertSetEqual(modules_not_imported, expected_modules_not_imported) self.assertSetEqual(modules_not_imported, expected_modules_not_imported)
def test_output_requirements(self): def test_output_requirements(self):
""" """
Test --print parameter Test --print parameter
@ -515,6 +516,48 @@ class TestPipreqs(unittest.TestCase):
stdout_content = capturedOutput.getvalue().lower() stdout_content = capturedOutput.getvalue().lower()
self.assertTrue(file_content == stdout_content) self.assertTrue(file_content == stdout_content)
def test_import_notebooks(self):
"""
Test the function get_all_imports() using .ipynb file
"""
imports = pipreqs.get_all_imports(self.project_with_notebooks, encoding="utf-8")
self.assertEqual(len(imports), 13)
for item in imports:
self.assertTrue(item.lower() in self.modules, "Import is missing: " + item)
self.assertFalse("time" in imports)
self.assertFalse("logging" in imports)
self.assertFalse("curses" in imports)
self.assertFalse("__future__" in imports)
self.assertFalse("django" in imports)
self.assertFalse("models" in imports)
self.assertFalse("FastAPI" in imports)
self.assertFalse("sklearn" in imports)
def test_invalid_notebook(self):
"""
Test that invalid notebook files cannot be imported.
"""
self.assertRaises(SyntaxError, pipreqs.get_all_imports, self.project_with_invalid_notebooks)
def test_ipynb_2_py(self):
"""
Test the function ipynb_2_py() which converts .ipynb file to .py format
"""
expected = pipreqs.get_all_imports(self.compatible_files["original"])
parsed = pipreqs.get_all_imports(self.compatible_files["notebook"])
self.assertEqual(expected, parsed)
parsed = pipreqs.get_all_imports(self.compatible_files["notebook"], encoding="utf-8")
self.assertEqual(expected, parsed)
def test_filter_ext(self):
"""
Test the function filter_ext()
"""
self.assertTrue(pipreqs.filter_ext("main.py", [".py"]))
self.assertTrue(pipreqs.filter_ext("main.py", [".py", ".ipynb"]))
self.assertFalse(pipreqs.filter_ext("main.py", [".ipynb"]))
def test_parse_requirements(self): def test_parse_requirements(self):
""" """
Test parse_requirements function Test parse_requirements function