proteins image source: DeepMind blog article

This scientific project develops an interpretable model for classifying protein sequences into the most common protein families found in the UniProt Knowledgebase. The study employs common NLP techniques and compares various machine learning models, such as k-nearest neighbors, decision trees, and random forests.

Preliminaries

Amino acids

Amino acids are the basic building blocks of proteins. With few exceptions, all proteins in all living organisms are composed of 19 types of primary amino acids and one secondary amino acid. (P). [๐Ÿ”—]

Glycine (G)Alanine (A)Proline (P)Valine (V)Leucine (L)
Isoleucine (I)Methionine (M)Phenylalanine (F)Tyrosine (Y)Tryptophan (W)
Serine (S)Threonine (T)Cysteine (C)Asparagine (N)Glutamine (Q)
Lysine (K)Histidine (H)Arginine (R)Aspartate (D)Glutamate (E)

Proteins

Proteins are the basis of all known organisms and perform many functions in the body. Each protein is made up of a long chain of amino acid residues that uniquely determines the shape of the protein, which in turn determines its function. [๐Ÿ”—]

Protein structures

The structure of protein chains is distinguished into primary, secondary, tertiary and, for some more complex proteins, quaternary. The primary structure is determined by the order of amino acids in the chain. It also uniquely determines all higher structures. [๐Ÿ”—]

PrimarySecondaryTertiaryQuaternary

Protein family

A protein family is a group of evolutionarily related proteins that typically have similar three-dimensional structure (tertiary structure), function, and significant sequence similarity. [๐Ÿ”—]

Motivation

Experimental determination of 3D protein structures from sequences is usually expensive and time consuming (X-ray crystallography, nuclear magnetic resonance).

Assigning proteins to their respective families allows researchers to infer the shape (and therefore the function) of proteins based only on shared features and evolutionary relationships.


Objective

Description

The goal is to develop a predictive model capable of classifying protein sequences into the eight most common families in the UniProt database, designated as Reviewed/Swiss-Prot.

Limitations: We will restrict ourselves to proteins of at most 1000 amino acids in length and those that have evidence of existence.

Data

The dataset uniprotkb.tsv comes from the UniProt Knowledgebase (UniProtKB). The file consists of a subset of UniProtKB/Swiss-Prot, which contains manually annotated records. [๐Ÿ”—]

Due to the size of the original file, columns that were too large or irrelevant to the task were discarded during the data preprocessing.


import itertools
import pandas as pd
import numpy as np
import missingno as msno
import matplotlib.pyplot as plt
import dtreeviz

from collections import defaultdict
from collections import Counter
from Bio.SeqUtils import IUPACData
from faiss import IndexFlatL2
from sklearn.model_selection import train_test_split
from matplotlib.ticker import MaxNLocator
from sklearn.model_selection import ParameterGrid
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.utils.validation import check_X_y
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.multiclass import unique_labels
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import PrecisionRecallDisplay
from sklearn.decomposition import PCA
from sklearn.feature_extraction.text import TfidfVectorizer

np.random.seed(42)

plt.style.use('ggplot')
red =  np.array((226, 74, 51))/255
blue = np.array((52, 138, 189))/255
blue2 = np.array((0, 150, 240))/255
grey = np.array((100, 100, 100))/255
cobalt = np.array((0, 71, 171))/255
knn_color = np.array((192, 64, 0))/255
tree_color = np.array((24, 168, 27))/255
forest_color = np.array((146, 55, 188))/255
main_color = blue2

Dataset

raw_data = pd.read_csv('uniprotkb.tsv', sep='\t', index_col=0)

Overview

The dataset has the following features:

FeatureDescription
Entry NameUnique identifier for a UniProtKB entry๐Ÿ”—
Protein namesName(s) and taxonomy๐Ÿ”—
Gene NamesGene(s) that code the protein sequence๐Ÿ”—
OrganismName(s) of the organism that is the source of the protein๐Ÿ”—
Protein existenceType of evidence that supports the existence of the protein๐Ÿ”—
AlphaFoldDBID of the structural prediction from AlphaFold database๐Ÿ”—
PfamProtein family classification of the protein๐Ÿ”—
SUPFAMSuperfamily to which the protein belongs๐Ÿ”—
LengthNumber of amino acids in the canonical sequence๐Ÿ”—
SequenceCanonical protein sequence (primary structure)๐Ÿ”—
pd.DataFrame([raw_data.count(), raw_data.nunique(), raw_data.dtypes],
             index=['Non-Null Count', 'Unique Values', 'Dtype']).T

Non-Null CountUnique ValuesDtype
Entry Name570420570420object
Protein names570420166782object
Gene Names545939493775object
Organism57042014534object
Protein existence5704205object
AlphaFoldDB546366546366object
Pfam53910520623object
SUPFAM4591105587object
Length5704203456int64
Sequence570420482285object

Necessary pre-processing

Some features are inappropriately represented by a list of values.

raw_data.iloc[[10, 42]].T

EntryA0A061AE05A0A087X1C5
Entry NamePAPSH_CAEELCP2D7_HUMAN
Protein namesBifunctional 3'-phosphoadenosine 5'-phosphosul...Putative cytochrome P450 2D7 (EC 1.14.14.1)
Gene Namespps-1 T14G10.1CYP2D7
OrganismCaenorhabditis elegansHomo sapiens (Human)
Protein existenceEvidence at protein levelUncertain
AlphaFoldDBA0A061AE05;A0A087X1C5;
PfamPF01583;PF01747;PF14306;PF00067;
SUPFAMSSF52374;SSF52540;SSF88697;SSF48264;
Length654515
SequenceMLTPRDENNEGDAMPMLKKPRYSSLSGQSTNITYQEHTISREERAA...MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...

The provided sample shows that the representation of the AlphaFoldDB, Pfam and SUPFAM features need to be modified.

data = raw_data.copy()
for col in ['AlphaFoldDB', 'Pfam', 'SUPFAM']:
    data[col] = data[col].str.strip(';')
    data[col] = data[col].str.split(';')
    data[col] = data[col].apply(lambda l: l if isinstance(l, list) else [])

data.iloc[[10, 42]].T

EntryA0A061AE05A0A087X1C5
Entry NamePAPSH_CAEELCP2D7_HUMAN
Protein namesBifunctional 3'-phosphoadenosine 5'-phosphosul...Putative cytochrome P450 2D7 (EC 1.14.14.1)
Gene Namespps-1 T14G10.1CYP2D7
OrganismCaenorhabditis elegansHomo sapiens (Human)
Protein existenceEvidence at protein levelUncertain
AlphaFoldDB[A0A061AE05][A0A087X1C5]
Pfam[PF01583, PF01747, PF14306][PF00067]
SUPFAM[SSF52374, SSF52540, SSF88697][SSF48264]
Length654515
SequenceMLTPRDENNEGDAMPMLKKPRYSSLSGQSTNITYQEHTISREERAA...MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...

