This project explores various deep learning approaches to classify images from the Fashion MNIST dataset. We’ll compare different neural network architectures and regularization techniques to identify the most effective model.

import copy
import pathlib
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.ticker import MaxNLocator
from scipy.ndimage import zoom
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from torch import nn
from torch.nn.functional import relu
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

pd.set_option("display.max_columns", 10)
pd.set_option("display.precision", 3)
plt.style.use("default")

random_state = 42
batch_size = 256
max_epochs = 30

!command -v nvidia-smi &> /dev/null && nvidia-smi

Data Preparation

We start by loading the Fashion MNIST dataset and splitting it into training, validation, and test sets.

workdir = pathlib.Path(".")

train_file = workdir / "train.csv"
eval_file = workdir / "evaluate.csv"
res_file = workdir / "results.csv"

artifacts_path = workdir / "artifacts"
models_path = artifacts_path / "models"
logs_path = artifacts_path / "logs"
outputs_path = artifacts_path / "outputs"

artifacts_path.mkdir(parents=True, exist_ok=True)
models_path.mkdir(parents=True, exist_ok=True)
logs_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

df = pd.read_csv(train_file)
display(df.head(5))

cats = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]
MAP_IL = {i: c for i, c in enumerate(cats)}
VEC_IL = np.vectorize(MAP_IL.get)

d = df.drop("label", axis=1)
fig, ax = plt.subplots(4, 8, figsize=(10, 6))
fig.suptitle("Sample from training data (inverted brightness)")
for j in range(4):
    for i in range(8):
        ax[j, i].imshow(d.iloc[5 * j + i].to_numpy().reshape((32, 32)), cmap="Greys")
        ax[j, i].set_title(MAP_IL[df.loc[5 * j + i, "label"]], fontsize=11)
        ax[j, i].set_axis_off()
plt.show()

pix1pix2pix3pix4pix5...pix1021pix1022pix1023pix1024label
000000...00003
111111...11113
211111...11117
300000...00009
411111...11115

5 rows × 1025 columns

png

The dataset consists of 32x32 pixel grayscale images, each belonging to one of 10 clothing categories:

  • T-shirt/top
  • Trouser
  • Pullover
  • Dress
  • Coat
  • Sandal
  • Shirt
  • Sneaker
  • Bag
  • Ankle boot

Train-Validation-Test Split

We split the data into training (60%), validation (20%), and test (20%) sets.

X_all = df.drop("label", axis=1).to_numpy(np.uint8)
y_all = df.loc[:, "label"].to_numpy(np.uint8)

X_train_val, X_test, y_train_val, y_test = train_test_split(
    X_all, y_all, test_size=0.2, random_state=random_state
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.25, random_state=random_state
)

d = pd.DataFrame(
    [
        ["train", X_train.shape[0], X_train.shape[0] / X_all.shape[0]],
        ["validation", X_val.shape[0], X_val.shape[0] / X_all.shape[0]],
        ["test", X_test.shape[0], X_test.shape[0] / X_all.shape[0]],
    ],
    columns=["dataset", "count", "relative count"],
)
display(d)

fig, ax = plt.subplots(figsize=(4, 4))
ax.pie(d["count"], labels=d["dataset"], autopct="%1.f%%", wedgeprops={"alpha": 0.95})
ax.set_title("Train-validation-test split")
plt.show()

datasetcountrelative count
0train315000.6
1validation105000.2
2test105000.2

png

Exploratory Data Analysis

Let’s examine the features (pixels) and target variable (class).

Features (Pixels)

d = pd.DataFrame(
    [
        ["count", X_train.shape[0]],
        ["size", X_train.shape[1]],
        ["min", X_train.min()],
        ["mean", X_train.mean().astype(int)],
        ["max", X_train.max()],
    ],
    columns=["stat", "value"],
)
display(d.set_index("stat").T)

d = X_train.flatten()
d2 = pd.DataFrame(
    [
        ["black pixel (lte 10)", d[d <= 10].shape[0]],
        ["non-black pixel (gt 10)", d[d > 10].shape[0]],
    ],
    columns=["color", "count"],
)

