Data Purchasing Challenge 2022

A New Baseline With 0.71 accuracy on LB

One small change in previous baseline and you can get a new baseline with 0.71 accuracy on LB


How to use this notebook 📝

You can use this notebook as stand-alone colab solution, however I have not written a code to make a submission.

Content of this notebook

  1. Download relevant data
  2. Install requirements
  3. Initialize datasets
  4. Setup the wandb (Shown in below cells)
  5. Baseline model (EfficientNet-b1 => 0.71 Accuracy) in run.sh
  6. Execution pipeline You can copy the relevant code to your codebase during submission.

Note: I have written code for both CPU and CUDA. Search for this key: "CHANGE CPU CUDA HERE" and make changes as per your machine. You just have to comment and uncomment a very few lines.

1) Login to AIcrowd 🤩

In [ ]:
#@title Login to AIcrowd
!pip install -U aicrowd-cli > /dev/null
!aicrowd login 2> /dev/null
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires requests~=2.23.0, but you have requests 2.27.1 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Please login here: https://api.aicrowd.com/auth/haKT8-uY-hsX6z30NGXp-zU8uxpOIFgzAjbTutOokMU
API Key valid
Gitlab access token valid
Saved details successfully!

2) Setup magically, run the below cell 😉

In [ ]:
#@title Magic Box ⬛ { vertical-output: true, display-mode: "form" }
  import os
  if first_run and os.path.exists("/content/data-purchasing-challenge-2022-starter-kit/data/training"):
    first_run = False
  first_run = True

if first_run:
  %cd /content/
  !git clone http://gitlab.aicrowd.com/zew/data-purchasing-challenge-2022-starter-kit.git > /dev/null
  %cd data-purchasing-challenge-2022-starter-kit
  !aicrowd dataset list -c data-purchasing-challenge-2022
  !aicrowd dataset download -c data-purchasing-challenge-2022
  !mkdir -p data/
  !mv *.tar.gz data/ && cd data && echo "Extracting dataset" && ls *.tar.gz | xargs -n1 -I{} bash -c "tar -xvf {} > /dev/null"

def run_pre_training_phase():
  from run import ZEWDPCBaseRun
  run = ZEWDPCBaseRun()
  run.pre_training_phase = pre_training_phase
  run.pre_training_phase(self=run, training_dataset=training_dataset)
  # NOTE:It is critical that the checkpointing works in a self-contained way
  #      As, the evaluators might choose to run the different phases separately.

def run_purchase_phase():
  from run import ZEWDPCBaseRun
  run = ZEWDPCBaseRun()
  run.pre_training_phase = pre_training_phase
  run.purchase_phase = purchase_phase
  # Hacky way to make it work in notebook
  unlabelled_dataset.purchases = set()
  run.purchase_phase(self=run, unlabelled_dataset=unlabelled_dataset, training_dataset=training_dataset, budget=3000)
  del run

def run_prediction_phase():
  from run import ZEWDPCBaseRun
  run = ZEWDPCBaseRun()
  run.pre_training_phase = pre_training_phase
  run.purchase_phase = purchase_phase
  run.prediction_phase = prediction_phase
  run.prediction_phase(self=run, test_dataset=val_dataset)
  del run
Cloning into 'data-purchasing-challenge-2022-starter-kit'...
remote: Enumerating objects: 12, done.
remote: Counting objects: 100% (12/12), done.
remote: Compressing objects: 100% (12/12), done.
remote: Total 111 (delta 3), reused 1 (delta 0), pack-reused 99
Receiving objects: 100% (111/111), 39.76 KiB | 515.00 KiB/s, done.
Resolving deltas: 100% (57/57), done.
                 Datasets for challenge #1024                                   