Exploratory data analysis

In the training phase of the classification model, only the Sequence feature is used to predict the target variable Pfam. Other features will be used to filter out unreliable or irrelevant entries.

Missing values

fig, axes = plt.subplots(2, 1, figsize=(12, 6), layout='constrained', sharex=True, height_ratios=[1, 2])
fig.suptitle('Missing values', fontsize=20)

ax = axes[0]
msno.matrix(raw_data, fontsize=14, sparkline=False, ax=ax, color=main_color)
ax.get_yaxis().set_visible(False)

ax = axes[1]
missings = raw_data.isna().sum()
miss_ticks = missings[(missings > 0) & (missings != 24054)]
bars = ax.bar(missings.index, missings, color=main_color)
ax.set_ylabel('Count', fontsize=14)
ax.tick_params(axis='y', which='major', labelsize=12)
ax.tick_params(axis='x', which='major', labelsize=12, rotation=-15)
for bar, count in zip(bars, missings):
    if count < 100:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_y() + 2000, f'{count}', ha='center', fontsize=12)
ax.set_yticks([0, *miss_ticks])
plt.show()

png

Some entries in the database are missing the Pfam feature. These records will be omitted.

Protein existance

Type of evidenceDescription
Protein uncertainUncertainty regarding the protein’s existence.
Protein predictedNo direct evidence for the protein’s existence at any level (protein, transcript, or homology).
Protein inferred by homologyLikely existence of the protein due to the presence of clear orthologs in related species.
Experimental evidence at transcript levelEvidence suggests the protein exists based on expression data for its transcript.
Experimental evidence at protein levelClear experimental proof of the protein’s existence.
existence = data['Protein existence'].value_counts().sort_values()

fig, ax = plt.subplots(1, 1, figsize=(12, 4), layout='constrained')
fig.suptitle('Protein existence', fontsize=20)

bars = ax.barh(existence.index, existence, color=main_color)
ax.set_xlabel('Count', fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=12)
for bar, count in zip(bars, existence):
    ax.text(bar.get_width() + ax.get_xlim()[1]*0.005, bar.get_y() + bar.get_height() / 2, f'{count}', va='center', fontsize=12)
ax.set_xlim(0, 430000)
plt.show()

png

Proteins in the categories Protein predicted and Protein uncertain are removed due to the lack of evidence.

Sequence length

fig, ax = plt.subplots(1, 1, figsize=(12, 3), layout='constrained')
fig.suptitle('Sequence length (all)', size=20)

medianprops = dict(linewidth=2.5, color=main_color)
flierprops = dict(marker='o', markerfacecolor='none', markersize=7, markeredgecolor=main_color)
ax.boxplot(data['Length'], vert=False, widths=[0.4], showfliers=True, flierprops=flierprops, medianprops=medianprops)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.get_yaxis().set_visible(False)
ax.set_xlabel('Length', size=14)

plt.show()

png

Most of the sequences in the dataset are less than 1000 long. We treat proteins with significantly longer sequences as outliers.

Example of a long protein sequence

data.loc[['Q9H195']].T

EntryQ9H195
Entry NameMUC3B_HUMAN
Protein namesMucin-3B (MUC-3B) (Intestinal mucin-3B)
Gene NamesMUC3B
OrganismHomo sapiens (Human)
Protein existenceEvidence at transcript level
AlphaFoldDB[Q9H195]
Pfam[]
SUPFAM[SSF82671]
Length13477
SequenceMQLLGLLSILWMLKSSPGATGTLSTATSTSHVTFPRAEATRTALSN...

Proteins with sequence longer than 1000 will be removed from the dataset to reduce computational overhead.

Distribution of sequence lengths

flengths = data.loc[data['Length'] < 1000, ['Length']]

fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True, layout='constrained')
fig.suptitle('Sequence length (< 1000)', size=20)

flengths = data.loc[data['Length'] < 1000, ['Length']]

ax = axes[0]
ax.hist(flengths, bins=40, color=main_color)
ax.set_ylabel('Count', size=14)
ax.tick_params(labelbottom=False, labelsize=12)
ax.tick_params(axis='x', which='both', length=0)

medianprops = dict(linewidth=2.5, color=main_color)
flierprops = dict(marker='o', markerfacecolor='none', markersize=7, markeredgecolor=main_color, alpha=0.002)
ax = axes[1]
ax.boxplot(flengths, vert=False, widths=[0.4], showfliers=True, flierprops=flierprops, medianprops=medianprops)
ax.get_yaxis().set_visible(False)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.set_xlabel('Length', size=14)
ax.set_xticks(list(ax.get_xticks()) + [np.median(flengths)])

plt.show()

png

The sequence length histogram has a log-normal distribution, which is commonly observed in proteins.

Organism

organism = data['Organism'].value_counts().sort_values().tail(5)
organism = organism.rename({
    "Saccharomyces cerevisiae (strain ATCC 204508 / S288c) (Baker's yeast)" : "Saccharomyces cerevisiae (Baker's yeast)"
})

fig, ax = plt.subplots(1, 1, figsize=(12, 4), layout='constrained')
fig.suptitle('Most frequent organisms in Swiss-Prot', fontsize=20)

