diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index bf3f880..c1548dd 100755 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -18,6 +18,7 @@ class TestPipreqs(unittest.TestCase): self.modules = ['flask', 'requests', 'sqlalchemy', 'docopt', 'ujson', 'nonexistendmodule'] self.project = os.path.join(os.path.dirname(__file__), "_data") self.requirements_path = os.path.join(self.project, "requirements.txt") + self.alt_requirement_path = os.path.join(self.project, "requirements2.txt") def test_get_all_imports(self): imports = pipreqs.get_all_imports(self.project) @@ -52,11 +53,31 @@ class TestPipreqs(unittest.TestCase): for item in self.modules[:-1]: self.assertTrue(item in data) + def test_init_local_only(self): + pipreqs.init({'': self.project, '--savepath': None,'--use-local':True}) + assert os.path.exists(self.requirements_path) == 1 + with open(self.requirements_path, "r") as f: + data = f.readlines() + self.assertEqual(len(data), 2, 'Only two local packages should be found') + + def test_init_savepath(self): + pipreqs.init({'': self.project, '--savepath': self.alt_requirement_path,'--use-local':None}) + assert os.path.exists(self.alt_requirement_path) == 1 + with open(self.alt_requirement_path, "r") as f: + data = f.read() + for item in self.modules[:-1]: + self.assertTrue(item in data) + + def tearDown(self): try: os.remove(self.requirements_path) except OSError: pass + try: + os.remove(self.alt_requirement_path) + except OSError: + pass if __name__ == '__main__':