│ #  Title              Description                  Size │                  
│ 0 │ validation.tar.gz │ Validation dataset       │ 182 MiB │                  
│ 1 │ unlabelled.tar.gz │ Unlabelled image dataset │ 609 MiB │                  
│ 2 │ training.tar.gz   │ Training data            │ 304 MiB │                  
│ 3 │ debug.tar.gz      │ Debug dataset            │ 6.1 MiB │                  
validation.tar.gz: 100% 191M/191M [00:10<00:00, 18.5MB/s]
unlabelled.tar.gz: 100% 638M/638M [00:57<00:00, 11.1MB/s]
training.tar.gz: 100% 319M/319M [00:15<00:00, 20.2MB/s]
debug.tar.gz: 100% 6.43M/6.43M [00:00<00:00, 9.65MB/s]
Extracting dataset

3) Writing your code implementation! ✍️

a) Runtime Packages

In [ ]:
#@title a) Runtime Packages<br/><small>Important: Add the packages required by your code here. (space separated)</small> { run: "auto", display-mode: "form" }
apt_packages = "build-essential vim" #@param {type:"string"}
pip_packages = "scikit-image pandas timeout-decorator==0.5.0 numpy wandb" #@param {type:"string"}

!apt install -y $apt_packages git-lfs
!pip install $pip_packages
Reading package lists... Done
Building dependency tree       
Reading state information... Done
build-essential is already the newest version (12.4ubuntu1).
The following packages were automatically installed and are no longer required:
  cuda-command-line-tools-10-0 cuda-command-line-tools-10-1
  cuda-command-line-tools-11-0 cuda-compiler-10-0 cuda-compiler-10-1
  cuda-compiler-11-0 cuda-cuobjdump-10-0 cuda-cuobjdump-10-1
  cuda-cuobjdump-11-0 cuda-cupti-10-0 cuda-cupti-10-1 cuda-cupti-11-0
  cuda-cupti-dev-11-0 cuda-documentation-10-0 cuda-documentation-10-1
  cuda-documentation-11-0 cuda-documentation-11-1 cuda-gdb-10-0 cuda-gdb-10-1
  cuda-gdb-11-0 cuda-gpu-library-advisor-10-0 cuda-gpu-library-advisor-10-1
  cuda-libraries-10-0 cuda-libraries-10-1 cuda-libraries-11-0
  cuda-memcheck-10-0 cuda-memcheck-10-1 cuda-memcheck-11-0 cuda-nsight-10-0
  cuda-nsight-10-1 cuda-nsight-11-0 cuda-nsight-11-1 cuda-nsight-compute-10-0
  cuda-nsight-compute-10-1 cuda-nsight-compute-11-0 cuda-nsight-compute-11-1
  cuda-nsight-systems-10-1 cuda-nsight-systems-11-0 cuda-nsight-systems-11-1
  cuda-nvcc-10-0 cuda-nvcc-10-1 cuda-nvcc-11-0 cuda-nvdisasm-10-0
  cuda-nvdisasm-10-1 cuda-nvdisasm-11-0 cuda-nvml-dev-10-0 cuda-nvml-dev-10-1
  cuda-nvml-dev-11-0 cuda-nvprof-10-0 cuda-nvprof-10-1 cuda-nvprof-11-0
  cuda-nvprune-10-0 cuda-nvprune-10-1 cuda-nvprune-11-0 cuda-nvtx-10-0
  cuda-nvtx-10-1 cuda-nvtx-11-0 cuda-nvvp-10-0 cuda-nvvp-10-1 cuda-nvvp-11-0
  cuda-nvvp-11-1 cuda-samples-10-0 cuda-samples-10-1 cuda-samples-11-0
  cuda-samples-11-1 cuda-sanitizer-11-0 cuda-sanitizer-api-10-1
  cuda-toolkit-10-0 cuda-toolkit-10-1 cuda-toolkit-11-0 cuda-toolkit-11-1
  cuda-tools-10-0 cuda-tools-10-1 cuda-tools-11-0 cuda-tools-11-1
  cuda-visual-tools-10-0 cuda-visual-tools-10-1 cuda-visual-tools-11-0
  cuda-visual-tools-11-1 default-jre dkms freeglut3 freeglut3-dev
  keyboard-configuration libargon2-0 libcap2 libcryptsetup12
  libdevmapper1.02.1 libfontenc1 libidn11 libip4tc0 libjansson4
  libnvidia-cfg1-510 libnvidia-common-460 libnvidia-common-510
  libnvidia-extra-510 libnvidia-fbc1-510 libnvidia-gl-510 libpam-systemd
  libpolkit-agent-1-0 libpolkit-backend-1-0 libpolkit-gobject-1-0 libxfont2
  libxi-dev libxkbfile1 libxmu-dev libxmu-headers libxnvctrl0 libxtst6
  nsight-compute-2020.2.1 nsight-compute-2022.1.0 nsight-systems-2020.3.2
  nsight-systems-2020.3.4 nsight-systems-2021.5.2 nvidia-dkms-510
  nvidia-kernel-common-510 nvidia-kernel-source-510 nvidia-modprobe
  nvidia-settings openjdk-11-jre policykit-1 policykit-1-gnome python3-xkit
  screen-resolution-extra systemd systemd-sysv udev x11-xkb-utils
  xserver-common xserver-xorg-core-hwe-18.04 xserver-xorg-video-nvidia-510