bars = ax.barh(organism.index, organism, color=main_color)
ax.set_xlabel('Count', fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.set_xticks([0, *organism[:3], *organism[4:]])
plt.show()

png

HumanMouseMouse-ear cressRatBaker’s yeast
humanmus musculusmouse-ear cressrattus norvegicusbaker’s yeast

Protein family

fam_count = defaultdict(int)
for row in data['Pfam'].to_numpy():
    for el in row:
        fam_count[el] += 1
fams = pd.Series(dict(fam_count))
top_fams = fams.sort_values(ascending=True).tail(8)

fig, ax = plt.subplots(1, 1, figsize=(12, 4), layout='constrained')
fig.suptitle('8 most frequent protein families', fontsize=20)

bars = ax.barh(top_fams.index, top_fams, color=main_color)
ax.set_xlabel('Count', fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=12)
plt.show()

png

Protein familyDomain name
PF00009Elongation factor Tu GTP binding domain๐Ÿ”—
PF00005ABC transporter๐Ÿ”—
PF04055Radical SAM superfamily๐Ÿ”—
PF00069Protein kinase domain๐Ÿ”—
PF0192650S ribosome-binding GTPase๐Ÿ”—
PF03144Elongation factor Tu domain 2๐Ÿ”—
PF00271Helicase conserved C-terminal domain๐Ÿ”—
PF00587tRNA synthetase class II core domain๐Ÿ”—

Amino acid sequence

one2all = {'A': ('A', 'ALA', 'alanine'), 'R': ('R', 'ARG', 'arginine'), 'N': ('N', 'ASN', 'asparagine'), 'D': ('D', 'ASP', 'aspartic acid'),
           'C': ('C', 'CYS', 'cysteine'), 'Q': ('Q', 'GLN', 'glutamine'), 'E': ('E', 'GLU', 'glutamic acid'), 'G': ('G', 'GLY', 'glycine'),
           'H': ('H', 'HIS', 'histidine'), 'I': ('I', 'ILE', 'isoleucine'), 'L': ('L', 'LEU', 'leucine'), 'K': ('K', 'LYS', 'lysine'),
           'M': ('M', 'MET', 'methionine'), 'F': ('F', 'PHE', 'phenylalanine'), 'P': ('P', 'PRO', 'proline'), 'S': ('S', 'SER', 'serine'),
           'T': ('T', 'THR', 'threonine'), 'W': ('W', 'TRP', 'tryptophan'), 'Y': ('Y', 'TYR', 'tyrosine'), 'V': ('V', 'VAL', 'valine'),
           'X': ('X', 'GLX', 'glutaminx'), 'Z': ('Z', 'GLI', 'glycine'), 'J': ('J', 'NLE', 'norleucine'), 'U': ('U', 'CYC', 'cysteinc'),
           'O': ('O', 'PYL', 'pyrrolysine'), 'B': ('B', 'ASX', 'asparagine')}

counter = Counter()
for seq in data['Sequence']:
    counter.update(set(seq))

acids = pd.DataFrame(counter.keys(), columns=['Letter'])
acids['Code'] = acids['Letter'].apply(lambda l: one2all[l][1])
acids['Name'] = acids['Letter'].apply(lambda l: one2all[l][2])
acids['Total count'] = counter.values()
acids['Relative count'] = acids['Total count'] / acids['Total count'].sum()

acids.style.hide()
LetterCodeNameTotal countRelative count
PPROproline5612930.050865
HHIShistidine5386180.048810
YTYRtyrosine5500580.049847
KLYSlysine5634020.051056
FPHEphenylalanine5594570.050698
NASNasparagine5600370.050751
MMETmethionine5643680.051143
EGLUglutamic acid5617800.050909
TTHRthreonine5646170.051166
VVALvaline5666790.051353
DASPaspartic acid5609150.050831
RARGarginine5642330.051131
QGLNglutamine5574120.050513
WTRPtryptophan4541080.041152
SSERserine5667310.051358
LLEUleucine5672430.051404
GGLYglycine5673430.051413
AALAalanine5664480.051332
IILEisoleucine5646070.051165
CCYScysteine4729130.042856
XGLXglutaminx22630.000205
UCYCcysteinc2540.000023
ZGLIglycine870.000008
BASXasparagine1130.000010
OPYLpyrrolysine290.000003

Note that, in addition to the standard set of 20 amino acids, there are 5 other amino acids that are non-standard (and quite rare).


Data pre-processing

Data filtering based on EDA findings

We will not consider proteins belonging to multiple families or domains to ensure unambiguity of training and evaluation. In addition, proteins that have insufficient evidence of existence or are composed of chain lengths exceeding 1000 are also removed. The reasons are discussed in the EDA.

datac = data.copy()
datac = datac[datac['Pfam'].str.len() == 1]
datac['Pfam'] = datac['Pfam'].str[0]
datac = datac[datac['Protein existence'].isin(['Evidence at protein level', 'Evidence at transcript level', 'Inferred from homology'])]
datac = datac[datac['Length'] <= 1000]

We also filter the data belonging to the prority protein families. We extend them with sequences from unspecified classes so that the trained model can also determine whether or not a sequence belongs to the selected families.

top_fams = datac['Pfam'].value_counts()
top_fams = top_fams.sort_values(ascending=False).head(8)
intop_df = datac[datac['Pfam'].isin(top_fams.index)].copy()
other_df = datac[~datac['Pfam'].isin(top_fams.index)].copy()

indexes = other_df.sample(int(top_fams.max() * 1)).index.tolist()
indexes.extend(intop_df.index)

dataset = datac.loc[indexes]

We apply all these filters before splitting them into training and test sets to meet the objectives specified in the problem statement.

Remapping the target variable

pfam_cat = pd.CategoricalDtype([*top_fams.index, 'OTHER'], ordered=False)
dataset['Pfam'] = dataset['Pfam'].apply(lambda fam: fam if fam in top_fams else 'OTHER')
dataset['Pfam'] = dataset['Pfam'].astype(pfam_cat)

filtered_fams = dataset['Pfam'].value_counts()

fig, ax = plt.subplots(1, 1, figsize=(12, 4), layout='constrained')
fig.suptitle('Filtered protein families', fontsize=20)

bars = ax.barh(filtered_fams.index, filtered_fams, color=main_color)
ax.set_xlabel('Count', fontsize=16)
ax.tick_params(axis='both', which='major', labelsize=14)
ax.invert_yaxis()
plt.show()

png

The filtered dataset has been suitably remapped. On the new data we observe an imbalance in the classes represented, specifically class PF00005 is represented twice as much as PF00012 or PF00696. We will address this issue by undersampling.

Train-validation-test split

X_all = dataset['Sequence']
y_all = dataset['Pfam']
X_train_val, X_test, y_train_val, y_test = train_test_split(X_all, y_all, test_size=0.24)

Undersampling unbalanced training data

We now perform undersampling of the training data so that the training respects class imbalance issue. We apply it after putting the test set aside so that the integrity of the test set is not compromised.

indices = []
min_count = y_train_val.value_counts().min()
for fam in y_train_val.unique():
    fam_indices = y_train_val.loc[y_train_val == fam].index
    sample_indices = np.random.choice(fam_indices, min_count, replace=False)
    indices.extend(sample_indices)
y_train_val = y_train_val.loc[indices]
X_train_val = X_train_val.loc[indices]

X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.36)

filtered_fams = y_train_val.value_counts()
fig, ax = plt.subplots(1, 1, figsize=(12, 4), layout='constrained')
fig.suptitle('Balanced protein families (train + validation set)', fontsize=20)
bars = ax.barh(filtered_fams.index, filtered_fams, color=main_color)
ax.set_xlabel('Count', fontsize=16)
ax.tick_params(axis='both', which='major', labelsize=14)
ax.invert_yaxis()
plt.show()

png

split_df = pd.DataFrame([X_train.shape[0], X_val.shape[0], X_test.shape[0]], index=['Train', 'Validation', 'Test'], columns=['Size'])
split_df['Relative size'] = split_df['Size'] / split_df['Size'].sum()
split_df['Relative size'] = split_df['Relative size'].round(2)
split_df

SizeRelative size
Train54660.45
Validation30750.25
Test36780.30

Sequence embedding

Now we convert the protein sequences into numerical vectors. This is done using TF-IDF embedding.

pd.DataFrame([X_train, y_train]).T.head()