fig, ax = plt.subplots(figsize=(6, 3))
ax.bar(d2["color"], d2["count"])
ax.set_title("Pixel count: black v. non-black")
ax.set_ylabel("count")
plt.show()

d = X_train.flatten()

fig, ax = plt.subplots(figsize=(6, 4))
ax.hist(d[d > 10], bins=20)
ax.set_title("Histogram of gray level distribution (gt 10)")
ax.set_xlabel("grey level")
ax.set_ylabel("frequency")
plt.show()

statcountsizeminmeanmax
value315001024044255

png

png

Images are mostly black pixels. Non-black pixels frequently have a brightness around 200.

Target variable (class)

d = []
for c in range(10):
    d.append(
        [
            f"{MAP_IL[c]}",
            sum(y_train == c),
            round(sum(y_train == c) / y_train.shape[0], 3),
        ]
    )
d = pd.DataFrame(d, columns=["label", "count", "relative count"])
display(d)

fig, ax = plt.subplots(figsize=(4, 4))
ax.pie(
    d.loc[:, "count"],
    labels=d.loc[:, "label"],
    autopct="%.0f%%",
    wedgeprops={"alpha": 0.8},
)
ax.set_title("Class distribution")
plt.show()

labelcountrelative count
0T-shirt/top31030.099
1Trouser30520.097
2Pullover31700.101
3Dress31790.101
4Coat30670.097
5Sandal31990.102
6Shirt31370.100
7Sneaker31830.101
8Bag32250.102
9Ankle boot31850.101

png

Class distribution is balanced. No resampling is needed.

for c in range(10):
    fig, ax = plt.subplots(1, 10, figsize=(10, 1.5))
    fig.suptitle(f"{MAP_IL[c]}")

    d = X_train[y_train == c]
    for i in range(10):
        ax[i].imshow(d[i].reshape((32, 32)), cmap="Greys")
        ax[i].set_axis_off()
    plt.show()

png

png

png

png

png

png

png

png

png

png

Distinguishing between “T-shirt/top” and “Shirt”, or “Pullover” and “Coat”, or “Sneaker” and “Ankle boot” might be challenging for the model.


Data Preprocessing and Augmentation

Image Augmentation

To make the model robust to variations, we augment the data with horizontal flips and zooms.

def augment(x, k):
    x1 = np.fliplr(x)
    x2 = zoom(x[k : 32 - k, k : 32 - k], 32 / (32 - 2 * k), order=0)
    x3 = np.fliplr(x2)
    return x1, x2, x3


def data_augment(X, y):
    X_aug = np.empty((4 * X.shape[0], X.shape[1]), dtype=np.uint8)
    y_aug = np.repeat(y, 4)
    for i in range(X.shape[0]):
        x0 = X[i].reshape((32, 32))
        x1, x2, x3 = augment(x0, 2)
        X_aug[4 * i] = x0.flatten()
        X_aug[4 * i + 1] = x1.flatten()
        X_aug[4 * i + 2] = x2.flatten()
        X_aug[4 * i + 3] = x3.flatten()
    return X_aug, y_aug


new_X_train, new_y_train = data_augment(X_train, y_train)
new_X_val, new_y_val = data_augment(X_val, y_val)


fig, axes = plt.subplots(5, 4, figsize=(8, 8))
for i in range(5):
    x1 = X_train[i].reshape((32, 32))
    x2, x3, x4 = augment(x1, k=5)

    ax = axes[i]
    ax[0].imshow(x1, cmap="Greys")
    ax[1].imshow(x2, cmap="Greys")
    ax[2].imshow(x3, cmap="Greys")
    ax[3].imshow(x4, cmap="Greys")

    if i == 0:
        ax[0].set_title("original")
        ax[1].set_title("flipped")
        ax[2].set_title("cropped")

    for j in range(4):
        ax[j].set_axis_off()

plt.show()

png

