Merge pull request #2 from michael-borisov/master

Fix indentation, pep8. Optimize imports and adding missing import for sy...
This commit is contained in:
Vadim Kravcenko 2015-04-25 11:38:59 +02:00
commit f219e3b105
2 changed files with 100 additions and 90 deletions

View File

@ -3,90 +3,99 @@
"""pipreqs - Generate pip requirements.txt file based on imports """pipreqs - Generate pip requirements.txt file based on imports
Usage: Usage:
pipreqs <path> pipreqs <path>
pipreqs <path>[options] pipreqs <path>[options]
Options: Options:
--debug prints debug information. --debug prints debug information.
--savepath path to requirements.txt (Optional) --savepath path to requirements.txt (Optional)
""" """
from __future__ import print_function from __future__ import print_function
import os, re, logging import os
import sys
import re
import logging
from docopt import docopt from docopt import docopt
import yarg import yarg
from yarg.exceptions import HTTPError from yarg.exceptions import HTTPError
REGEXP = [ REGEXP = [
re.compile(r'^import (.+)$'), re.compile(r'^import (.+)$'),
re.compile(r'from (.*?) import (?:.*)') re.compile(r'from (.*?) import (?:.*)')
] ]
def get_all_imports(start_path): def get_all_imports(start_path):
imports = [] imports = []
packages = [] packages = []
logging.debug('Traversing tree, start: %s', start_path) logging.debug('Traversing tree, start: %s', start_path)
for root, dirs, files in os.walk(start_path): for root, dirs, files in os.walk(start_path):
path = root.split('/') packages.append(os.path.basename(root))
packages.append(os.path.basename(root)) for file_name in files:
for file in files: if file_name[-3:] != ".py":
if file[-3:] != ".py": continue
continue
for rex in REGEXP: with open(os.path.join(root, file_name), "r") as file_object:
with open(os.path.join(root,file), "r") as f: for line in file_object:
lines = f.readlines() if line[0] == "#":
for line in lines: continue
if line[0] == "#": if "(" in line:
continue break
if "(" in line: for rex in REGEXP:
break s = rex.match(line)
s = rex.match(line) if not s:
if not s: continue
continue for item in s.groups():
for item in s.groups(): if "," in item:
if "," in item: for match in item.split(","):
for match in item.split(","): imports.append(match.strip())
imports.append(match.strip()) else:
else: to_append = item if "." not in item else item.split(".")[0]
to_append = item if "." not in item else item.split(".")[0] imports.append(to_append.strip())
imports.append(to_append.strip()) third_party_packages = set(imports) - set(set(packages) & set(imports))
third_party_packages = set(imports) - set(set(packages) & set(imports)) logging.debug('Found third-party packages: %s', third_party_packages)
logging.debug('Found third-party packages: %s', third_party_packages) with open(os.path.join(os.path.dirname(__file__), "stdlib"), "r") as f:
with open(os.path.join(os.path.dirname(__file__), "stdlib"), "r") as f: data = [x.strip() for x in f.readlines()]
data = [x.strip() for x in f.readlines()] return list(set(third_party_packages) - set(data))
return list(set(third_party_packages) - set(data))
def generate_requirements_file(path, imports): def generate_requirements_file(path, imports):
with open(path, "w") as ff: with open(path, "w") as ff:
logging.debug('Writing requirements to file %s', path) logging.debug('Writing requirements to file %s', path)
for item in imports: for item in imports:
ff.write(item['name'] + "==" + item['version']) ff.write(item['name'] + "==" + item['version'])
ff.write("\n") ff.write("\n")
def get_imports_info(imports): def get_imports_info(imports):
result = [] result = []
for item in imports: for item in imports:
try: try:
data = yarg.get(item) data = yarg.get(item)
except HTTPError: except HTTPError:
logging.debug('Package does not exist or network problems') logging.debug('Package does not exist or network problems')
continue continue
if not data or len(data.release_ids) < 1: if not data or len(data.release_ids) < 1:
continue continue
last_release = data.release_ids[-1] last_release = data.release_ids[-1]
result.append({'name':item,'version':last_release}) result.append({'name': item, 'version': last_release})
return result return result
def init(args): def init(args):
print ("Looking for imports") print("Looking for imports")
imports = get_all_imports(args['<path>']) imports = get_all_imports(args['<path>'])
print ("Getting latest version of packages information from PyPi") print("Getting latest version of packages information from PyPi")
imports_with_info = get_imports_info(imports) imports_with_info = get_imports_info(imports)
print ("Found third-party imports: " + ", ".join(imports)) print("Found third-party imports: " + ", ".join(imports))
path = args["--savepath"] if args["--savepath"] else os.path.join(args['<path>'],"requirements.txt") path = args["--savepath"] if args["--savepath"] else os.path.join(args['<path>'], "requirements.txt")
generate_requirements_file(path, imports_with_info) generate_requirements_file(path, imports_with_info)
print ("Successfuly saved requirements file in: " + path) print("Successfuly saved requirements file in: " + path)
def main(): # pragma: no cover
def main(): # pragma: no cover
args = docopt(__doc__, version='xstat 0.1') args = docopt(__doc__, version='xstat 0.1')
log_level = logging.WARNING log_level = logging.WARNING
if args['--debug']: if args['--debug']:
@ -98,5 +107,6 @@ def main(): # pragma: no cover
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(0) sys.exit(0)
if __name__ == '__main__': if __name__ == '__main__':
main() # pragma: no cover main() # pragma: no cover