SequencePfam
Entry
S7Z6X4MLVQYQNLPVQNIPVLLLSCGFLAILFRSLVLRVRYYRKAQAWGCK...PF00067
Q9Y5N1MERAPPDGPLNASGALAGEAAAAGGARGFSAAWTAVLAALMALLIV...PF00001
P39583MDLSVTHMDDLKTVMEDWKNELLVYKFALDALDTKFSIISQEYNLI...OTHER
Q9R0M1MESSTAFYDYHDKLSLLCENNVIFFSTISTIVLYSLVFLLSLVGNS...PF00001
Q9A076MTLASQIATQLLDIKAVYLKPEDPFTWASGIKSPIYTDNRVTLSYP...PF00156

TF-IDF

Wikipedia [๐Ÿ”—]:

The term frequencyโ€“inverse document frequency is a measure of a word’s importance to a document within a collection.

  • term frequency \(\text{tf}(t,d)\) is the relative frequency of term \(t\) within document \(d\), and
  • inverse document frequency \(\text{idf}(t,\mathcal{D})\) is the logarithmically scaled inverse fraction of the documents \(\mathcal{D}\) that contain the term \(t\).
\[ \text{tf}(t,d) = \frac{f_{t,d}}{\sum_{t' \in d} f_{t',d}} \]

\[ \text{idf}(t,\mathcal{D}) = \log \left( {\frac{|\mathcal{D}|}{|\{ d \in \mathcal{D} \mid t \in d \}|}} \right) \]

Tfโ€“idf is calculated as \(\text{tf}(t,d) \cdot \text{idf}(t,\mathcal{D})\).

In our context, the “word” represents an individual amino acid, the “document” represents the corresponding protein sequence and the “collection” represents all available proteins in the training set.

N-gram

Wikipedia [๐Ÿ”—]:

An N-gram is a contiguous sequence of N items, such as words or characters, extracted from a text or speech. N-grams are commonly used in language modeling and statistical analysis to capture patterns and dependencies within a given dataset.

In our task, the “items” are amino acids.

Vectorizer

Below are some examples of n-grams of amino acids.

def get_vectorizer(X, ngram_range):
    vectorizer = TfidfVectorizer(strip_accents='ascii', lowercase=False, analyzer='char', ngram_range=ngram_range)
    vectorizer.fit(X)
    return vectorizer

uni = get_vectorizer(X_train, (1, 1))
bi = get_vectorizer(X_train, (2, 2))
tri = get_vectorizer(X_train, (3, 3))

vect_dict = {}
for label, vectorizer in zip(['uni', 'bi', 'tri'], [uni, bi, tri]):
    vect_dict[label] = {
        'train': {'X': np.asarray(vectorizer.transform(X_train).todense()),
                  'y': y_train.cat.codes.to_numpy()},
        'val': {'X': np.asarray(vectorizer.transform(X_val).todense()),
                'y': y_val.cat.codes.to_numpy()},
        'test': {'X': np.asarray(vectorizer.transform(X_test).todense()),
                 'y': y_test.cat.codes.to_numpy()},
    }

pd.DataFrame.from_dict({
    'unigram_vectorizer': uni.get_feature_names_out()[:19],
    'bigram_vectorizer': bi.get_feature_names_out()[:19],
    'trigram_vectorizer': tri.get_feature_names_out()[:19],
}).T.rename(columns={i:'' for i in range(19)})

unigram_vectorizerACDEFGHIKLMNPQRSTVW
bigram_vectorizerAAACADAEAFAGAHAIAKALAMANAPAQARASATAVAW
trigram_vectorizerAAAAACAADAAEAAFAAGAAHAAIAAKAALAAMAANAAPAAQAARAASAATAAVAAW

Due to the higher computational cost, we will limit the training to unigrams and bigrams.

Example of a vectorization

matrix = uni.transform([data.loc['P62547', 'Sequence']])
display(pd.DataFrame(data.loc['P62547', [
    'Entry Name', 'Organism', 'Protein existence', 'Pfam', 'Length', 'Sequence'
]]).T)
print(matrix.toarray()[0].round(2).tolist())

Entry NameOrganismProtein existencePfamLengthSequence
P62547CR16_RANCHRanoidea chloris (Red-eyed tree frog) (Litoria...Evidence at protein level[PF07440]25GLFSVLGAVAKHVLPHVVPVIAEKL
[0.33, 0.0, 0.0, 0.11, 0.11, 0.22, 0.22, 0.11, 0.22, 0.44, 0.0, 0.0, 0.22, 0.0, 0.0, 0.11, 0.0, 0.67, 0.0, 0.0, 0.0]
SequenceVectorizerUnigram vector
P62547GLFSVLGAVAKHVLPHVVPVIAEKLโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โžค[0.33, 0.0, 0.0, 0.11, 0.11, 0.22, 0.23, 0.11, 0.22, 0.44, 0.0, 0.0, 0.0, 0.22, 0.0, 0.0, 0.11, 0.0, 0.0, 0.67, 0.0, 0.0, 0.0]

Model training

We will train and compare the performance of three models:

  • ๐ŸŸ  k-nearest neighbors,
  • ๐ŸŸข decision tree,
  • ๐ŸŸฃ random forest.

When comparing different models, we will primarily consider accuracy.

\[ \text{Accuracy} = \frac{\#\text{correct classifications}}{\#\text{all classifications}} = \frac{1}{N} \sum_{i=1}^{N} \mathbb{1}_{\{\hat{y}_i = y_i\}} \]

Other metrics we will use for evaluation are precision, recall and F1 score.

Precision \(P\)Recall \(R\)\(F_1\) score
\[ \frac{TP}{TP + FP} \]\[ \frac{TP}{TP + FN} \]\[ 2 \frac{P_{i} \cdot R_{i}}{P_{i} + R_{i}} \]

K-nearest neighbours

Custom implementation with faiss library

We implement a custom algorithm using the faiss library, which provides an interface for a more efficient kNN.

class KNeighborsClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, k=5):
        self.index = None
        self.y = None
        self.k = k

    def fit(self, X, y):
        X, y = check_X_y(X, y)
        self.classes_ = unique_labels(y)

        self.index = IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        self.y = y

    def predict(self, X):
        check_is_fitted(self)
        X = check_array(X)
        
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        votes = self.y[indices]
        predictions = np.array([np.argmax(np.bincount(x)) for x in votes])
        return predictions

Hyperparameter tunning

HyperparameterDescription
n_neighborsNumber of nearest neighbours
ngramConfiguration of n-gram vectorizer
param_grid = ParameterGrid({
    'n_neighbors': range(1, 101, 9),
    'ngram': ['uni', 'bi'],
})

log_knn = pd.DataFrame(columns=['n_neighbors', 'ngram', 'train_accuracy', 'val_accuracy'])
estimators_knn = []
for params in param_grid:
    train_ngram = vect_dict[params['ngram']]['train']
    val_ngram = vect_dict[params['ngram']]['val']

    clf = KNeighborsClassifier(k=params['n_neighbors'])
    clf.fit(train_ngram['X'], train_ngram['y'])
    train_accuracy = accuracy_score(train_ngram['y'], clf.predict(train_ngram['X']))
    val_accuracy = accuracy_score(val_ngram['y'], clf.predict(val_ngram['X']))
    log_knn.loc[len(log_knn.index)] = [params['n_neighbors'], params['ngram'], train_accuracy, val_accuracy]
    estimators_knn.append(clf)

top_knns = log_knn.sort_values('val_accuracy', ascending=False)
best_knn = estimators_knn[top_knns.index[0]]
h = top_knns.head(10)
h.index = np.arange(1, len(h)+1)
h

n_neighborsngramtrain_accuracyval_accuracy
11bi1.0000000.872520
21uni1.0000000.867967
310uni0.8516280.842927
419uni0.8243690.828943
528uni0.8143070.824065
610bi0.8307720.818211
737uni0.8018660.814309
846uni0.7939990.805203
955uni0.7874130.799024
1019bi0.8020490.795122
log_knn.sort_index(inplace=True)
halpha = 0.6

# fig, axes = plt.subplots(1, 3, figsize=(14, 6), layout='constrained', sharey=True)
fig, axes = plt.subplots(1, 2, figsize=(14, 6), layout='constrained', sharey=True)
fig.suptitle('KNN learning curve', size=20)
fig.supxlabel('Number of neighbors', size=16)
fig.supylabel('Accuracy', size=16)

df = log_knn[log_knn['ngram'] == 'bi'].copy()
ax = axes[0]
ax.set_title('bigram', size=16)
ax.plot(df['n_neighbors'], df['train_accuracy'], label='train', color='black', linewidth=2, alpha=0.4)
ax.plot(df['n_neighbors'], df['val_accuracy'], label='validation', color=knn_color, linewidth=2.5)
ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=halpha)
ax.axhline(df['val_accuracy'].min(), color='black', linestyle='--', linewidth=1, alpha=halpha)
bounderies = [df['val_accuracy'].min().round(4), df['val_accuracy'].max().round(4)]
ticks = [i/100 for i in range(75, 101, 5)]
for b in bounderies:
    for tick in ticks:
        if abs(tick - b) < 0.02:
            ticks.remove(tick)
ax.set_yticks(bounderies + ticks)
for i, tick in enumerate(ax.yaxis.get_major_ticks()):
    if i in [0, 1]:
        tick.label1.set_size(14)
        tick.label1.set_weight(1000)

df = log_knn[log_knn['ngram'] == 'uni'].copy()
ax = axes[1]
ax.set_title('unigram', size=16)
ax.plot(df['n_neighbors'], df['train_accuracy'], label='train', color='black', linewidth=2, alpha=0.4)
ax.plot(df['n_neighbors'], df['val_accuracy'], label='validation', color=knn_color, linewidth=2.5)
ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=halpha)
ax.axhline(df['val_accuracy'].min(), color='black', linestyle='--', linewidth=1, alpha=halpha)