This augmentation helps the model become invariant to image orientation and size.

Data and Device Setup

We prepare the datasets for batch processing and identify the available device (CPU/GPU).

train_dataloader = DataLoader(
    TensorDataset(
        torch.Tensor(new_X_train).reshape((-1, 1, 32, 32)),
        torch.Tensor(new_y_train).type(torch.uint8),
    ),
    batch_size=batch_size,
)

val_dataloader = DataLoader(
    TensorDataset(
        torch.Tensor(new_X_val).reshape((-1, 1, 32, 32)),
        torch.Tensor(new_y_val).type(torch.uint8),
    ),
    batch_size=batch_size,
)

test_dataloader = DataLoader(
    TensorDataset(
        torch.Tensor(X_test).reshape((-1, 1, 32, 32)),
        torch.Tensor(y_test).type(torch.uint8),
    ),
    batch_size=batch_size,
)

for X, y in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using {device} device")
Shape of X [N, C, H, W]: torch.Size([256, 1, 32, 32])
Shape of y: torch.Size([256]) torch.uint8
Using cpu device

Feedforward Neural Networks

Here, we define a set of functions for training neural networks, incorporating early stopping for regularization.

class EarlyStopping:
    def __init__(self, tolerance, model):
        self.tolerance = tolerance
        self.counter = 0
        self.early_stop = False
        self.init = False
        self.max_acc = 0
        self.best_model = model

    def __call__(self, model, val_acc):
        if not self.init:
            self.init = True
            self.max_acc = val_acc
            self.best_model = copy.deepcopy(model)
            return

        if val_acc > self.max_acc:
            self.counter = 0
            self.max_acc = val_acc
            self.best_model = copy.deepcopy(model)
            return
        else:
            self.counter += 1
            if self.counter >= self.tolerance:
                self.early_stop = True
            return


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()

    total_loss, total_correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        with torch.no_grad():
            total_loss += loss.item()
            total_correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)

    total_loss /= num_batches
    total_correct /= size
    return total_correct, total_loss


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()

    loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    loss /= num_batches
    correct /= size

    print(
        f"Validation Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {loss:>8f} \n"
    )
    return correct, loss


def predict_proba(dataloader, model, out_path):
    if out_path.is_file():
        return pickle.load(open(out_path, "rb"))

    size = len(dataloader.dataset)
    y_hat = np.empty((0, 10), dtype=np.uint8)
    model.eval()
    with torch.no_grad():
        for X in dataloader:
            X = X[0].to(device)
            logits = model(X)
            y_hat = np.concatenate((y_hat, logits.cpu()))

    pickle.dump(y_hat, open(out_path, "wb"))
    return y_hat


def predict(dataloader, model, out_path):
    return np.argmax(predict_proba(dataloader, model, out_path), axis=1)


def load_model(file, m):
    model_path = models_path / (file + ".pt")
    m.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
    return m


def train_model(file, m, optimizer, early_stopping=3, force_train=False):
    model_path = models_path / (file + ".pt")
    log_path = logs_path / (file + ".pickle")

    if model_path.is_file() and log_path.is_file() and not force_train:
        m = m.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
        logs = pickle.load(open(log_path, "rb"))
    else:
        m = m.to(device)
        loss_fn = nn.CrossEntropyLoss()
        stopping = EarlyStopping(early_stopping, model)

        logs = []
        for t in range(max_epochs):
            print(f"Epoch {t+1}\n-------------------------------")
            tacc, tloss = train(train_dataloader, m, loss_fn, optimizer)
            vacc, vloss = test(val_dataloader, m, loss_fn)
            logs.append((tacc, tloss, vacc, vloss))

            stopping(m, vacc)
            if stopping.early_stop:
                break

        m = stopping.best_model
        torch.save(m.state_dict(), model_path)
        pickle.dump(logs, open(log_path, "wb"))

    stats = ["train_accuracy", "train_loss", "val_accuracy", "val_loss"]
    d = pd.DataFrame(logs, columns=stats)
    d.index += 1
    return d