Use 'apt autoremove' to remove them.
The following additional packages will be installed:
  libgpm2 vim-common vim-runtime xxd
Suggested packages:
  gpm ctags vim-doc vim-scripts
The following NEW packages will be installed:
  git-lfs libgpm2 vim vim-common vim-runtime xxd
0 upgraded, 6 newly installed, 0 to remove and 39 not upgraded.
Need to get 8,854 kB of archives.
After this operation, 40.2 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 xxd amd64 2:8.0.1453-1ubuntu1.8 [49.9 kB]
Get:2 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 vim-common all 2:8.0.1453-1ubuntu1.8 [71.1 kB]
Get:3 http://archive.ubuntu.com/ubuntu bionic/universe amd64 git-lfs amd64 2.3.4-1 [2,129 kB]
Get:4 http://archive.ubuntu.com/ubuntu bionic/main amd64 libgpm2 amd64 1.20.7-5 [15.1 kB]
Get:5 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 vim-runtime all 2:8.0.1453-1ubuntu1.8 [5,435 kB]
Get:6 http://archive.ubuntu.com/ubuntu bionic-updates/main amd64 vim amd64 2:8.0.1453-1ubuntu1.8 [1,154 kB]
Fetched 8,854 kB in 1s (7,299 kB/s)
Selecting previously unselected package xxd.
(Reading database ... 155113 files and directories currently installed.)
Preparing to unpack .../0-xxd_2%3a8.0.1453-1ubuntu1.8_amd64.deb ...
Unpacking xxd (2:8.0.1453-1ubuntu1.8) ...
Selecting previously unselected package vim-common.
Preparing to unpack .../1-vim-common_2%3a8.0.1453-1ubuntu1.8_all.deb ...
Unpacking vim-common (2:8.0.1453-1ubuntu1.8) ...
Selecting previously unselected package git-lfs.
Preparing to unpack .../2-git-lfs_2.3.4-1_amd64.deb ...
Unpacking git-lfs (2.3.4-1) ...
Selecting previously unselected package libgpm2:amd64.
Preparing to unpack .../3-libgpm2_1.20.7-5_amd64.deb ...
Unpacking libgpm2:amd64 (1.20.7-5) ...
Selecting previously unselected package vim-runtime.
Preparing to unpack .../4-vim-runtime_2%3a8.0.1453-1ubuntu1.8_all.deb ...
Adding 'diversion of /usr/share/vim/vim80/doc/help.txt to /usr/share/vim/vim80/doc/help.txt.vim-tiny by vim-runtime'
Adding 'diversion of /usr/share/vim/vim80/doc/tags to /usr/share/vim/vim80/doc/tags.vim-tiny by vim-runtime'
Unpacking vim-runtime (2:8.0.1453-1ubuntu1.8) ...
Selecting previously unselected package vim.
Preparing to unpack .../5-vim_2%3a8.0.1453-1ubuntu1.8_amd64.deb ...
Unpacking vim (2:8.0.1453-1ubuntu1.8) ...
Setting up git-lfs (2.3.4-1) ...
Setting up xxd (2:8.0.1453-1ubuntu1.8) ...
Setting up libgpm2:amd64 (1.20.7-5) ...
Setting up vim-common (2:8.0.1453-1ubuntu1.8) ...
Setting up vim-runtime (2:8.0.1453-1ubuntu1.8) ...
Setting up vim (2:8.0.1453-1ubuntu1.8) ...
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/vim (vim) in auto mode
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/vimdiff (vimdiff) in auto mode
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/rvim (rvim) in auto mode
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/rview (rview) in auto mode
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/vi (vi) in auto mode
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/view (view) in auto mode
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/ex (ex) in auto mode
update-alternatives: using /usr/bin/vim.basic to provide /usr/bin/editor (editor) in auto mode
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
Processing triggers for hicolor-icon-theme (0.17-2) ...
Processing triggers for mime-support (3.60ubuntu1) ...
Processing triggers for libc-bin (2.27-3ubuntu1.3) ...
/sbin/ldconfig.real: /usr/local/lib/python3.7/dist-packages/ideep4py/lib/libmkldnn.so.0 is not a symbolic link