# df = log_knn[log_knn['ngram'] == 'uni'].copy()
# ax = axes[2]
# ax.set_title('unigram', size=16)
# ax.plot(df['n_neighbors'], df['train_accuracy'], label='train', color='black', linewidth=2, alpha=0.4)
# ax.plot(df['n_neighbors'], df['val_accuracy'], label='validation', color=knn_color, linewidth=2.5)
# ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=halpha)
ax.legend(loc='upper right', fontsize=14)

plt.show()

png

Model evaluation

bi_val_vec = vect_dict['bi']['val']
y_true = bi_val_vec['y']
y_hat = best_knn.predict(bi_val_vec['X']).astype(np.int8)

Confusion matrix

disp = ConfusionMatrixDisplay.from_estimator(
    best_knn, bi_val_vec['X'], bi_val_vec['y'], cmap=plt.cm.Oranges,
    display_labels=pfam_cat.categories, normalize='true'
)
disp.figure_.set_size_inches((12, 12))
ax = disp.ax_
ax.set_title('Confusion matrix (best kNN)', size=20)
ax.set_xlabel('Predicted family', size=16)
ax.set_ylabel('True family', size=16)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.grid(False)

png

Precision, recall, F1 score

report = classification_report(y_true, y_hat, target_names=pfam_cat.categories, output_dict=True)
pd.DataFrame(report).T.round(2).iloc[:-3]

precisionrecallf1-scoresupport
PF000050.940.910.93338.0
PF000690.770.940.84330.0
PF000670.910.880.89362.0
PF001561.000.930.96329.0
PF000010.980.950.96346.0
PF076900.880.980.93341.0
PF000120.691.000.81367.0
PF006960.980.940.96356.0
OTHER0.820.250.38306.0

Best kNN model

n_neighbors1
ngrambi
knn_eval = pd.DataFrame([
   accuracy_score(y_true, y_hat),
], index=['Accuracy'], columns=['KNN'])
knn_eval

KNN
Accuracy0.87252

Decision tree

Hyperparameter tunning

HyperparameterDescription
max_depthMaximum depth of the tree
criterionFunction to measure the quality of a split
ngramConfiguration of n-gram vectorizer
CriterionFormula
gini\[ H(\mathcal{D}) = \sum_{k} p_k \left( 1 - p_k \right) \]
entropy\[ H(\mathcal{D}) = - \sum_{k} p_k \log(p_k) \]
param_grid = ParameterGrid({
    'max_depth': range(1, 32, 6),
    'criterion': ['gini', 'entropy'],
    'ngram': ['uni', 'bi'],
})

log_tree = pd.DataFrame(columns=['criterion', 'max_depth', 'ngram', 'train_accuracy', 'val_accuracy'])
estimators_tree = []
for params in param_grid:
    train_ngram = vect_dict[params['ngram']]['train']
    val_ngram = vect_dict[params['ngram']]['val']

    clf = DecisionTreeClassifier(criterion=params['criterion'], max_depth=params['max_depth'])
    clf.fit(train_ngram['X'], train_ngram['y'])
    train_accuracy = accuracy_score(train_ngram['y'], clf.predict(train_ngram['X']))
    val_accuracy = accuracy_score(val_ngram['y'], clf.predict(val_ngram['X']))
    log_tree.loc[len(log_tree.index)] = [params['criterion'], params['max_depth'], params['ngram'], train_accuracy, val_accuracy]
    estimators_tree.append(clf)

top_trees = log_tree.sort_values('val_accuracy', ascending=False)
best_tree = estimators_tree[top_trees.index[0]]
h = top_trees.head(10)
h.index = np.arange(1, len(h)+1)
h

criterionmax_depthngramtrain_accuracyval_accuracy
1gini25uni1.0000000.757398
2gini13uni0.9778630.755122
3gini19uni0.9983530.754797
4gini31uni1.0000000.754472
5entropy13uni0.9895720.745691
6entropy25bi1.0000000.743089
7entropy25uni1.0000000.742764
8entropy31uni1.0000000.741789
9entropy31bi1.0000000.741138
10entropy19bi1.0000000.741138
log_tree = log_tree[log_tree['criterion'] == 'gini']
log_tree.sort_index(inplace=True)
halpha = 0.5