def display_result(d):
    display(d.tail(5))

    fig, axes = plt.subplots(1, 2, figsize=(10, 4), layout="constrained")
    fig.suptitle("Learning curve")

    ax = axes[0]
    ax.set_title("Loss")
    ax.plot(d.index, d.loc[:, "train_loss"], label="train", linestyle="dashed")
    ax.plot(d.index, d.loc[:, "val_loss"], label="validation")
    ax.set_xlabel("epoch")
    ax.set_ylabel("loss")
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.legend()

    ax = axes[1]
    ax.set_title("Accuracy")
    ax.plot(d.index, d.loc[:, "train_accuracy"], label="train", linestyle="dashed")
    ax.plot(d.index, d.loc[:, "val_accuracy"], label="validation")
    ax.set_xlabel("epoch")
    ax.set_ylabel("accuracy")
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.legend()

    plt.show()

Model suitability

Fully connected (dense) neural network models are very flexible and should perform relatively well on image data. However, we expect convolutional networks to perform significantly better. Compared to traditional models, neural networks are capable of modeling complicated functions, which on the one hand is desirable because it is certainly a complex function, however, it can also lead to overfitting or difficult training.

Base Network

This is our baseline fully connected neural network model.

class BaseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = BaseNetwork()
optimizer = torch.optim.Adam(model.parameters())
d_base = train_model("base", model, optimizer)
display_result(d_base)

train_accuracytrain_lossval_accuracyval_loss
100.8510.4040.8330.467
110.8520.3940.8270.488
120.8560.3880.8230.504
130.8560.3830.8280.472
140.8590.3770.8210.505

png

This base network serves as a reference for subsequent model modifications.

Wide Layers

We explore the impact of wider hidden layers.

class WideNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = WideNetwork()
optimizer = torch.optim.Adam(model.parameters())
d_wide = train_model("wide", model, optimizer)
display_result(d_wide)

train_accuracytrain_lossval_accuracyval_loss
120.8530.3960.8390.478
130.8550.3860.8330.470
140.8580.3780.8360.475
150.8610.3710.8350.495
160.8640.3650.8370.495

png

A slight improvement over the baseline is observed with wider layers.

Deep Network

Next, we evaluate a deeper network with more hidden layers.

class DeepNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = DeepNetwork()
optimizer = torch.optim.Adam(model.parameters())
d_deep = train_model("deep", model, optimizer)
display_result(d_deep)

train_accuracytrain_lossval_accuracyval_loss
260.8830.3130.8490.467
270.8840.3110.8510.466
280.8830.3140.8530.468
290.8880.3030.8490.476
300.8870.3020.8460.474

png

A deeper network shows significant improvement, though training time also increased.

SGD Optimizer

We test the Stochastic Gradient Descent (SGD) optimizer.

model = BaseNetwork()
optimizer = torch.optim.SGD(model.parameters())
d_sgd = train_model("sgd", model, optimizer)
display_result(d_sgd)

train_accuracytrain_lossval_accuracyval_loss
260.8740.3460.8390.442
270.8760.3410.8400.438
280.8780.3370.8410.438
290.8790.3320.8420.437
300.8810.3280.8430.436

png

SGD offers more stable training, though it converges slower. Longer training might yield better results.

L2 Regularization

We apply L2 regularization to the base network.

model = BaseNetwork()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-2)
d_l2 = train_model("l2", model, optimizer)
display_result(d_l2)

train_accuracytrain_lossval_accuracyval_loss
190.8010.5400.8100.527
200.8000.5390.7970.558
210.8010.5410.8090.523
220.8020.5390.8040.538
230.8010.5360.8080.523

png

L2 regularization did not provide much benefit for this network. It might be more effective with larger models to prevent overfitting.

Batch Normalization

Batch normalization is introduced to improve training stability and performance.

class BatchNormNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(32 * 32, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = BatchNormNetwork()
optimizer = torch.optim.Adam(model.parameters())
d_batch_norm = train_model("batch_norm", model, optimizer)
display_result(d_batch_norm)

train_accuracytrain_lossval_accuracyval_loss
120.9110.2360.8430.495
130.9190.2170.8410.524
140.9240.2020.8400.545
150.9300.1870.8360.575
160.9340.1740.8360.585

png

Batch normalization improves performance, but overfitting occurs with longer training, as indicated by the widening gap between training and validation curves.

Dropout

We incorporate dropout layers to mitigate overfitting.

class DropoutNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(32 * 32, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = DropoutNetwork()
optimizer = torch.optim.Adam(model.parameters())
d_dropout = train_model("dropout", model, optimizer)
display_result(d_dropout)

train_accuracytrain_lossval_accuracyval_loss
160.7480.6870.7880.596
170.7480.6880.7510.658
180.7500.6860.7830.612
190.7490.6840.7800.599
200.7490.6890.7540.643

png

Dropout leads to a performance drop but is expected to improve generalization, especially in larger models. Validation accuracy is higher than training accuracy because dropout is disabled during evaluation.

Summary of Feedforward Networks

def display_summary(logs, labels):
    accs = [l.loc[:, "val_accuracy"].max() for l in logs]

    fig, ax = plt.subplots(figsize=(10, 6))
    fig.suptitle("Comparison of networks")
    for d, label in zip(logs, labels):
        ax.plot(d.index, d.loc[:, "val_loss"], label=label, alpha=0.95)
    ax.set_title("Learning curve")
    ax.set_xlabel("epoch")
    ax.set_ylabel("validation loss")
    ax.legend()
    plt.show()

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 4))
    fig.subplots_adjust(hspace=0.05)
    for i, l in enumerate(labels):
        ax1.bar(l, accs[i], alpha=0.7)
        ax2.bar(l, accs[i], alpha=0.7)
    ax1.set_ylim(0.75, 0.9)
    ax2.set_ylim(0, 0.2)
    ax1.spines.bottom.set_visible(False)
    ax2.spines.top.set_visible(False)
    ax1.set_xticks([])
    ax2.set_yticks([0, 0.04, 0.10, 0.15])
    ax1.xaxis.tick_top()

    ax1.axhline(accs[0], alpha=0.4, c="black", linestyle="--", linewidth=0.51)

    d = 0.5
    kwargs = dict(
        marker=[(-1, -d), (1, d)],
        markersize=12,
        linestyle="none",
        color="k",
        mec="k",
        mew=1,
        clip_on=False,
    )
    ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
    ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

    ax1.set_title("Best accuracy")
    ax1.set_ylabel("validation accuracy")
    plt.show()


display_summary(
    [d_base, d_wide, d_deep, d_sgd, d_l2, d_batch_norm, d_dropout],
    ["base", "wide", "deep", "SGD", "L2", "batch norm", "dropout"],
)

png

png

Deeper and wider networks outperform the baseline. SGD provides stable training. Batch normalization yields good results, while L2 and dropout did not improve performance on these smaller models.


Convolutional Neural Networks (CNNs)

CNNs are generally better for image data. We expect them to perform better than fully connected networks.

Base CNN

