diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index a9ead36..3d4afa0 100755 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -34,9 +34,11 @@ def get_all_imports(start_path): for root, dirs, files in os.walk(start_path): packages.append(os.path.basename(root)) files = filter(lambda fn:os.path.splitext(fn)[1] == ".py", files) + packages += map(lambda fn:os.path.splitext(fn)[0], files) for file_name in files: with open(os.path.join(root, file_name), "r") as file_object: - for line in file_object: + lines = filter(lambda l:len(l) > 0, map(lambda l:l.strip(), file_object)) + for line in lines: if line[0] == "#": continue if "(" in line: diff --git a/tests/_data/test.py b/tests/_data/test.py index 58a8764..b033ef7 100644 --- a/tests/_data/test.py +++ b/tests/_data/test.py @@ -13,7 +13,12 @@ import nonexistendmodule # import django import flask.ext.somext from sqlalchemy import model -import ujson as json +try: + import ujson as json +except ImportError: + import json + +import models def main(): pass diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index 79992d2..a2001dd 100755 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -29,6 +29,7 @@ class TestPipreqs(unittest.TestCase): self.assertFalse("curses" in imports) self.assertFalse("__future__" in imports) self.assertFalse("django" in imports) + self.assertFalse("models" in imports) def test_get_imports_info(self): imports = pipreqs.get_all_imports(self.project)