# fig, axes = plt.subplots(1, 3, figsize=(14, 6), layout='constrained', sharey=True)
fig, axes = plt.subplots(1, 2, figsize=(14, 6), layout='constrained', sharey=True)
fig.suptitle('Decision tree learning curve', size=20)
fig.supxlabel('Max depth', size=16)
fig.supylabel('Accuracy', size=16)

df = log_tree[log_tree['ngram'] == 'uni'].copy()
ax = axes[0]
ax.set_title('unigram', size=16)
ax.plot(df['max_depth'], df['train_accuracy'], label='train', color='black', linewidth=2, alpha=0.4)
ax.plot(df['max_depth'], df['val_accuracy'], label='validation', color=tree_color, linewidth=2.5)
ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=halpha)
ax.axhline(df['val_accuracy'].min(), color='black', linestyle='--', linewidth=1, alpha=halpha)
bounderies = [df['val_accuracy'].min().round(4), df['val_accuracy'].max().round(4)]
ticks = [i/100 for i in range(0, 101, 10)]
for b in bounderies:
    for tick in ticks:
        if abs(tick - b) < 0.04:
            ticks.remove(tick)
ax.set_yticks(bounderies + ticks)
for i, tick in enumerate(ax.yaxis.get_major_ticks()):
    if i in [0, 1]:
        tick.label1.set_size(14)
        tick.label1.set_weight(1000)

df = log_tree[log_tree['ngram'] == 'bi'].copy()
ax = axes[1]
ax.set_title('bigram', size=16)
ax.plot(df['max_depth'], df['train_accuracy'], label='train', color='black', linewidth=2, alpha=0.4)
ax.plot(df['max_depth'], df['val_accuracy'], label='validation', color=tree_color, linewidth=2.5)
ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=halpha)
ax.axhline(df['val_accuracy'].min(), color='black', linestyle='--', linewidth=1, alpha=halpha)

# df = log_tree[log_tree['ngram'] == 'uni'].copy()
# ax = axes[2]
# ax.set_title('unigram', size=16)
# ax.plot(df['max_depth'], df['train_accuracy'], label='train', color='black', linewidth=2, alpha=0.4)
# ax.plot(df['max_depth'], df['val_accuracy'], label='validation', color=tree_color, linewidth=2.5)
# ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=halpha)
# ax.axhline(df['val_accuracy'].min(), color='black', linestyle='--', linewidth=1, alpha=halpha)
ax.legend(loc='center right', fontsize=14)

plt.show()

png

Example of a classification using a decision tree

Here, we visualize a decision process for classifying the family of the Chaperone protein DnaK using a decision tree of depth 4. The process is shown in the following diagram.

SequenceVectorizerVectorDecision treePrediction
MAKVIGIDLGTTNSCV…DDNTKKSAโ”€โ”€โ”€โ”€โ”€โ”€โ”€โžค[0.33, 0.0, 0.0, 0.11, 0.11, 0.22, 0.23, 0.11, 0.22, 0.44, 0.0, 0.0, 0.0, 0.22, 0.0, 0.0, 0.11, 0.0, 0.0, 0.67, 0.0, 0.0, 0.0]โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โžคPF00012
B9JZ87_sequence = 'MAKVIGIDLGTTNSCVSVMDGKDAKVIENSEGARTTPSMVAFS' + \
'DDGERLVGQPAKRQAVTNPTNTLFAVKRLIGRRYEDPTVEKDKGLVPFPIIKGDNGDAWVE' + \
'AQGKGYSPAQISAMILQKMKETAEAYLGEKVEKAVITVPAYFNDAQRQATKDAGRIAGLEV' + \
'LRIINEPTAAALAYGLDKTEGKTIAVYDLGGGTFDISVLEIGDGVFEVKSTNGDTFLGGED' + \
'FDMRLVEYLAAEFKKEQGIELKNDKLALQRLKEAAEKAKIELSSSQQTEINLPFITADASG' + \
'PKHLTMKLTRAKFENLVDDLVQRTVAPCKAALKDAGVTAADIDEVVLVGGMSRMPKVQEVV' + \
'KQLFGKEPHKGVNPDEVVAMGAAIQAGVLQGDVKDVLLLDVTPLSLGIETLGGVFTRLIDR' + \
'NTTIPTKKSQVFSTADDNQQAVTIRVSQGEREMAQDNKLLGQFDLVGLPPSPRGVPQIEVT' + \
'FDIDANGIVQVSAKDKGTGKEQQIRIQASGGLSDADIEKMVKDAEANAEADKNRRAVVEAK' + \
'NQAESLIHSTEKSVKDYGDKVSADDRKAIEDAIAALKSSIETSEPNAEDIQAKTQTLMEVS' + \
'MKLGQAIYESQQAEGGAEGGPSGHHDDGIVDADYEEVKDDNTKKSA'

vect = vectorizer.transform([B9JZ87_sequence])
np_vect = np.asarray(vect.todense())[0]

tree_train_vec = vect_dict['uni']['train']
tree = DecisionTreeClassifier(max_depth=4)
tree.fit(tree_train_vec['X'], tree_train_vec['y'])
viz_model = dtreeviz.model(tree, X_train=tree_train_vec['X'], y_train=tree_train_vec['y'],
                           feature_names=vectorizer.get_feature_names_out(),
                           class_names=pfam_cat.categories.tolist())

viz_model.view(scale=2, orientation='TD', x=np_vect, fancy=True, show_just_path=True, instance_orientation='LR')

svg

viz_model.view(scale=1.6, orientation='LR', x=np_vect, fancy=False, leaftype='barh', instance_orientation='TD')

svg

Model evaluation

uni_val_vec = vect_dict['uni']['val']
y_true = uni_val_vec['y']
y_hat = best_tree.predict(uni_val_vec['X']).astype(np.int8)
y_hat_proba = best_tree.predict_proba(uni_val_vec['X'])

Confusion matrix

disp = ConfusionMatrixDisplay.from_estimator(
    best_tree, uni_val_vec['X'], uni_val_vec['y'], cmap=plt.cm.Greens,
    display_labels=pfam_cat.categories, normalize='true'
)
disp.figure_.set_size_inches((12, 12))
ax = disp.ax_
ax.set_title('Confusion matrix (best decision tree)', size=20)
ax.set_xlabel('Predicted family', size=16)
ax.set_ylabel('True family', size=16)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.grid(False)

png

Precision, recall, F1 score

report = classification_report(y_true, y_hat, target_names=pfam_cat.categories, output_dict=True)
pd.DataFrame(report).T.round(2).iloc[:-3]

