diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 57147e2..66993a5 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -46,6 +46,7 @@ from docopt import docopt import requests from yarg import json2package from yarg.exceptions import HTTPError +from nbconvert import PythonExporter from pipreqs import __version__ @@ -113,13 +114,21 @@ def get_all_imports( dirs[:] = [d for d in dirs if d not in ignore_dirs] candidates.append(os.path.basename(root)) - files = [fn for fn in files if os.path.splitext(fn)[1] == ".py"] + files = [fn for fn in files if filter_ext(fn, [".py", ".ipynb"])] + + candidates = list(map( + lambda fn: os.path.splitext(fn)[0], + filter(lambda fn: filter_ext(fn, [".py"]), files) + )) - candidates += [os.path.splitext(fn)[0] for fn in files] for file_name in files: file_name = os.path.join(root, file_name) - with open_func(file_name, "r", encoding=encoding) as f: - contents = f.read() + contents = '' + if filter_ext(file_name, [".py"]): + with open_func(file_name, "r", encoding=encoding) as f: + contents = f.read() + elif filter_ext(file_name, [".ipynb"]): + contents = ipynb_2_py(file_name, encoding=encoding) try: tree = ast.parse(contents) for node in ast.walk(tree): @@ -135,6 +144,10 @@ def get_all_imports( continue else: logging.error("Failed on file: %s" % file_name) + if filter_ext(file_name, [".ipynb"]): + logging.error( + "Magic command without % might be failed" + ) raise exc # Clean up imports @@ -161,6 +174,28 @@ def filter_line(line): return len(line) > 0 and line[0] != "#" +def filter_ext(file_name, acceptable): + return os.path.splitext(file_name)[1] in acceptable + + +def ipynb_2_py(file_name, encoding="utf-8"): + """ + + Args: + file_name (str): notebook file path to parse as python script + encoding (str): encoding of file + + Returns: + str: parsed string + + """ + + exporter = PythonExporter() + (body, _) = exporter.from_filename(file_name) + + return body.encode(encoding if encoding is not None else "utf-8") + + def generate_requirements_file(path, imports): with _open(path, "w") as out_file: logging.debug('Writing {num} requirements: {imports} to {file}'.format( diff --git a/requirements.txt b/requirements.txt index 959d1b7..469b3f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ wheel==0.23.0 Yarg==0.1.9 -docopt==0.6.2 \ No newline at end of file +docopt==0.6.2 +nbconvert==5.4.1 +ipython==5.4.1 diff --git a/tests/_data_notebook/markdown_test.ipynb b/tests/_data_notebook/markdown_test.ipynb new file mode 100644 index 0000000..54712d3 --- /dev/null +++ b/tests/_data_notebook/markdown_test.ipynb @@ -0,0 +1,37 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Markdown test\n", + "import sklearn\n", + "\n", + "```python\n", + "import FastAPI\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/_data_notebook/models.py b/tests/_data_notebook/models.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_data_notebook/test.ipynb b/tests/_data_notebook/test.ipynb new file mode 100644 index 0000000..16c07d9 --- /dev/null +++ b/tests/_data_notebook/test.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"unused import\"\"\"\n", + "# pylint: disable=undefined-all-variable, import-error, no-absolute-import, too-few-public-methods, missing-docstring\n", + "import xml.etree # [unused-import]\n", + "import xml.sax # [unused-import]\n", + "import os.path as test # [unused-import]\n", + "from sys import argv as test2 # [unused-import]\n", + "from sys import flags # [unused-import]\n", + "# +1:[unused-import,unused-import]\n", + "from collections import deque, OrderedDict, Counter\n", + "# All imports above should be ignored\n", + "import requests # [unused-import]\n", + "\n", + "# setuptools\n", + "import zipimport # command/easy_install.py\n", + "\n", + "# twisted\n", + "from importlib import invalidate_caches # python/test/test_deprecate.py\n", + "\n", + "# astroid\n", + "import zipimport # manager.py\n", + "# IPython\n", + "from importlib.machinery import all_suffixes # core/completerlib.py\n", + "import importlib # html/notebookapp.py\n", + "\n", + "from IPython.utils.importstring import import_item # Many files\n", + "\n", + "# pyflakes\n", + "# test/test_doctests.py\n", + "from pyflakes.test.test_imports import Test as TestImports\n", + "\n", + "# Nose\n", + "from nose.importer import Importer, add_path, remove_path # loader.py\n", + "\n", + "import atexit\n", + "from __future__ import print_function\n", + "from docopt import docopt\n", + "import curses, logging, sqlite3\n", + "import logging\n", + "import os\n", + "import sqlite3\n", + "import time\n", + "import sys\n", + "import signal\n", + "import bs4\n", + "import nonexistendmodule\n", + "import boto as b, peewee as p\n", + "# import django\n", + "import flask.ext.somext # # #\n", + "from sqlalchemy import model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import ujson as json\n", + "except ImportError:\n", + " import json\n", + "\n", + "import models\n", + "\n", + "\n", + "def main():\n", + " pass\n", + "\n", + "import after_method_is_valid_even_if_not_pep8" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/_invalid_data_notebook/invalid.ipynb b/tests/_invalid_data_notebook/invalid.ipynb new file mode 100644 index 0000000..cacff3f --- /dev/null +++ b/tests/_invalid_data_notebook/invalid.ipynb @@ -0,0 +1,34 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cd ." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index dcd75c5..eef998b 100755 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -27,9 +27,13 @@ class TestPipreqs(unittest.TestCase): self.project_invalid = os.path.join(os.path.dirname(__file__), "_invalid_data") self.project_with_ignore_directory = os.path.join(os.path.dirname(__file__), "_data_ignore") self.project_with_duplicated_deps = os.path.join(os.path.dirname(__file__), "_data_duplicated_deps") + self.project_with_notebooks = os.path.join(os.path.dirname(__file__), "_data_notebook") + self.project_with_invalid_notebooks = os.path.join(os.path.dirname(__file__), "_invalid_data_notebook") self.requirements_path = os.path.join(self.project, "requirements.txt") self.alt_requirement_path = os.path.join( self.project, "requirements2.txt") + self.compatible_files_path = {"original": os.path.join(os.path.dirname(__file__), "_data/test.py"), + "notebook": os.path.join(os.path.dirname(__file__), "_data_notebook/test.ipynb")} def test_get_all_imports(self): imports = pipreqs.get_all_imports(self.project) @@ -200,6 +204,46 @@ class TestPipreqs(unittest.TestCase): for item in ['beautifulsoup4==4.8.1', 'boto==2.49.0']: self.assertFalse(item.lower() in data) + def test_import_notebooks(self): + """ + Test the function get_all_imports() using .ipynb file + """ + imports = pipreqs.get_all_imports(self.project_with_notebooks, encoding="utf-8") + self.assertEqual(len(imports), 13) + for item in imports: + self.assertTrue( + item.lower() in self.modules, "Import is missing: " + item) + self.assertFalse("time" in imports) + self.assertFalse("logging" in imports) + self.assertFalse("curses" in imports) + self.assertFalse("__future__" in imports) + self.assertFalse("django" in imports) + self.assertFalse("models" in imports) + self.assertFalse("FastAPI" in imports) + self.assertFalse("sklearn" in imports) + + def test_invalid_notebook(self): + """ + Test that invalid notebook files cannot be imported. + """ + self.assertRaises(SyntaxError, pipreqs.get_all_imports, self.project_with_invalid_notebooks) + + def test_ipynb_2_py(self): + """ + Test the function ipynb_2_py() which converts .ipynb file to .py format + """ + expected = pipreqs.get_all_imports(self.compatible_files_path["original"]) + parsed = pipreqs.get_all_imports(self.compatible_files_path["notebook"]) + self.assertEqual(expected, parsed) + + def test_filter_ext(self): + """ + Test the function filter_ext() + """ + self.assertTrue(pipreqs.filter_ext("main.py", [".py"])) + self.assertTrue(pipreqs.filter_ext("main.py", [".py", ".ipynb"])) + self.assertFalse(pipreqs.filter_ext("main.py", [".ipynb"])) + def tearDown(self): """ Remove requiremnts.txt files that were written