View File

@ -14,43 +14,43 @@ from pipreqs import pipreqs
class TestPipreqs(unittest.TestCase): class TestPipreqs(unittest.TestCase):
def setUp(self): def setUp(self):
self.modules = ['flask', 'requests', 'sqlalchemy', 'docopt', 'nonexistendmodule'] self.modules = ['flask', 'requests', 'sqlalchemy', 'docopt', 'nonexistendmodule']
self.project = os.path.join(os.path.dirname(__file__),"_data") self.project = os.path.join(os.path.dirname(__file__), "_data")
self.requirements_path = os.path.join(self.project,"requirements.txt") self.requirements_path = os.path.join(self.project, "requirements.txt")
def test_get_all_imports(self): def test_get_all_imports(self):
imports = pipreqs.get_all_imports(self.project) imports = pipreqs.get_all_imports(self.project)
self.assertEqual(len(imports),5, "Incorrect Imports array length") self.assertEqual(len(imports), 5, "Incorrect Imports array length")
for item in imports: for item in imports:
self.assertTrue(item in self.modules, "Import is missing") self.assertTrue(item in self.modules, "Import is missing")
self.assertFalse("time" in imports) self.assertFalse("time" in imports)
self.assertFalse("logging" in imports) self.assertFalse("logging" in imports)
self.assertFalse("curses" in imports) self.assertFalse("curses" in imports)
self.assertFalse("__future__" in imports) self.assertFalse("__future__" in imports)
def test_get_imports_info(self): def test_get_imports_info(self):
imports = pipreqs.get_all_imports(self.project) imports = pipreqs.get_all_imports(self.project)
with_info = pipreqs.get_imports_info(imports) with_info = pipreqs.get_imports_info(imports)
# Should contain only 4 Elements without the "nonexistendmodule" # Should contain only 4 Elements without the "nonexistendmodule"
self.assertEqual(len(with_info),4, "Length of imports array with info is wrong") self.assertEqual(len(with_info), 4, "Length of imports array with info is wrong")
for item in with_info: for item in with_info:
self.assertTrue(item['name'] in self.modules, "Import item appears to be missing") self.assertTrue(item['name'] in self.modules, "Import item appears to be missing")
def test_init(self): def test_init(self):
pipreqs.init({'<path>':self.project, '--savepath':None}) pipreqs.init({'<path>': self.project, '--savepath': None})
assert os.path.exists(self.requirements_path) == 1 assert os.path.exists(self.requirements_path) == 1
with open(self.requirements_path, "r") as f: with open(self.requirements_path, "r") as f:
data = f.read() data = f.read()
for item in self.modules[:-1]: for item in self.modules[:-1]:
self.assertTrue(item in data) self.assertTrue(item in data)
def tearDown(self): def tearDown(self):
try: try:
os.remove(self.requirements_path) os.remove(self.requirements_path)
except OSError: except OSError:
pass pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()