precisionrecallf1-scoresupport
PF000050.640.690.66338.0
PF000690.670.640.65330.0
PF000670.770.770.77362.0
PF001560.770.750.76329.0
PF000010.900.900.90346.0
PF076900.890.890.89341.0
PF000120.930.900.92367.0
PF006960.820.790.80356.0
OTHER0.420.440.43306.0

Best decision tree model

criteriongini
max_depth25
ngramuni
tree_eval = pd.DataFrame([
    accuracy_score(y_true, y_hat),
], index=['Accuracy'], columns=['Decision tree'])
tree_eval

Decision tree
Accuracy0.757398

Random forest

Hyperparameter tunning

HyperparameterDescription
n_estimatorsNumber of trees in the forest
max_samplesNumber of drawn bootstrap samples
criterionFunction to measure the quality of a split
max_featuresNumber of features to consider when splitting
max_depthMaximum depth of the tree
ngramConfiguration of n-gram vectorizer
param_grid = ParameterGrid({
    'n_estimators': list(itertools.chain(range(1, 20, 6), range(20, 110, 15))),
    'max_samples': [i/10 for i in range(1, 6)],
    'max_depth': range(1, 6),
    'ngram': ['uni', 'bi'],

    'criterion': ['gini'],
    'max_features': ['log2'],
})

log_forest = pd.DataFrame(columns=['n_estimators', 'max_samples', 'max_features', 'criterion', 'max_depth', 'ngram', 'train_accuracy', 'val_accuracy'])
estimators_forest = []
for params in param_grid:
    train_ngram = vect_dict[params['ngram']]['train']
    val_ngram = vect_dict[params['ngram']]['val']
    ngram = params['ngram']
    del params['ngram']

    clf = RandomForestClassifier(n_jobs=-1, **params)
    clf.fit(train_ngram['X'], train_ngram['y'])
    train_accuracy = accuracy_score(train_ngram['y'], clf.predict(train_ngram['X']))
    val_accuracy = accuracy_score(val_ngram['y'], clf.predict(val_ngram['X']))
    log_forest.loc[len(log_forest.index)] = [params['n_estimators'], params['max_samples'], params['max_features'], params['criterion'], params['max_depth'],
                                             ngram,  train_accuracy, val_accuracy]
    estimators_forest.append(clf)

top_forests = log_forest.sort_values('val_accuracy', ascending=False)
best_forest = estimators_forest[top_forests.index[0]]
h = top_forests.head(10)
h.index = np.arange(1, len(h)+1)
h

n_estimatorsmax_samplesmax_featurescriterionmax_depthngramtrain_accuracyval_accuracy
1800.4log2gini5bi0.9154770.909268
2950.5log2gini5bi0.9130990.907642
3950.4log2gini5bi0.9158430.906016
4950.2log2gini5bi0.9163920.902439
5650.4log2gini5bi0.9061470.900813
6650.5log2gini5bi0.9070620.899187
7800.3log2gini5bi0.9140140.897886
8800.2log2gini5bi0.9123670.897236
9950.3log2gini5bi0.9151120.896911
10800.5log2gini5bi0.9149290.896585
log_forest.sort_index(inplace=True)

fig, axes = plt.subplots(1, 2, figsize=(12, 6), layout='constrained', sharey=True)
fig.suptitle('Random forest learning curve', size=20)
fig.supxlabel('Number of estimators', size=16)
fig.supylabel('Accuracy', size=16)

df = log_forest[(log_forest['max_samples'] == 0.4) &
                (log_forest['max_depth'] == 5) &
                (log_forest['ngram'] == 'bi')]
ax = axes[0]
ax.set_title('bigram', size=16)
ax.plot(df['n_estimators'], df['train_accuracy'], label='train', color='black', linewidth=3, alpha=0.5)
ax.plot(df['n_estimators'], df['val_accuracy'], label='validation', color=forest_color, linewidth=3)
ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=0.4)
ax.axhline(df['val_accuracy'].min(), color='black', linestyle='--', linewidth=1, alpha=0.4)
bounderies = [df['val_accuracy'].min().round(3), df['val_accuracy'].max().round(3)]
ticks = [i/100 for i in range(50, 101, 5)]
for b in bounderies:
    for tick in ticks:
        if abs(tick - b) < 0.02:
            ticks.remove(tick)
ax.set_yticks(bounderies + ticks)
for i, tick in enumerate(ax.yaxis.get_major_ticks()):
    if i in [0, 1]:
        tick.label1.set_size(14)
        tick.label1.set_weight(1000)

df = log_forest[(log_forest['max_samples'] == 0.4) &
                (log_forest['max_depth'] == 5) &
                (log_forest['ngram'] == 'uni')]
ax = axes[1]
ax.set_title('unigram', size=16)
ax.plot(df['n_estimators'], df['train_accuracy'], label='train', color='black', linewidth=3, alpha=0.5)
ax.plot(df['n_estimators'], df['val_accuracy'], label='validation', color=forest_color, linewidth=3)
ax.axhline(df['val_accuracy'].max(), color='black', linestyle='--', linewidth=1, alpha=0.4)
ax.axhline(df['val_accuracy'].min(), color='black', linestyle='--', linewidth=1, alpha=0.4)
ax.legend(loc='upper right')

plt.show()

png

import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

df = log_forest.copy()
fig = px.parallel_coordinates(df, color='val_accuracy', range_color=[0, 1],
                              title='Hyperparameter Parallel Coordinates',
                              color_continuous_scale=px.colors.sequential.Inferno,
                              dimensions=['max_depth', 'max_samples', 'n_estimators','val_accuracy'],
                              template='ggplot2')
fig.show()

png

Model evaluation

y_true = bi_val_vec['y']
y_hat = best_forest.predict(bi_val_vec['X']).astype(np.int8)
y_hat_proba = best_forest.predict_proba(bi_val_vec['X'])

Confusion matrix

disp = ConfusionMatrixDisplay.from_estimator(
    best_forest, bi_val_vec['X'], bi_val_vec['y'], cmap=plt.cm.Purples,
    display_labels=pfam_cat.categories, normalize='true'
)
disp.figure_.set_size_inches((12, 12))
ax = disp.ax_
ax.set_title('Confusion matrix (best random forest)', size=20)
ax.set_xlabel('Predicted family', size=16)
ax.set_ylabel('True family', size=16)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.grid(False)

png

Precision, recall, F1 score

report = classification_report(y_true, y_hat, target_names=pfam_cat.categories, output_dict=True)
pd.DataFrame(report).T.round(2).iloc[:-3]

precisionrecallf1-scoresupport
PF000050.910.980.94338.0
PF000690.840.900.87330.0
PF000670.880.950.91362.0
PF001560.880.970.92329.0
PF000010.920.980.95346.0
PF076900.930.990.95341.0
PF000120.990.930.96367.0
PF006960.930.960.95356.0
OTHER0.880.480.62306.0

Best random forest model

