Skip to content
Snippets Groups Projects
class1_allele_specific_models.ipynb 80.5 KiB
Newer Older
Tim O'Donnell's avatar
Tim O'Donnell committed
       "      <th>model_layer_sizes</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>HLA-A3301</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2026</td>\n",
       "      <td>1014</td>\n",
       "      <td>None</td>\n",
       "      <td>0.710233</td>\n",
       "      <td>0.989589</td>\n",
       "      <td>0.902256</td>\n",
       "      <td>0.429803</td>\n",
       "      <td>...</td>\n",
       "      <td>0.1</td>\n",
       "      <td>True</td>\n",
       "      <td>0.0</td>\n",
       "      <td>glorot_uniform</td>\n",
       "      <td>tanh</td>\n",
       "      <td>128</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "      <td>50000.0</td>\n",
       "      <td>[8]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>HLA-A3301</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2026</td>\n",
       "      <td>1014</td>\n",
       "      <td>None</td>\n",
       "      <td>0.747597</td>\n",
       "      <td>0.993938</td>\n",
       "      <td>0.919708</td>\n",
       "      <td>0.425610</td>\n",
       "      <td>...</td>\n",
       "      <td>0.1</td>\n",
       "      <td>True</td>\n",
       "      <td>0.0</td>\n",
       "      <td>glorot_uniform</td>\n",
       "      <td>tanh</td>\n",
       "      <td>128</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "      <td>50000.0</td>\n",
       "      <td>[12]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>HLA-A3301</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2027</td>\n",
       "      <td>1013</td>\n",
       "      <td>None</td>\n",
       "      <td>0.705507</td>\n",
       "      <td>0.990185</td>\n",
       "      <td>0.882466</td>\n",
       "      <td>0.430678</td>\n",
       "      <td>...</td>\n",
       "      <td>0.1</td>\n",
       "      <td>True</td>\n",
       "      <td>0.0</td>\n",
       "      <td>glorot_uniform</td>\n",
       "      <td>tanh</td>\n",
       "      <td>128</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "      <td>50000.0</td>\n",
       "      <td>[8]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>HLA-A3301</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2027</td>\n",
       "      <td>1013</td>\n",
       "      <td>None</td>\n",
       "      <td>0.745532</td>\n",
       "      <td>0.993875</td>\n",
       "      <td>0.924812</td>\n",
       "      <td>0.395103</td>\n",
       "      <td>...</td>\n",
       "      <td>0.1</td>\n",
       "      <td>True</td>\n",
       "      <td>0.0</td>\n",
       "      <td>glorot_uniform</td>\n",
       "      <td>tanh</td>\n",
       "      <td>128</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "      <td>50000.0</td>\n",
       "      <td>[12]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>HLA-A3301</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>2027</td>\n",
       "      <td>1013</td>\n",
       "      <td>None</td>\n",
       "      <td>0.709275</td>\n",
       "      <td>0.992395</td>\n",
       "      <td>0.894531</td>\n",
       "      <td>0.441365</td>\n",
       "      <td>...</td>\n",
       "      <td>0.1</td>\n",
       "      <td>True</td>\n",
       "      <td>0.0</td>\n",
       "      <td>glorot_uniform</td>\n",
       "      <td>tanh</td>\n",
       "      <td>128</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "      <td>50000.0</td>\n",
       "      <td>[8]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>HLA-A3301</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2027</td>\n",
       "      <td>1013</td>\n",
       "      <td>None</td>\n",
       "      <td>0.743498</td>\n",
       "      <td>0.994674</td>\n",
       "      <td>0.873518</td>\n",
       "      <td>0.439221</td>\n",
       "      <td>...</td>\n",
       "      <td>0.1</td>\n",
       "      <td>True</td>\n",
       "      <td>0.0</td>\n",
       "      <td>glorot_uniform</td>\n",
       "      <td>tanh</td>\n",
       "      <td>128</td>\n",
       "      <td>False</td>\n",
       "      <td>9</td>\n",
       "      <td>50000.0</td>\n",
       "      <td>[12]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6 rows × 31 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      allele  fold_num  model_num  train_size  test_size imputed_train_size  \\\n",
       "0  HLA-A3301         0          0        2026       1014               None   \n",
       "1  HLA-A3301         0          1        2026       1014               None   \n",
       "2  HLA-A3301         1          0        2027       1013               None   \n",
       "3  HLA-A3301         1          1        2027       1013               None   \n",
       "4  HLA-A3301         2          0        2027       1013               None   \n",
       "5  HLA-A3301         2          1        2027       1013               None   \n",
       "\n",
       "   train_tau  train_auc  train_f1  test_tau        ...         \\\n",
       "0   0.710233   0.989589  0.902256  0.429803        ...          \n",
       "1   0.747597   0.993938  0.919708  0.425610        ...          \n",
       "2   0.705507   0.990185  0.882466  0.430678        ...          \n",
       "3   0.745532   0.993875  0.924812  0.395103        ...          \n",
       "4   0.709275   0.992395  0.894531  0.441365        ...          \n",
       "5   0.743498   0.994674  0.873518  0.439221        ...          \n",
       "\n",
       "   model_fraction_negative  model_batch_normalization  \\\n",
       "0                      0.1                       True   \n",
       "1                      0.1                       True   \n",
       "2                      0.1                       True   \n",
       "3                      0.1                       True   \n",
       "4                      0.1                       True   \n",
       "5                      0.1                       True   \n",
       "\n",
       "  model_dropout_probability      model_init  model_activation  \\\n",
       "0                       0.0  glorot_uniform              tanh   \n",
       "1                       0.0  glorot_uniform              tanh   \n",
       "2                       0.0  glorot_uniform              tanh   \n",
       "3                       0.0  glorot_uniform              tanh   \n",
       "4                       0.0  glorot_uniform              tanh   \n",
       "5                       0.0  glorot_uniform              tanh   \n",
       "\n",
       "  model_batch_size model_impute  model_kmer_size  model_max_ic50  \\\n",
       "0              128        False                9         50000.0   \n",
       "1              128        False                9         50000.0   \n",
       "2              128        False                9         50000.0   \n",
       "3              128        False                9         50000.0   \n",
       "4              128        False                9         50000.0   \n",
       "5              128        False                9         50000.0   \n",
       "\n",
       "  model_layer_sizes  \n",
       "0               [8]  \n",
       "1              [12]  \n",
       "2               [8]  \n",
       "3              [12]  \n",
       "4               [8]  \n",
       "5              [12]  \n",
       "\n",
       "[6 rows x 31 columns]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = mhcflurry.class1_allele_specific.train.train_across_models_and_folds(\n",
    "    folds,\n",
    "    models_to_search,\n",
    "    return_predictors=True)\n",
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    Class1BindingPredictor(name=None, max_ic50=500...\n",
       "1    Class1BindingPredictor(name=None, max_ic50=500...\n",
       "2    Class1BindingPredictor(name=None, max_ic50=500...\n",
       "3    Class1BindingPredictor(name=None, max_ic50=500...\n",
       "4    Class1BindingPredictor(name=None, max_ic50=500...\n",
       "5    Class1BindingPredictor(name=None, max_ic50=500...\n",
       "Name: predictor, dtype: object"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The trained predictors are in the 'predictor' column\n",
    "results_df.predictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "model_num\n",
       "0    0.859859\n",
       "1    0.847004\n",
       "Name: test_auc, dtype: float64"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Which model had the best average AUC across folds?\n",
    "results_df.groupby(\"model_num\").test_auc.mean()"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [py3k]",
   "language": "python",
   "name": "Python [py3k]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}