This is our baseline CNN model.

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(2304, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(relu(self.conv1(x)))
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = CNN()
optimizer = torch.optim.Adam(model.parameters())
d_base = train_model("base_cnn", model, optimizer)
display_result(d_base)

train_accuracytrain_lossval_accuracyval_loss
80.9110.2340.8750.381
90.9170.2220.8680.414
100.9210.2100.8670.432
110.9240.1990.8700.435
120.9260.1960.8720.436

png

The base CNN already shows significant improvement over fully connected networks and converges quickly.

Wide CNN

class WideCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(2304, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(relu(self.conv1(x)))
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = WideCNN()
optimizer = torch.optim.Adam(model.parameters())
d_wide = train_model("wide_cnn", model, optimizer)
display_result(d_wide)

train_accuracytrain_lossval_accuracyval_loss
40.8800.3210.8620.380
50.8890.2950.8530.410
60.8980.2720.8490.448
70.9040.2560.8540.444
80.9110.2380.8570.444

png

A wider CNN layer provides no improvement and still converges quickly.

Deep CNN

class DeepCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 3)
        self.conv2 = nn.Conv2d(64, 128, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(4608, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(relu(self.conv1(x)))
        x = self.pool(relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = relu(self.fc1(x))
        x = relu(self.fc2(x))
        x = relu(self.fc3(x))
        x = self.fc4(x)
        return x


model = DeepCNN()
optimizer = torch.optim.Adam(model.parameters())
d_deep = train_model("deep_cnn", model, optimizer)
display_result(d_deep)

train_accuracytrain_lossval_accuracyval_loss
70.9120.2360.8780.359
80.9170.2230.8680.390
90.9230.2040.8620.423
100.9260.1950.8650.418
110.9310.1860.8720.434

png

A deeper CNN offers minor improvement, but training becomes more volatile and shows signs of overfitting.

SGD Optimizer for CNN

model = CNN()
optimizer = torch.optim.SGD(model.parameters())
d_sgd = train_model("sgd_cnn", model, optimizer)
display_result(d_sgd)

train_accuracytrain_lossval_accuracyval_loss
260.9080.2560.8640.369
270.9090.2530.8650.367
280.9110.2490.8660.363
290.9120.2460.8680.359
300.9130.2420.8690.357

png

Similar to the fully connected networks, SGD offers stable but slower convergence for CNNs. Further training could be beneficial.

L2 Regularization for CNN

model = CNN()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-2)
d_l2 = train_model("l2_cnn", model, optimizer)
display_result(d_l2)

train_accuracytrain_lossval_accuracyval_loss
120.8740.3430.8690.346
130.8750.3430.8600.370
140.8760.3400.8680.353
150.8760.3390.8690.350
160.8770.3360.8680.349

png

L2 regularization leads to more stable training for CNNs, even if it doesn’t always boost accuracy. This stability is crucial for larger models to generalize well.

Batch Normalization for CNN

class BatchNormCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(2304, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(relu(self.conv1_bn(self.conv1(x))))
        x = self.pool(relu(self.conv2_bn(self.conv2(x))))
        x = torch.flatten(x, 1)
        x = relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = BatchNormCNN()
optimizer = torch.optim.Adam(model.parameters())
d_batch_norm = train_model("batch_norm_cnn", model, optimizer)
display_result(d_batch_norm)

train_accuracytrain_lossval_accuracyval_loss
90.9250.2030.8810.337
100.9310.1880.8770.368
110.9350.1760.8800.376
120.9400.1630.8810.386
130.9450.1510.8790.399

png

Batch normalization for CNNs results in a smoother training curve and a slight accuracy boost.

Dropout for CNN

class DropoutCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.dropout = nn.Dropout2d(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(2304, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(self.dropout(relu(self.conv1(x))))
        x = self.pool(self.dropout(relu(self.conv2(x))))
        x = torch.flatten(x, 1)
        x = relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = DropoutCNN()
optimizer = torch.optim.Adam(model.parameters())
d_dropout = train_model("dropout_cnn", model, optimizer)
display_result(d_dropout)

train_accuracytrain_lossval_accuracyval_loss
210.8580.3790.8740.332
220.8610.3750.8650.350
230.8600.3760.8720.338
240.8610.3730.8710.332
250.8620.3730.8700.341

png

Dropout works better with CNNs than with fully connected networks. It helps prevent overfitting, leading to better generalization.

Summary of CNNs

display_summary(
    [d_base, d_wide, d_deep, d_sgd, d_l2, d_batch_norm, d_dropout],
    ["base", "wide", "deep", "SGD", "L2", "batch norm", "dropout"],
)

png

png

CNNs generally outperform fully connected networks on image data. Batch normalization and deeper networks show improved performance. Regularization techniques like L2 and dropout help stabilize training and improve generalization.


Final Model

Our final model combines successful elements from previous experiments: a deep and wide CNN architecture with batch normalization, dropout, and L2 regularization. We will train it using Adam initially, then fine-tune with SGD for optimal accuracy and stability.

class FinalCNNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Conv2d(1, 64, 2)
        self.conv1_bn = nn.BatchNorm2d(64)
        self.conv1_do = nn.Dropout2d(0.1)

        self.conv2 = nn.Conv2d(64, 128, 2)
        self.conv2_bn = nn.BatchNorm2d(128)
        self.conv2_do = nn.Dropout2d(0.3)

        self.conv3 = nn.Conv2d(128, 256, 2)
        self.conv3_bn = nn.BatchNorm2d(256)
        self.conv3_do = nn.Dropout2d(0.5)

        self.conv4 = nn.Conv2d(256, 512, 4)
        self.conv4_bn = nn.BatchNorm2d(512)
        self.conv4_do = nn.Dropout2d(0.5)

        self.fc1 = nn.Linear(4608, 512)
        self.fc1_bn = nn.BatchNorm1d(512)
        self.fc1_do = nn.Dropout(0.1)

        self.fc2 = nn.Linear(512, 256)
        self.fc2_bn = nn.BatchNorm1d(256)
        self.fc2_do = nn.Dropout(0.3)

        self.fc3 = nn.Linear(256, 128)
        self.fc3_bn = nn.BatchNorm1d(128)
        self.fc3_do = nn.Dropout(0.5)

        self.fc4 = nn.Linear(128, 64)
        self.fc4_bn = nn.BatchNorm1d(64)
        self.fc4_do = nn.Dropout(0.5)

        self.fc5 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(self.conv1_do(relu(self.conv1_bn(self.conv1(x)))))
        x = self.pool(self.conv2_do(relu(self.conv2_bn(self.conv2(x)))))
        x = self.conv3_do(relu(self.conv3_bn(self.conv3(x))))
        x = self.conv4_do(relu(self.conv4_bn(self.conv4(x))))

        x = torch.flatten(x, 1)
        x = self.fc1_do(relu(self.fc1_bn(self.fc1(x))))
        x = self.fc2_do(relu(self.fc2_bn(self.fc2(x))))
        x = self.fc3_do(relu(self.fc3_bn(self.fc3(x))))
        x = self.fc4_do(relu(self.fc4_bn(self.fc4(x))))
        x = self.fc5(x)
        return x


RETRAIN = False
max_epochs = 200
model = FinalCNNet()

optimizer0 = torch.optim.Adam(model.parameters(), weight_decay=1e-3)
optimizer1 = torch.optim.SGD(model.parameters(), weight_decay=1e-3)
optimizer2 = torch.optim.SGD(model.parameters())

d_final0 = train_model("final0", model, optimizer0, 4, RETRAIN)
d_final1 = train_model("final1", model, optimizer1, 4, RETRAIN)
d_final2 = train_model("final2", model, optimizer2, 4, RETRAIN)

display_result(pd.concat((d_final0, d_final1, d_final2)).reset_index())

indextrain_accuracytrain_lossval_accuracyval_loss
149390.9090.2820.9140.246
150400.9080.2830.9140.246
151410.9090.2810.9140.246
152420.9080.2840.9140.246
153430.9100.2810.9140.246

png

Evaluation

We now evaluate the final model’s performance on the test data.

final = load_model("final2", FinalCNNet()).to(device)

y_test_proba_pred_path = outputs_path / "y_test_proba_pred.pickle"
y_test_proba_hat = predict_proba(test_dataloader, final, y_test_proba_pred_path)
y_test_hat = np.argmax(y_test_proba_hat, axis=1)

acc = np.round(accuracy_score(y_test, y_test_hat), 3)
print(f"Test accuracy: {acc}")
Test accuracy: 0.916

The model achieved a test accuracy of 91.6%.

report = pd.DataFrame(
    classification_report(y_test, y_test_hat, target_names=cats, output_dict=True)
).T
display(report)

precisionrecallf1-scoresupport
T-shirt/top0.8580.8870.8721115.000
Trouser0.9950.9940.9951101.000
Pullover0.8820.8710.8771047.000
Dress0.8920.9110.9011028.000
Coat0.8610.8670.8641061.000
Sandal0.9890.9670.9781024.000
Shirt0.7710.7430.7561058.000
Sneaker0.9510.9770.964998.000
Bag0.9910.9770.9841044.000
Ankle boot0.9780.9740.9761024.000
accuracy0.9160.9160.9160.916
macro avg0.9170.9170.91710500.000
weighted avg0.9160.9160.91610500.000
def display_confusion_matrix(ax, norm=None):
    ConfusionMatrixDisplay.from_predictions(
        VEC_IL(y_test),
        VEC_IL(y_test_hat),
        labels=cats,
        normalize=norm,
        xticks_rotation="vertical",
        colorbar=False,
        cmap="Blues",
        ax=ax,
    )
    ax.set_title("Confusion matrix (final model, test data)")
    ax.set_xlabel("Predicted class")
    ax.set_ylabel("True class", size=16)
    ax.tick_params(axis="both", which="major")
    ax.grid(False)


fig, ax = plt.subplots(1, 1, figsize=(8, 8), layout="constrained")
display_confusion_matrix(ax)
plt.show()

png

The model frequently confuses “Shirt” with “T-shirt/top,” and also struggles with “Coat/Shirt” and “Pullover/Shirt” distinctions.

fig, axes = plt.subplots(4, 3, figsize=(10, 10), layout="constrained")
fig.suptitle("ROC and AUC")
fig.supxlabel("True positive rate")
fig.supylabel("False positive rate")
for i, (cat, ax) in enumerate(zip(cats, fig.axes[:10])):
    disp = RocCurveDisplay.from_predictions(
        y_test == i, y_test_proba_hat[:, i], ax=ax, name=f"{cat} v rest"
    )
    ax.set(xlabel=None, ylabel=None)
for ax in fig.axes[10:]:
    ax.remove()
plt.show()

png

The model performs exceptionally well for “Trouser”, “Ankle boot”, “Bag”, “Sneaker”, and “Sandal” classes (AUC = 1). However, it struggles with “Shirt”, “Coat”, “Pullover”, and “T-shirt/top”.

Sample Predictions

edf = pd.read_csv(eval_file)
id = edf.loc[:, "ID"]
X_eval = edf.drop("ID", axis=1).to_numpy()

eval_dataloader = DataLoader(
    TensorDataset(torch.Tensor(X_eval).reshape((-1, 1, 32, 32))),
    batch_size=batch_size,
)

eval_pred_path = outputs_path / "eval_pred.pickle"
final = load_model("final2", FinalCNNet()).to(device)
y_eval = predict(eval_dataloader, final, eval_pred_path)

d = pd.DataFrame()
d["ID"] = id
d["label"] = y_eval
d.to_csv(res_file, index=False)

fig, ax = plt.subplots(4, 8, figsize=(10, 6))
fig.suptitle("Sample from evaluation.csv (inverted brightness)")
for j in range(4):
    for i in range(8):
        ax[j, i].imshow(X_eval[5 * j + i].reshape((32, 32)), cmap="Greys")
        ax[j, i].set_title(MAP_IL[d.loc[5 * j + i, "label"]], fontsize=9)
        ax[j, i].set_axis_off()
plt.show()

png

The first 32 predictions appear reasonable. Let’s visualize predictions by category.

for c in range(10):
    fig, axes = plt.subplots(2, 10, figsize=(10, 2))
    fig.suptitle(f"{MAP_IL[c]}")

    d = X_eval[y_eval == c]
    for i in range(10):
        axes[0][i].imshow(d[i].reshape((32, 32)), cmap="Greys")
        axes[0][i].set_axis_off()
        axes[1][i].imshow(d[i + 10].reshape((32, 32)), cmap="Greys")
        axes[1][i].set_axis_off()
    plt.show()

png

png

png

png

png

png

png

png

png

png