n_estimators80
max_samples0.4
max_featureslog2
criteriongini
max_depth5
ngrambi
forest_eval = pd.DataFrame([
    accuracy_score(y_true, y_hat),
], index=['Accuracy'], columns=['Random forest'])
forest_eval

Random forest
Accuracy0.909268

Final model

Comparing models

eval = pd.concat([forest_eval.T, knn_eval.T, tree_eval.T]).round(3)
eval

Accuracy
Random forest0.909
KNN0.873
Decision tree0.757
fig, ax = plt.subplots(1, 1, figsize=(12, 5))
ax.set_title('Accuracy')
bars = ax.bar(eval.index, eval.loc[:, 'Accuracy'], color=[forest_color, knn_color, tree_color], width=0.5)
ax.set_ylim(0, 1)
ax.bar_label(bars)

plt.show()

png

Random forest achieves the highest accuracy.

best_model = best_forest

Evaluation on the test set

To evaluate the model on the imbalanced test set, we use the weighted \(F_1\) score.

\[ \text{F}_{1, \text{weighted}}(y, \hat{y}) = \sum_{k} \frac{|\{y_i \in y \mid y_i = k\}|}{|\{y_i \in y\}|} \left( 2 \frac{P \cdot R}{P + R} \right) \]
bi_test_vec = vect_dict['bi']['test']
y_true = bi_test_vec['y']
y_hat = best_model.predict(bi_test_vec['X']).astype(np.int8)
y_hat_proba = best_model.predict_proba(bi_test_vec['X'])

pd.DataFrame([accuracy_score(y_true, y_hat),
              f1_score(y_true, y_hat, average='weighted')],
              index=['Test accuracy', 'Weighted F1 score'],
              columns=[''])

Test accuracy0.880642
Weighted F1 score0.870774

Confusion matrix

disp = ConfusionMatrixDisplay.from_estimator(
    best_model, bi_test_vec['X'], y_true, cmap=plt.cm.Blues,
    display_labels=pfam_cat.categories, normalize='true'
)
disp.figure_.set_size_inches((12, 12))
ax = disp.ax_
ax.set_title('Confusion matrix (best model, test data)', size=20)
ax.set_xlabel('Predicted family', size=16)
ax.set_ylabel('True family', size=16)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.grid(False)

png

The model is weakest when classifying the “other” class, which is often confused with one of the priority families.

Precision a recall, F1 score

Precision \(P\)Recall \(R\)\(F_1\) score
\[ \frac{T_p}{T_p + F_p} \]\[ \frac{T_p}{T_p + F_n} \]\[ 2 \frac{P \cdot R}{P + R} \]
report = classification_report(y_true, y_hat, target_names=pfam_cat.categories, output_dict=True)
pd.DataFrame(report).T.round(2).iloc[:-3]

precisionrecallf1-scoresupport
PF000050.930.960.95562.0
PF000690.790.890.84426.0
PF000670.870.940.90427.0
PF001560.810.950.88390.0
PF000010.930.950.94430.0
PF076900.901.000.94339.0
PF000120.970.950.96282.0
PF006960.870.960.92322.0
OTHER0.860.450.59500.0
fig, axes = plt.subplots(3, 3, figsize=(12, 10), layout='constrained')
fig.suptitle('Precision recall curve', fontsize=20)
fig.supxlabel('Recall', size=16)
fig.supylabel('Precision', size=16)
for i, (cat, ax) in enumerate(zip(pfam_cat.categories, fig.axes)):
    disp = PrecisionRecallDisplay.from_predictions(y_true == i, y_hat_proba[:, i], ax=ax, name=cat, linewidth=1.5, alpha=1, color=main_color)
    ax.set(xlabel=None, ylabel=None)

plt.show()

png

ROC a AUC

The ROC curve is a plot of the ratio of FPR and TPR versus the parameterized threshold.

$$ \text{TPR} = \frac{\text{TP}}{\text{TP + FN}} \qquad \text{FPR} = \frac{\text{FP}}{\text{FP + TN}} $$

fig, axes = plt.subplots(3, 3, figsize=(12, 10), layout='constrained')
fig.suptitle('ROC curve', fontsize=20)
fig.supxlabel('False positive rate', size=16)
fig.supylabel('True postive rate', size=16)
for i, (cat, ax) in enumerate(zip(pfam_cat.categories, fig.axes)):
    disp = RocCurveDisplay.from_predictions(y_true == i, y_hat_proba[:, i], ax=ax, name=f'{cat} vs the rest', linewidth=1.5, color=main_color)
    ax.plot([(0, 0), (1, 1)], color='k', alpha=0.5, linestyle='--', linewidth=0.8)
    ax.set(xlabel=None, ylabel=None)
    
plt.show()

png

Visualization of the classification on sample data

In the following example, we visualize the success of the model on a sample of 100 test data. We represent the values by projecting them onto the first two PCA components.

pca = PCA(n_components=2)
pca_X = pca.fit_transform(bi_test_vec['X'])

fig, axes = plt.subplots(1, 2, figsize=(12, 5), layout='constrained', sharey=True)
fig.suptitle('Test data prediction (sample)', size=20)
fig.supxlabel('First component', size=14)
fig.supylabel('Second component', size=14)

rs = np.random.RandomState(2)
sample = rs.choice(np.arange(len(y_true)), 100, replace=False)

ax = axes[0]
ax.scatter(pca_X[sample, 0], pca_X[sample, 1], c=(y_true[sample]), cmap='Set1')
ax.set_title('True labels', size=16)

ax = axes[1]
scatter = ax.scatter(pca_X[sample, 0], pca_X[sample, 1], c=(y_hat[sample]), cmap='Set1')
ax.set_title('Predicted labels', size=16)
legend1 = ax.legend(scatter.legend_elements()[0], pfam_cat.categories.tolist(), ncol=5, loc='upper center',
                    fancybox=True, framealpha=1, bbox_to_anchor=(-0.09, 1), fontsize=10)
ax.add_artist(legend1)

ax.set_ylim(-0.3, 0.7)

plt.show()

png

pca = PCA(n_components=2)
pca_X = pca.fit_transform(bi_test_vec['X'])

fig, ax = plt.subplots(1, 1, figsize=(12, 8), layout='constrained', sharey=True)
fig.supxlabel('First component', size=14)
fig.supylabel('Second component', size=14)

correct = sample[y_true[sample] == y_hat[sample]]
inccorect = sample[y_true[sample] != y_hat[sample]]
ax.scatter(pca_X[correct, 0], pca_X[correct, 1], c='grey', alpha=0.5, label='correct', s=80)
ax.scatter(pca_X[inccorect, 0], pca_X[inccorect, 1], c='red', alpha=0.8, label='incorrect', s=80)
ax.set_title('Evaluation', size=16)
ax.legend(loc='upper right', fontsize=12)

ax.set_ylim(-0.3, 0.55)

plt.show()

png