Skip to content
Snippets Groups Projects
generate.py 6.21 KiB
Newer Older
Generate certain RST files used in documentation.
"""

import sys
import argparse
Tim O'Donnell's avatar
Tim O'Donnell committed
import json
from textwrap import wrap
Tim O'Donnell's avatar
Tim O'Donnell committed
from collections import OrderedDict

import pypandoc
import pandas
from keras.utils.vis_utils import plot_model
Tim O'Donnell's avatar
Tim O'Donnell committed
from tabulate import tabulate

from mhcflurry import __version__
from mhcflurry.downloads import get_path
from mhcflurry.class1_affinity_predictor import Class1AffinityPredictor

parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument(
    "--cv-summary-csv",
    metavar="FILE.csv",
    default=get_path(
        "cross_validation_class1", "summary.all.csv", test_exists=False),
    help="Cross validation scores summary. Default: %(default)s",
)
parser.add_argument(
    "--class1-models-dir",
    metavar="DIR",
    default=get_path(
        "models_class1", "models", test_exists=False),
    help="Class1 models. Default: %(default)s",
)
parser.add_argument(
    "--out-models-cv-rst",
    metavar="FILE.rst",
    help="rst output file",
)
parser.add_argument(
    "--out-models-info-rst",
    metavar="FILE.rst",
    help="rst output file",
)
parser.add_argument(
    "--out-models-architecture-png",
    metavar="FILE.png",
    help="png output file",
)
parser.add_argument(
    "--out-models-supported-alleles-rst",
    metavar="FILE.png",
    help="png output file",
)


def go(argv):
    args = parser.parse_args(argv)

    predictor = None

    if args.out_models_supported_alleles_rst:
        # Supported alleles rst
        if predictor is None:
            predictor = Class1AffinityPredictor.load(args.class1_models_dir)
        with open(args.out_models_supported_alleles_rst, "w") as fd:
            fd.write(
                "Models released with the current version of MHCflurry (%s) "
                "support peptides of "
                "length %d-%d and the following %d alleles:\n\n::\n\n\t%s\n\n" % (
                    __version__,
                    predictor.supported_peptide_lengths[0],
                    predictor.supported_peptide_lengths[1],
                    len(predictor.supported_alleles),
                    "\n\t".join(
                        wrap(", ".join(predictor.supported_alleles)))))
            print("Wrote: %s" % args.out_models_supported_alleles_rst)

    if args.out_models_architecture_png:
        # Architecture diagram
        if predictor is None:
            predictor = Class1AffinityPredictor.load(args.class1_models_dir)
        network = predictor.neural_networks[0].network()
        plot_model(
            network,
            to_file=args.out_models_architecture_png,
            show_layer_names=True,
            show_shapes=True)
        print("Wrote: %s" % args.out_models_architecture_png)

    if args.out_models_info_rst:
        # Architecture information rst
        if predictor is None:
            predictor = Class1AffinityPredictor.load(args.class1_models_dir)
Tim O'Donnell's avatar
Tim O'Donnell committed

        representative_networks = OrderedDict()
        for network in predictor.neural_networks:
            config = json.dumps(network.hyperparameters)
            if config not in representative_networks:
                representative_networks[config] = network

        all_hyperparameters = [
            network.hyperparameters for network in representative_networks.values()
        ]
        hyperparameter_keys =  all_hyperparameters[0].keys()
        assert all(
            hyperparameters.keys() == hyperparameter_keys
            for hyperparameters in all_hyperparameters)

        constant_hyperparameter_keys = [
            k for k in hyperparameter_keys
            if all([
                hyperparameters[k] == all_hyperparameters[0][k]
                for hyperparameters in all_hyperparameters
            ])
        ]
        constant_hypeparameters = dict(
            (key, all_hyperparameters[0][key])
            for key in sorted(constant_hyperparameter_keys)
        )

        def write_hyperparameters(fd, hyperparameters):
            rows = []
            for key in sorted(hyperparameters.keys()):
                rows.append((key, json.dumps(hyperparameters[key])))
            fd.write("\n")
            fd.write(
                tabulate(rows, ["Hyperparameter", "Value"], tablefmt="grid"))

        with open(args.out_models_info_rst, "w") as fd:
Tim O'Donnell's avatar
Tim O'Donnell committed
            fd.write("Hyperparameters shared by all %d architectures:\n" %
                len(representative_networks))
            write_hyperparameters(fd, constant_hypeparameters)
            fd.write("\n")
            for (i, network) in enumerate(representative_networks.values()):
                lines = []
                network.network().summary(print_fn=lines.append)

                fd.write("Architecture %d / %d:\n" % (
                    (i + 1, len(representative_networks))))
                fd.write("+" * 40)
                fd.write("\n")
Tim O'Donnell's avatar
Tim O'Donnell committed
                write_hyperparameters(
                    fd,
                    dict(
                        (key, value)
                        for (key, value) in network.hyperparameters.items()
                        if key not in constant_hypeparameters))
                fd.write("\n\n::\n\n")
                for line in lines:
                    fd.write("    ")
                    fd.write(line)
                    fd.write("\n")
        print("Wrote: %s" % args.out_models_info_rst)

    if args.out_models_cv_rst:
        # Models cv output
        df = pandas.read_csv(args.cv_summary_csv)
        sub_df = df.loc[
            df.kind == "ensemble"
Tim O'Donnell's avatar
Tim O'Donnell committed
            ].sort_values("allele").copy().reset_index(drop=True)
        sub_df["Allele"] = sub_df.allele
        sub_df["CV Training Size"] = sub_df.train_size.astype(int)
        sub_df["AUC"] = sub_df.auc
        sub_df["F1"] = sub_df.f1
        sub_df["Kendall Tau"] = sub_df.tau
        sub_df = sub_df[sub_df.columns[-5:]]
        html = sub_df.to_html(
            index=False,
            float_format=lambda v: "%0.3f" % v,
            justify="left")
        rst = pypandoc.convert_text(html, format="html", to="rst")

        with open(args.out_models_cv_rst, "w") as fd:
            fd.write(
                "Showing estimated performance for %d alleles." % len(sub_df))
            fd.write("\n\n")
            fd.write(rst)
            print("Wrote: %s" % args.out_models_cv_rst)

if __name__ == "__main__":
    go(sys.argv[1:])