diff --git a/.travis.yml b/.travis.yml index d4d5fedc6cd735045bf3917c63b4299d6897e364..0d0bf9ffb3a1638db905352b96c3a05b821d2d6a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,6 +45,6 @@ script: # download data and models, then run tests - mhcflurry-downloads fetch - mhcflurry-downloads info # just to test this command works - - nosetests test --with-coverage --cover-package=mhcflurry && ./lint.sh + - nosetests test -sv --with-coverage --cover-package=mhcflurry && ./lint.sh after_success: coveralls diff --git a/test/test_speed.py b/test/test_speed.py new file mode 100644 index 0000000000000000000000000000000000000000..b84e1f37ea4214d5421ca6ee030d6213cdbaa6ba --- /dev/null +++ b/test/test_speed.py @@ -0,0 +1,58 @@ +import numpy +numpy.random.seed(0) +import time +import cProfile +import pstats + +import pandas + +from mhcflurry import Class1AffinityPredictor +from mhcflurry.common import random_peptides + +NUM = 10000 + +DOWNLOADED_PREDICTOR = Class1AffinityPredictor.load() + + +def test_speed(profile=False): + starts = {} + timings = {} + profilers = {} + + def start(name): + starts[name] = time.time() + if profile: + profilers[name] = cProfile.Profile() + profilers[name].enable() + + def end(name): + timings[name] = time.time() - starts[name] + if profile: + profilers[name].disable() + + start("first") + DOWNLOADED_PREDICTOR.predict(["SIINFEKL"], allele="HLA-A*02:01") + end("first") + + peptides = random_peptides(NUM) + start("pred_%d" % NUM) + DOWNLOADED_PREDICTOR.predict(peptides, allele="HLA-A*02:01") + end("pred_%d" % NUM) + + print("SPEED BENCHMARK") + print("Results:\n%s" % str(pandas.Series(timings))) + + return dict( + (key, pstats.Stats(value)) for (key, value) in profilers.items()) + + +if __name__ == '__main__': + # If run directly from python, do profiling and leave the user in a shell + # to explore results. + + result = test_speed(profile=True) + result["pred_%d" % NUM].sort_stats("cumtime").reverse_order().print_stats() + + # Leave in ipython + locals().update(result) + import ipdb ; ipdb.set_trace()