Requirement already satisfied: scikit-image in /usr/local/lib/python3.7/dist-packages (0.18.3)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (1.3.5)
Collecting timeout-decorator==0.5.0
  Downloading timeout-decorator-0.5.0.tar.gz (4.8 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (1.21.5)
Requirement already satisfied: imageio>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image) (2.4.1)
Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image) (7.1.2)
Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.7/dist-packages (from scikit-image) (2021.11.2)
Requirement already satisfied: scipy>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from scikit-image) (1.4.1)
Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from scikit-image) (1.2.0)
Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image) (2.6.3)
Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image) (3.2.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image) (3.0.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image) (1.3.2)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image) (0.11.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib!=3.0.0,>=2.0.0->scikit-image) (1.15.0)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas) (2018.9)
Building wheels for collected packages: timeout-decorator
  Building wheel for timeout-decorator (setup.py) ... done
  Created wheel for timeout-decorator: filename=timeout_decorator-0.5.0-py3-none-any.whl size=5028 sha256=26a1e218629ba4c5771281c9907d578b966d5323bec6cac22d06d02d8ae409a5
  Stored in directory: /root/.cache/pip/wheels/7d/64/ac/de1dd54f9a6e48b846e9cb5e4176d6f063380e7f83d69807ad
Successfully built timeout-decorator
Installing collected packages: timeout-decorator
Successfully installed timeout-decorator-0.5.0

b) Load Dataset

The directory sturcture at this point looks like this:

Quick preview of images and labels.csv is as follows:

Let's initialise dataset instances.

In [ ]:
from evaluator.dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset

# Instantiate Training Dataset
training_dataset = ZEWDPCBaseDataset(
# Instantiate Unlabelled Dataset
unlabelled_dataset = ZEWDPCProtectedDataset(
    budget=3000,  # Configurable Parameter
# Instantiate Validation Dataset
val_dataset = ZEWDPCBaseDataset(
val_dataset_gt = ZEWDPCBaseDataset(

c) pre_training_phase

Pre-train your model on the available labelled dataset here.

Hook for the Pre-Training Phase of the Competition, where you have access to a training_dataset, an instance of the ZEWDPCBaseDataset class (see dataset.py for more details).

You are allowed to pre-train on this data while you prepare for the purchase phase of the competition.

If you train some models, you can instantiate them as self.model, as long as you implement self-contained checkpointing in the self.save_checkpoint and self.load_checkpoint hooks, as the hooks for the different phases of the competition, can be called in other executions of the BaseRun.

Base code

In [ ]:
import torch
from torch import nn
from torchvision import models
from torch.optim import Adam, SGD, lr_scheduler
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import abc
import datetime
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from sklearn.metrics import hamming_loss
import wandb

from evaluator.dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset

Training class

In [ ]:
class ZEWDPCBaseRun:

    def __init__(self):
        self.evaluation_state = {}
        # Model parameters
        self.BATCH_SIZE = 32
        self.NUM_WORKERS = 2
        self.LEARNING_RATE = 0.001
        self.NUM_CLASSES = 4
        self.TOPK= 3
        self.THRESHOLD = 0.5
        self.NUM_EPOCS = 50
        self.EVAL_FREQ = 5

        self.model = models.efficientnet_b1(num_classes = self.NUM_CLASSES)
        # self.model.cpu()

        self.trainable_parameters = filter(lambda param: param.requires_grad, self.model.parameters())
        self.optimizer = Adam(self.trainable_parameters, lr=self.LEARNING_RATE)
        self.epoch = 0

        self.lr_scheduler_ = lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', patience=2, verbose=True
        self.criterion = nn.BCEWithLogitsLoss()

        # WandB setup
        model_name = "EfficientNet-B1"
            "model": model_name,
            "learning_rate": self.LEARNING_RATE,
            "samples_per_gpu": self.BATCH_SIZE,
            "workers_per_gpu": self.NUM_WORKERS,
            "augmentations": "NO",
            "description": "Pre-Training. No augmentation"
        wandb.init(project="<PROJECT_NAME>", entity="<USERNAME>", name=model_name, config=wandb_config)

    def pre_training_phase(
        self, training_dataset: ZEWDPCBaseDataset, register_progress=lambda x: False
        print("\n================> Pre-Training Phase\n")
        # Creating transformations
        train_transform = transforms.Compose([
        train_loader = DataLoader(

        def run_epoch():
            for _, batch in enumerate(train_loader):

                ## CHANGE CPU CUDA HERE
                x, y = batch["image"].cuda(), batch["label"]
                # x, y = batch["image"].cpu(), batch["label"]

                pred_y = self.model(x)
                # Change the shape of true labels here. Because for last batch the no. of images can be less
                y = torch.cat(y, dim=0).reshape(
                    self.NUM_CLASSES, pred_y.shape[0]
                ## CHANGE CPU CUDA HERE. Comment for CPU
                y = y.cuda()
                loss = self.criterion(pred_y, y)

                # 416 = BATCH_SIZE*13
                if self.global_step % 416 == 0:
                        "Step": (self.epoch*5000)+self.global_step,
                        "Loss": loss,
                    print("[{}] Training [epoch {}, step {}], loss: {:4f}".format(
                        datetime.datetime.now(), self.epoch, self.global_step, loss))
                self.global_step += self.BATCH_SIZE
        epoch_range = tqdm(range(self.epoch, self.NUM_EPOCS))
        for i in epoch_range:
            epoch_range.set_description(f"Epoch: {i}")
            self.global_step = 0
            register_progress(i) # Epoch as progress
            if (i+1)%self.EVAL_FREQ == 0:
                predictions = self.prediction_phase(val_dataset)
            self.epoch += 1
        print("Execution Complete of Training Phase.")

    def purchase_phase(
        unlabelled_dataset: ZEWDPCProtectedDataset,
        training_dataset: ZEWDPCBaseDataset,
        register_progress=lambda x: False,
        # Purchase Phase
        In this phase of the competition, you have access to
        the unlabelled_dataset (an instance of `ZEWDPCProtectedDataset`)
        and the training_dataset (an instance of `ZEWDPCBaseDataset`)
        {see datasets.py for more details}, and a purchase budget.

        You can iterate over both the datasets and access the images without restrictions.
        However, you can probe the labels of the unlabelled_dataset only until you
        run out of the label purchasing budget.

        PARTICIPANT_TODO: Add your code here
        print("\n================> Purchase Phase | Budget = {}\n".format(budget))

        register_progress(0.0) #Register Progress
        for sample in tqdm(unlabelled_dataset):
            idx = sample["idx"]
            # image = unlabelled_dataset.__getitem__(idx)
            # print(image)

            # Budgeting & Purchasing Labels
            if budget > 0:
                label = unlabelled_dataset.purchase_label(idx)

            budget -= 1
        register_progress(1.0) #Register Progress
        print("Execution Complete of Purchase Phase.")

    def prediction_phase(
        test_dataset: ZEWDPCBaseDataset,
        register_progress=lambda x: False,
        # Prediction Phase
        In this phase of the competition, you have access to the test dataset, and you
        are supposed to make predictions using your trained models.

            np.ndarray of shape (n, 4)
                where n is the number of samples in the test set
                and 4 refers to the 4 labels to be predicted for each sample
                for the multi-label classification problem.

        PARTICIPANT_TODO: Add your code here
            "\n================> Prediction Phase : - on {} images\n".format(
        test_transform = transforms.Compose([
        test_loader = DataLoader(
        def convert_to_label(preds):
            return np.array((torch.sigmoid(preds) > 0.5), dtype=int).tolist()

        predictions = []
        with torch.no_grad():
            for _, batch in enumerate(test_loader):
                ## CHANGE CPU CUDA HERE
                # X= batch['image'].cpu()
                X = batch['image'].cuda()

                pred_y = self.model(X)

                # Convert to labels
                pred_y_labels = []
                for arr in pred_y:
                    ## CHANGE CPU CUDA HERE
                    pred_y_labels.append(convert_to_label(arr.cpu())) # For CUDA
                    # pred_y_labels.append(convert_to_label(arr)) # For CPU

                # Save the results

        predictions = np.array(predictions) # random predictions
        print("Execution Complete of Purchase Phase.")
        return predictions

    def evaluation(self, predictions):
        from evaluator.evaluation_metrics import accuracy_score, hamming_loss, exact_match_ratio

        y_true = val_dataset_gt._get_all_labels()
        y_pred = predictions

        accuracy_score = accuracy_score(y_true, y_pred)
        hamming_loss_score = hamming_loss(y_true, y_pred)
        exact_match_ratio_score = exact_match_ratio(y_true, y_pred)

            "Epoch": self.epoch+1,
            "Accuracy": accuracy_score,
            "Hamming Loss": hamming_loss_score,
            "Match ratio": exact_match_ratio_score

        print("Accuracy Score : ", accuracy_score)
        print("Hamming Loss : ", hamming_loss_score)
        print("Exact Match Ratio : ", exact_match_ratio_score)

    def save_checkpoint(self, checkpoint_path):
        Saves the checkpoint in the checkpoint_path directory. Each checkpoint will be saved for epoch_x
        save_dict = {
            'epoch': self.epoch + 1,
            'model_state_dict': self.model.state_dict(),
            'optim_state_dict': self.optimizer.state_dict(),
        torch.save(save_dict, checkpoint_path)
        print(f"Checkpont epoch:{self.epoch} Model saved at {checkpoint_path}")

    def load_checkpoint(self, checkpoint_path):
        Load the latest checkpoint from the experiment
        checkpoint_model = torch.load(checkpoint_path, map_location="cuda:0")
        # checkpoint_model = torch.load(checkpoint_path, map_location="cpu")
        self.latest_epoch = checkpoint_model['epoch']
        print('loading checkpoint success (epoch {})'.format(self.latest_epoch))
In [ ]:
import tempfile
checkpoint_path = tempfile.NamedTemporaryFile(delete=False).name

run = ZEWDPCBaseRun()
## Pre - Training process
del run

# ## Purchasing phase
run = ZEWDPCBaseRun()
run.purchase_phase(unlabelled_dataset, training_dataset, budget=3000)
del run

## Prediction phase
run = ZEWDPCBaseRun()
predictions = run.prediction_phase(val_dataset)
assert type(predictions) == np.ndarray
assert predictions.shape == (len(val_dataset), 4)

## Evaluation Phase
from evaluator.evaluation_metrics import accuracy_score, hamming_loss, exact_match_ratio

y_true = val_dataset_gt._get_all_labels()
y_pred = predictions

accuracy_score = accuracy_score(y_true, y_pred)
hamming_loss_score = hamming_loss(y_true, y_pred)
exact_match_ratio_score = exact_match_ratio(y_true, y_pred)

print("Accuracy Score : ", accuracy_score)
print("Hamming Loss : ", hamming_loss_score)
print("Exact Match Ratio : ", exact_match_ratio_score)


Over 2 years ago

It took me a while to figure out the small change.

You must login before you can post a comment.