Loading

Data Purchasing Challenge 2022

[Utils] Round 2 Final Model compilation

Compilation of the Last Stage's 5 final models

santiactis

Hi, you! Welcome! 🙌


 

This notebook provides a quick **plug-and-play** method to test the 5 models from the final stage. You will find code chunks for the local_evaluation.pymodel.py, and trainer.py files that you can copy into your repository. This will allow you to evaulate your purchasing strategy with all five models in one run.

Keep in mind that this is not an interactive Notebook. You need to copy each code chunk into the corresponding .py file in your local repository.

 

I hope this helps you speeding up your tests. If you find it useful, please, remember to leave a 💗!

Round 2 Final Models compilation

Hi, you! Welcome! 🙌

This notebook provides a quick plug-and-play method to test the 5 models from the final stage.

Keep in mind that this is not an interactive Notebook. You need to copy each code chunk into the corresponding .py file in your local repository

1) Model class

This code chunk must be copied into the "YourRepository/evaluator/model.py" file. It contains all the 5 models for the final stage, allowing you to run all of them and get results for you to compare the different strategies you are implementing.

In [ ]:
import torch
import torchvision
from torchvision import transforms as T

##################################################################################################################################
##### Final model Trainer No.1
##################################################################################################################################


class ZEWDPCModel_1(torch.nn.Module):
    """
    A basic model based on EfficientNet_B4 which is used to re-train the 
    overall dataset containing the available training set, 
    and the purchased dataset. 
    
    Inputs:
        use_pretrained: Whether to use pretrained weights or not.
    """

    def __init__(self, num_classes=6, use_pretrained=True):
        super().__init__()
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.required_transforms = [
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            if use_pretrained
            else T.Lambda(lambda x: x)
        ]
        self.activation = torch.nn.Sigmoid()

        self.init_network()

    def init_network(self):
        # Setup Base Model - EfficientNet_b4
        self.base_model = torchvision.models.efficientnet_b4(
            pretrained=self.use_pretrained,
        )
        # Freeze feature extration layers
        for param in self.base_model.features.parameters():
            param.requires_grad = False

        # Replace the final FC layer to support
        # the required number of classes
        in_features = list(self.base_model.classifier.children())[-1].in_features
        self.base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.4, inplace=True),
            torch.nn.Linear(
                in_features=in_features, out_features=self.num_classes, bias=True
            ),
        )

    def forward(self, x):
        """
        Forward function of the ZEWDPCModel

        Inputs:
            x: The batched images input
        """
        output = self.base_model(x)
        # output = self.activation(output) # Not needed for BCE with logits, better training
        return output

##################################################################################################################################
##### Final model Trainer No.2
##################################################################################################################################

class ZEWDPCModel_2(torch.nn.Module):
    """
    A basic model based on EfficientNet_B4 which is used to re-train the 
    overall dataset containing the available training set, 
    and the purchased dataset. 
    
    Inputs:
        use_pretrained: Whether to use pretrained weights or not.
    """

    def __init__(self, num_classes=6, use_pretrained=True):
        super().__init__()
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.required_transforms = [
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            if use_pretrained
            else T.Lambda(lambda x: x)
        ]
        self.activation = torch.nn.Sigmoid()

        self.init_network()

    def init_network(self):
        # Setup Base Model - EfficientNet_b4
        self.base_model = torchvision.models.efficientnet_b4(
            pretrained=self.use_pretrained,
        )
        # Freeze feature extration layers
        for param in self.base_model.features.parameters():
            param.requires_grad = False

        # Replace the final FC layer to support
        # the required number of classes
        in_features = list(self.base_model.classifier.children())[-1].in_features
        self.base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.4, inplace=True),
            torch.nn.Linear(
                in_features=in_features, out_features=self.num_classes, bias=True
            ),
        )

    def forward(self, x):
        """
        Forward function of the ZEWDPCModel

        Inputs:
            x: The batched images input
        """
        output = self.base_model(x)
        # output = self.activation(output) # Not needed for BCE with logits, better training
        return output


##################################################################################################################################
##### Final model Trainer No.3
##################################################################################################################################

class ZEWDPCModel_3(torch.nn.Module):
    """
    A basic model based on EfficientNet_B4 which is used to re-train the 
    overall dataset containing the available training set, 
    and the purchased dataset. 
    
    Inputs:
        use_pretrained: Whether to use pretrained weights or not.
    """

    def __init__(self, num_classes=6, use_pretrained=True):
        super().__init__()
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.required_transforms = [
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            if use_pretrained
            else T.Lambda(lambda x: x)
        ]
        self.activation = torch.nn.Sigmoid()

        self.init_network()

    def init_network(self):
        # Setup Base Model - EfficientNet_b4
        self.base_model = torchvision.models.efficientnet_b4(
            pretrained=self.use_pretrained,
        )

        # Replace the final FC layer to support
        # the required number of classes
        in_features = list(self.base_model.classifier.children())[-1].in_features
        self.base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.4, inplace=True),
            torch.nn.Linear(
                in_features=in_features, out_features=self.num_classes, bias=True
            ),
        )

    def forward(self, x):
        """
        Forward function of the ZEWDPCModel

        Inputs:
            x: The batched images input
        """
        output = self.base_model(x)
        # output = self.activation(output) # Not needed for BCE with logits, better training
        return output


##################################################################################################################################
##### Final model Trainer No.4
##################################################################################################################################

class ZEWDPCModel_4(torch.nn.Module):
    """
    A basic model based on EfficientNet_B4 which is used to re-train the 
    overall dataset containing the available training set, 
    and the purchased dataset. 
    
    Inputs:
        use_pretrained: Whether to use pretrained weights or not.
    """

    def __init__(self, num_classes=6, use_pretrained=True):
        super().__init__()
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.required_transforms = [
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            if use_pretrained
            else T.Lambda(lambda x: x)
        ]
        self.activation = torch.nn.Sigmoid()

        self.init_network()

    def init_network(self):
        # Setup Base Model - EfficientNet_b4
        self.base_model = torchvision.models.efficientnet_b4(
            pretrained=self.use_pretrained,
        )

        # Replace the final FC layer to support
        # the required number of classes
        in_features = list(self.base_model.classifier.children())[-1].in_features
        self.base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.4, inplace=True),
            torch.nn.Linear(
                in_features=in_features, out_features=self.num_classes, bias=True
            ),
        )

    def forward(self, x):
        """
        Forward function of the ZEWDPCModel

        Inputs:
            x: The batched images input
        """
        output = self.base_model(x)
        # output = self.activation(output) # Not needed for BCE with logits, better training
        return output


##################################################################################################################################
##### Final model Trainer No.5
##################################################################################################################################

class ZEWDPCModel_5(torch.nn.Module):
    """
    A basic model based on EfficientNet_B4 which is used to re-train the 
    overall dataset containing the available training set, 
    and the purchased dataset. 
    
    Inputs:
        use_pretrained: Whether to use pretrained weights or not.
    """

    def __init__(self, num_classes=6, use_pretrained=True):
        super().__init__()
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.required_transforms = [
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            if use_pretrained
            else T.Lambda(lambda x: x)
        ]
        self.activation = torch.nn.Sigmoid()

        self.init_network()

    def init_network(self):
        # Setup Base Model - EfficientNet_b4
        self.base_model = torchvision.models.efficientnet_b0(
            pretrained=self.use_pretrained,
        )

        # Replace the final FC layer to support
        # the required number of classes
        in_features = list(self.base_model.classifier.children())[-1].in_features
        self.base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.4, inplace=True),
            torch.nn.Linear(
                in_features=in_features, out_features=self.num_classes, bias=True
            ),
        )

    def forward(self, x):
        """
        Forward function of the ZEWDPCModel

        Inputs:
            x: The batched images input
        """
        output = self.base_model(x)
        # output = self.activation(output) # Not needed for BCE with logits, better training
        return output



#### DEBUG

if __name__ == "__main__":

    model = ZEWDPCModel_1()

    input = torch.rand(1, 3, 224, 224)
    output = model(input)
    print(output)

2) Trainer class

This code chunk must be copied into the "YourRepository/evaluator/trainer.py" file. It contains the trainer classes for all the 5 models for the final stage, allowing you to run all of them and get results for you to compare the different strategies you are implementing.

In [ ]:
#!/usr/bin/env python

from cgi import test
import os
import tqdm
import copy
import numpy as np

from evaluator.dataset import (
    ZEWDPCBaseDataset,
    ZEWDPCProtectedDataset,
    ZEWDPCRuntimeDataset,
)

from evaluator.model import ZEWDPCModel_1, ZEWDPCModel_2, ZEWDPCModel_3, ZEWDPCModel_4, ZEWDPCModel_5

from evaluator.utils import (
    instantiate_purchased_dataset,
    AverageMeter,
)

from evaluator.exceptions import OutOfBudetException
from evaluator.evaluation_metrics import get_zew_dpc_metrics

import torch
import torchvision
from torchvision import transforms as T
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

import torchmetrics

SEED = os.getenv("AICROWD_ZEWDPC_SEED", 42)

####################################################################################
####################################################################################
#
# ZEW DPC Trainer Class
# 
# This class takes the an aggregated dataset (purchased dataset + training dataset), 
# and trains an EfficientNet_b4 network for 10 epochs.
# This is used to evaluate the "performance" of these purchased labels.
####################################################################################

##################################################################################################################################
##### Final model Trainer No.1
##################################################################################################################################

class ZEWDPCTrainer_1:
    def __init__(
        self,
        num_classes=6,
        use_pretrained=True,
        hparams={},
        seed=42,
        sync_with_server=lambda x: x,
    ):
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.DEFAULT_HYPERPARAMETERS = (
            {  # A single dictionary to hold all configurabale params.
                "learning_rate": 1e-3,
                "validation_interval": 1,
                "LR Scheduler Patience": 5,
                "LR Scheduler Factor": 0.2,
            }
        )
        self.hparams = self.DEFAULT_HYPERPARAMETERS
        self.hparams.update(hparams)
        self.seed(seed)
        self.sync_with_server = sync_with_server

        # A state dictionary which can hold any run_time metrics that need to be relayed to the user
        self.setup_runtime_metrics()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.init_model()

    def seed(self, seed):
        torch.use_deterministic_algorithms(True)
        torch.manual_seed(seed)
        self.seed = seed
        

    def init_model(self):
        """Initialize the base model used to train the network"""
        self.model = ZEWDPCModel_1(
            num_classes=self.num_classes, use_pretrained=self.use_pretrained
        )
        self.model = self.model.to(self.device)
        self.best_model = None

    def setup_transforms(self, training_dataset):
        """
        Sets up the necessary transforms for the training_dataset
        """
        ## Setup necessary Transformations
        train_transform = T.Compose(
            [
                T.ToTensor(),  # Converts image to [0, 1]
                T.RandomVerticalFlip(p=0.5),
                T.RandomHorizontalFlip(p=0.5),
                T.GaussianBlur(kernel_size=3),
                T.ColorJitter(brightness=0.2, contrast=0.2),
                # TODO: Could agree with participants on the final list of transforms
                *self.model.required_transforms,
            ]
        )

        if isinstance(training_dataset, ZEWDPCBaseDataset):
            training_dataset.set_transform(train_transform)
        elif isinstance(training_dataset, torch.utils.data.ConcatDataset):
            for dataset in training_dataset.datasets:
                if isinstance(dataset, ZEWDPCRuntimeDataset):
                    dataset.set_transform(train_transform)
                elif isinstance(dataset, ZEWDPCBaseDataset):
                    dataset.set_transform(train_transform)
                else:
                    raise NotImplementedError()
        else:
            raise NotImplementedError()

    def setup_validation_set(self, training_dataset, validation_percentage=0.05):
        """
        Creates a Validation Set from the Training Dataset
        """
        assert (
            0 < validation_percentage < 1
        ), "Expected : validation_percentage ∈ [0, 1]. Received validataion_percentage = {}".format(
            validation_percentage
        )

        validation_size = int(validation_percentage * len(training_dataset))
        training_dataset, validation_dataset = torch.utils.data.random_split(
            training_dataset,
            [
                len(training_dataset) - validation_size,
                validation_size,
            ],
            generator=torch.Generator().manual_seed(self.seed),
        )
        return training_dataset, validation_dataset

    def setup_dataloaders(self, training_dataset, validation_dataset, batch_size=32):
        """
        Sets up necessary dataloader
        """
        train_dataloader = torch.utils.data.DataLoader(
            training_dataset, batch_size=batch_size, shuffle=True
        )
        val_dataloader = torch.utils.data.DataLoader(
            validation_dataset, batch_size=batch_size, shuffle=True
        )

        return train_dataloader, val_dataloader

    def train(
        self, training_dataset, num_epochs=20, validation_percentage=0.1, batch_size=32
    ):
        # Setup Transforms
        self.setup_transforms(training_dataset)

        # Prepare Validation Set
        training_dataset, validation_dataset = self.setup_validation_set(
            training_dataset, validation_percentage=validation_percentage
        )

        # Setup Dataloaders
        train_dataloader, val_dataloader = self.setup_dataloaders(
            training_dataset, validation_dataset, batch_size=batch_size
        )

        # Setup Criterion & Optimizer
        criterion = torch.nn.BCEWithLogitsLoss()
        lr = self.hparams["learning_rate"]
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=lr)

        lr_sched = CosineAnnealingLR(optimizer, num_epochs, eta_min=1e-5, last_epoch=-1, verbose=True)

        # Setup Metric Meters
        val_loss_avg_meter = AverageMeter()
        train_loss_avg_meter = AverageMeter()
        val_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")
        train_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")

        # Setup references for runtime-bests
        best_train_loss = float("inf")
        best_val_loss = float("inf")

        ########################################################################
        ########################################################################
        #
        # Iterate over Epochs
        ########################################################################
        for epoch in range(num_epochs):
            self.model.train()
            train_loss_avg_meter.reset()
            train_f1.reset()

            tqdm_iter = tqdm.tqdm(train_dataloader, total=len(train_dataloader))
            tqdm_iter.set_description(f"Epoch {epoch}")
            
            for sample in tqdm_iter:
                # Reset Optimizer Gradients
                optimizer.zero_grad()

                # Gather Data Sample
                idx = sample["idx"].to(self.device)
                image = sample["image"].to(self.device)
                label = torch.vstack(sample["label"]).T

                # Forward Pass
                output = self.model(image)
                # Compute Loss
                loss = criterion(output, label.to(self.device).float())

                # Update Metric Meters
                train_loss_avg_meter.update(loss.item(), image.shape[0])
                output_with_activation = self.model.activation(output.detach()).cpu()
                train_f1.update(output_with_activation, label)
                tqdm_iter.set_postfix(
                    iter_train_loss=loss.item(), avg_train_loss=train_loss_avg_meter.avg
                )

                # Backpropagate
                loss.backward()
                optimizer.step()

            print(
                "Epoch %d - Average Train Loss: %.5f \t Train F1: %.5f"
                % (epoch, train_loss_avg_meter.avg, train_f1.compute().item())
            )
            if train_loss_avg_meter.avg < best_train_loss:
                best_train_loss = train_loss_avg_meter.avg
                if validation_dataset is None:
                    self.best_model = copy.deepcopy(self.model)

            ####################################################################################
            ####################################################################################
            #
            # Validation
            ####################################################################################
            VALIDATION_INTERVAL = self.hparams["validation_interval"]
            if (
                validation_dataset is not None
                and (epoch + 1) % VALIDATION_INTERVAL == 0
            ):
                self.model.eval()
                val_loss_avg_meter.reset()
                val_f1.reset()
                tqdm_iter = tqdm.tqdm(val_dataloader, total=len(val_dataloader))
                tqdm_iter.set_description(f"Validation at Epoch {epoch}")

                for sample in tqdm_iter:
                    with torch.no_grad():

                        idx = sample["idx"].to(self.device)
                        image = sample["image"].to(self.device)
                        label = torch.vstack(sample["label"]).T

                        output = self.model(image)
                        loss = criterion(output, label.to(self.device).float())
                        output_with_activation = self.model.activation(
                            output.detach()
                        ).cpu()
                        val_f1.update(output_with_activation, label)

                        val_loss_avg_meter.update(loss.item(), image.shape[0])
                        tqdm_iter.set_postfix(avg_val_loss=val_loss_avg_meter.avg)

                lr_sched.step()

                print(
                    "Epoch %d - Average Val Loss: %.5f \t Val F1: %.5f \t Learning Rate %0.5f"
                    % (
                        epoch,
                        val_loss_avg_meter.avg,
                        val_f1.compute().item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )
                if val_loss_avg_meter.avg < best_val_loss:
                    best_val_loss = val_loss_avg_meter.avg
                    self.best_model = copy.deepcopy(self.model)

                    

            train_metrics = {"f1": train_f1.compute().item()}
            val_metrics = {"f1": val_f1.compute().item()}
            info = {"learning_rate": optimizer.param_groups[0]["lr"]}
            self.update_runtime_metrics(
                epoch=epoch,
                total_epochs=num_epochs,
                train_loss=train_loss_avg_meter.avg,
                val_loss=val_loss_avg_meter.avg,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
                info=info,
            )

        # Load the best model if available
        self.load_best_model()
        
        return self.model
    
    def load_best_model(self):
        """
        Helper Function to load the best model (if available)
        """
        if self.best_model is not None:
            # Save the best_model as the internal model instance
            self.model = self.best_model.to(self.device)

    def predict(self, test_dataset, batch_size=32):
        # Load best model
        self.load_best_model()
        
        self.model.eval()

        transform = T.Compose(
            [
                T.ToTensor(),
                *self.model.required_transforms,
            ]
        )
        test_dataset.set_transform(transform)
        dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )
        outputs = []

        self.model.to(self.device)

        batches_processed = 0
        total_batches = len(dataloader)
        for data in tqdm.tqdm(dataloader, total=len(dataloader)):
            with torch.no_grad():
                image = data["image"].to(self.device)
                output = self.model(image)
                output_with_activation = self.model.activation(output).cpu().numpy()
                outputs.append(output_with_activation)

            batches_processed += 1
            prediction_progress = batches_processed / total_batches
            self.run_time_metrics["prediction_progress"] = prediction_progress
            if batches_processed % 5 == 0:
                # Avoid too many updates with the server
                self.sync_with_server(self.run_time_metrics)

        outputs = np.concatenate(outputs, axis=0)
        outputs = outputs > 0.5

        # Final sync with the server
        self.sync_with_server(self.run_time_metrics)
        return outputs

    def setup_runtime_metrics(self):

        self.metric_names = ["f1"]
        self.info_names = ["learning_rate"]

        def get_metric_dict():
            metric_dict = {}
            metric_dict["loss"] = []
            for name in self.metric_names:
                metric_dict[name] = []

            return metric_dict

        self.run_time_metrics = {
            "epoch": [],
            "train": get_metric_dict(),
            "validation": get_metric_dict(),
            "training_progress": 0.0,
            "prediction_progress": 0.0,
        }
        for name in self.info_names:
            self.run_time_metrics[name] = []

    def update_runtime_metrics(
        self,
        epoch,
        total_epochs,
        train_loss,
        val_loss,
        train_metrics,
        val_metrics,
        info,
    ):

        self.run_time_metrics["epoch"].append(epoch)

        self.run_time_metrics["train"]["loss"].append(train_loss)
        self.run_time_metrics["validation"]["loss"].append(val_loss)

        for mn in self.metric_names:
            self.run_time_metrics["train"][mn].append(train_metrics[mn])
            self.run_time_metrics["validation"][mn].append(val_metrics[mn])

        for ik, iv in info.items():
            self.run_time_metrics[ik].append(iv)

        progress = (epoch + 1) / total_epochs
        self.run_time_metrics["training_progress"] = progress
        # Sync run_time_metrics with the server
        self.sync_with_server(self.run_time_metrics)


class ZEWDPCDebugTrainer_1:
    """
    A Debug Trainer class, which can be used during development.
    It is not supposed to do anything meaningful, other than help
    with integration testing, and spit out random predictions
    """

    def __init__(
        self,
        num_classes=6,
        use_pretrained=True,
        hparams={},
        seed=42,
        sync_with_server=lambda x: x,
    ):
        print("Running ZEWDPCDebugTrainer !!!!")

        # A state dictionary which can hold any run_time metrics that need to be relayed to the user
        self.metric_names = ["f1"]

        def get_metric_dict():
            metric_dict = {}
            metric_dict["loss"] = list(np.random.rand(3))
            for name in self.metric_names:
                metric_dict[name] = list(np.random.rand(3))

            return metric_dict

        self.run_time_metrics = {
            "epoch": list(range(3)),
            "train": get_metric_dict(),
            "validation": get_metric_dict(),
        }

    def train(
        self, aggregated_dataset, num_epochs=1, validation_percentage=0.1, batch_size=32
    ):
        for sample in tqdm.tqdm(aggregated_dataset):
            pass
        return True

    def predict(self, test_dataset, batch_size=32):
        size_test_set = len(test_dataset)

        return torch.randint(low=0, high=2, size=(size_test_set, 6))


##################################################################################################################################
##### Final model Trainer No.2
##################################################################################################################################

class ZEWDPCTrainer_2:
    def __init__(
        self,
        num_classes=6,
        use_pretrained=True,
        hparams={},
        seed=42,
        sync_with_server=lambda x: x,
    ):
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.DEFAULT_HYPERPARAMETERS = (
            {  # A single dictionary to hold all configurabale params.
                "learning_rate": 1e-3,
                "validation_interval": 1,
                "LR Scheduler Patience": 5,
                "LR Scheduler Factor": 0.2,
            }
        )
        self.hparams = self.DEFAULT_HYPERPARAMETERS
        self.hparams.update(hparams)
        self.seed(seed)
        self.sync_with_server = sync_with_server

        # A state dictionary which can hold any run_time metrics that need to be relayed to the user
        self.setup_runtime_metrics()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.init_model()

    def seed(self, seed):
        torch.use_deterministic_algorithms(True)
        torch.manual_seed(seed)
        self.seed = seed
        

    def init_model(self):
        """Initialize the base model used to train the network"""
        self.model = ZEWDPCModel_2(
            num_classes=self.num_classes, use_pretrained=self.use_pretrained
        )
        self.model = self.model.to(self.device)
        self.best_model = None

    def setup_transforms(self, training_dataset):
        """
        Sets up the necessary transforms for the training_dataset
        """
        ## Setup necessary Transformations
        train_transform = T.Compose(
            [
                T.ToTensor(),  # Converts image to [0, 1]
                T.RandomVerticalFlip(p=0.5),
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.2, contrast=0.2),
                # TODO: Could agree with participants on the final list of transforms
                *self.model.required_transforms,
            ]
        )

        if isinstance(training_dataset, ZEWDPCBaseDataset):
            training_dataset.set_transform(train_transform)
        elif isinstance(training_dataset, torch.utils.data.ConcatDataset):
            for dataset in training_dataset.datasets:
                if isinstance(dataset, ZEWDPCRuntimeDataset):
                    dataset.set_transform(train_transform)
                elif isinstance(dataset, ZEWDPCBaseDataset):
                    dataset.set_transform(train_transform)
                else:
                    raise NotImplementedError()
        else:
            raise NotImplementedError()

    def setup_validation_set(self, training_dataset, validation_percentage=0.05):
        """
        Creates a Validation Set from the Training Dataset
        """
        assert (
            0 < validation_percentage < 1
        ), "Expected : validation_percentage ∈ [0, 1]. Received validataion_percentage = {}".format(
            validation_percentage
        )

        validation_size = int(validation_percentage * len(training_dataset))
        training_dataset, validation_dataset = torch.utils.data.random_split(
            training_dataset,
            [
                len(training_dataset) - validation_size,
                validation_size,
            ],
            generator=torch.Generator().manual_seed(self.seed),
        )
        return training_dataset, validation_dataset

    def setup_dataloaders(self, training_dataset, validation_dataset, batch_size=32):
        """
        Sets up necessary dataloader
        """
        train_dataloader = torch.utils.data.DataLoader(
            training_dataset, batch_size=batch_size, shuffle=True
        )
        val_dataloader = torch.utils.data.DataLoader(
            validation_dataset, batch_size=batch_size, shuffle=True
        )

        return train_dataloader, val_dataloader

    def train(
        self, training_dataset, num_epochs=20, validation_percentage=0.1, batch_size=32
    ):
        # Setup Transforms
        self.setup_transforms(training_dataset)

        # Prepare Validation Set
        training_dataset, validation_dataset = self.setup_validation_set(
            training_dataset, validation_percentage=validation_percentage
        )

        # Setup Dataloaders
        train_dataloader, val_dataloader = self.setup_dataloaders(
            training_dataset, validation_dataset, batch_size=batch_size
        )

        # Setup Criterion & Optimizer
        criterion = torch.nn.BCEWithLogitsLoss()
        lr = self.hparams["learning_rate"]
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=lr)

        lr_sched = CosineAnnealingLR(optimizer, num_epochs, eta_min=1e-5, last_epoch=-1, verbose=True)

        # Setup Metric Meters
        val_loss_avg_meter = AverageMeter()
        train_loss_avg_meter = AverageMeter()
        val_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")
        train_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")

        # Setup references for runtime-bests
        best_train_loss = float("inf")
        best_val_loss = float("inf")

        ########################################################################
        ########################################################################
        #
        # Iterate over Epochs
        ########################################################################
        for epoch in range(num_epochs):
            self.model.train()
            train_loss_avg_meter.reset()
            train_f1.reset()

            tqdm_iter = tqdm.tqdm(train_dataloader, total=len(train_dataloader))
            tqdm_iter.set_description(f"Epoch {epoch}")
            
            for sample in tqdm_iter:
                # Reset Optimizer Gradients
                optimizer.zero_grad()

                # Gather Data Sample
                idx = sample["idx"].to(self.device)
                image = sample["image"].to(self.device)
                label = torch.vstack(sample["label"]).T

                # Forward Pass
                output = self.model(image)
                # Compute Loss
                loss = criterion(output, label.to(self.device).float())

                # Update Metric Meters
                train_loss_avg_meter.update(loss.item(), image.shape[0])
                output_with_activation = self.model.activation(output.detach()).cpu()
                train_f1.update(output_with_activation, label)
                tqdm_iter.set_postfix(
                    iter_train_loss=loss.item(), avg_train_loss=train_loss_avg_meter.avg
                )

                # Backpropagate
                loss.backward()
                optimizer.step()

            print(
                "Epoch %d - Average Train Loss: %.5f \t Train F1: %.5f"
                % (epoch, train_loss_avg_meter.avg, train_f1.compute().item())
            )
            if train_loss_avg_meter.avg < best_train_loss:
                best_train_loss = train_loss_avg_meter.avg
                if validation_dataset is None:
                    self.best_model = copy.deepcopy(self.model)

            ####################################################################################
            ####################################################################################
            #
            # Validation
            ####################################################################################
            VALIDATION_INTERVAL = self.hparams["validation_interval"]
            if (
                validation_dataset is not None
                and (epoch + 1) % VALIDATION_INTERVAL == 0
            ):
                self.model.eval()
                val_loss_avg_meter.reset()
                val_f1.reset()
                tqdm_iter = tqdm.tqdm(val_dataloader, total=len(val_dataloader))
                tqdm_iter.set_description(f"Validation at Epoch {epoch}")

                for sample in tqdm_iter:
                    with torch.no_grad():

                        idx = sample["idx"].to(self.device)
                        image = sample["image"].to(self.device)
                        label = torch.vstack(sample["label"]).T

                        output = self.model(image)
                        loss = criterion(output, label.to(self.device).float())
                        output_with_activation = self.model.activation(
                            output.detach()
                        ).cpu()
                        val_f1.update(output_with_activation, label)

                        val_loss_avg_meter.update(loss.item(), image.shape[0])
                        tqdm_iter.set_postfix(avg_val_loss=val_loss_avg_meter.avg)

                lr_sched.step()

                print(
                    "Epoch %d - Average Val Loss: %.5f \t Val F1: %.5f \t Learning Rate %0.5f"
                    % (
                        epoch,
                        val_loss_avg_meter.avg,
                        val_f1.compute().item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )
                if val_loss_avg_meter.avg < best_val_loss:
                    best_val_loss = val_loss_avg_meter.avg
                    self.best_model = copy.deepcopy(self.model)

                    

            train_metrics = {"f1": train_f1.compute().item()}
            val_metrics = {"f1": val_f1.compute().item()}
            info = {"learning_rate": optimizer.param_groups[0]["lr"]}
            self.update_runtime_metrics(
                epoch=epoch,
                total_epochs=num_epochs,
                train_loss=train_loss_avg_meter.avg,
                val_loss=val_loss_avg_meter.avg,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
                info=info,
            )

        # Load the best model if available
        self.load_best_model()
        
        return self.model
    
    def load_best_model(self):
        """
        Helper Function to load the best model (if available)
        """
        if self.best_model is not None:
            # Save the best_model as the internal model instance
            self.model = self.best_model.to(self.device)

    def predict(self, test_dataset, batch_size=32):
        # Load best model
        self.load_best_model()
        
        self.model.eval()

        transform = T.Compose(
            [
                T.ToTensor(),
                *self.model.required_transforms,
            ]
        )
        test_dataset.set_transform(transform)
        dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )
        outputs = []

        self.model.to(self.device)

        batches_processed = 0
        total_batches = len(dataloader)
        for data in tqdm.tqdm(dataloader, total=len(dataloader)):
            with torch.no_grad():
                image = data["image"].to(self.device)
                output = self.model(image)
                output_with_activation = self.model.activation(output).cpu().numpy()
                outputs.append(output_with_activation)

            batches_processed += 1
            prediction_progress = batches_processed / total_batches
            self.run_time_metrics["prediction_progress"] = prediction_progress
            if batches_processed % 5 == 0:
                # Avoid too many updates with the server
                self.sync_with_server(self.run_time_metrics)

        outputs = np.concatenate(outputs, axis=0)
        outputs = outputs > 0.5

        # Final sync with the server
        self.sync_with_server(self.run_time_metrics)
        return outputs

    def setup_runtime_metrics(self):

        self.metric_names = ["f1"]
        self.info_names = ["learning_rate"]

        def get_metric_dict():
            metric_dict = {}
            metric_dict["loss"] = []
            for name in self.metric_names:
                metric_dict[name] = []

            return metric_dict

        self.run_time_metrics = {
            "epoch": [],
            "train": get_metric_dict(),
            "validation": get_metric_dict(),
            "training_progress": 0.0,
            "prediction_progress": 0.0,
        }
        for name in self.info_names:
            self.run_time_metrics[name] = []

    def update_runtime_metrics(
        self,
        epoch,
        total_epochs,
        train_loss,
        val_loss,
        train_metrics,
        val_metrics,
        info,
    ):

        self.run_time_metrics["epoch"].append(epoch)

        self.run_time_metrics["train"]["loss"].append(train_loss)
        self.run_time_metrics["validation"]["loss"].append(val_loss)

        for mn in self.metric_names:
            self.run_time_metrics["train"][mn].append(train_metrics[mn])
            self.run_time_metrics["validation"][mn].append(val_metrics[mn])

        for ik, iv in info.items():
            self.run_time_metrics[ik].append(iv)

        progress = (epoch + 1) / total_epochs
        self.run_time_metrics["training_progress"] = progress
        # Sync run_time_metrics with the server
        self.sync_with_server(self.run_time_metrics)
        
        
##################################################################################################################################
##### Final model Trainer No.3
##################################################################################################################################

class ZEWDPCTrainer_3:
    def __init__(
        self,
        num_classes=6,
        use_pretrained=True,
        hparams={},
        seed=42,
        sync_with_server=lambda x: x,
    ):
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.DEFAULT_HYPERPARAMETERS = (
            {  # A single dictionary to hold all configurabale params.
                "learning_rate": 1e-3,
                "validation_interval": 1,
                "LR Scheduler Patience": 5,
                "LR Scheduler Factor": 0.2,
            }
        )
        self.hparams = self.DEFAULT_HYPERPARAMETERS
        self.hparams.update(hparams)
        self.seed(seed)
        self.sync_with_server = sync_with_server

        # A state dictionary which can hold any run_time metrics that need to be relayed to the user
        self.setup_runtime_metrics()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.init_model()

    def seed(self, seed):
        torch.use_deterministic_algorithms(True)
        torch.manual_seed(seed)
        self.seed = seed
        

    def init_model(self):
        """Initialize the base model used to train the network"""
        self.model = ZEWDPCModel_3(
            num_classes=self.num_classes, use_pretrained=self.use_pretrained
        )
        self.model = self.model.to(self.device)
        self.best_model = None

    def setup_transforms(self, training_dataset):
        """
        Sets up the necessary transforms for the training_dataset
        """
        ## Setup necessary Transformations
        train_transform = T.Compose(
            [
                T.ToTensor(),  # Converts image to [0, 1]
                T.RandomVerticalFlip(p=0.5),
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.2, contrast=0.2),
                # TODO: Could agree with participants on the final list of transforms
                *self.model.required_transforms,
            ]
        )

        if isinstance(training_dataset, ZEWDPCBaseDataset):
            training_dataset.set_transform(train_transform)
        elif isinstance(training_dataset, torch.utils.data.ConcatDataset):
            for dataset in training_dataset.datasets:
                if isinstance(dataset, ZEWDPCRuntimeDataset):
                    dataset.set_transform(train_transform)
                elif isinstance(dataset, ZEWDPCBaseDataset):
                    dataset.set_transform(train_transform)
                else:
                    raise NotImplementedError()
        else:
            raise NotImplementedError()

    def setup_validation_set(self, training_dataset, validation_percentage=0.05):
        """
        Creates a Validation Set from the Training Dataset
        """
        assert (
            0 < validation_percentage < 1
        ), "Expected : validation_percentage ∈ [0, 1]. Received validataion_percentage = {}".format(
            validation_percentage
        )

        validation_size = int(validation_percentage * len(training_dataset))
        training_dataset, validation_dataset = torch.utils.data.random_split(
            training_dataset,
            [
                len(training_dataset) - validation_size,
                validation_size,
            ],
            generator=torch.Generator().manual_seed(self.seed),
        )
        return training_dataset, validation_dataset

    def setup_dataloaders(self, training_dataset, validation_dataset, batch_size=32):
        """
        Sets up necessary dataloader
        """
        train_dataloader = torch.utils.data.DataLoader(
            training_dataset, batch_size=batch_size, shuffle=True
        )
        val_dataloader = torch.utils.data.DataLoader(
            validation_dataset, batch_size=batch_size, shuffle=True
        )

        return train_dataloader, val_dataloader

    def train(
        self, training_dataset, num_epochs=20, validation_percentage=0.1, batch_size=16  ############ BATCH_SIZE
    ):
        # Setup Transforms
        self.setup_transforms(training_dataset)

        # Prepare Validation Set
        training_dataset, validation_dataset = self.setup_validation_set(
            training_dataset, validation_percentage=validation_percentage
        )

        # Setup Dataloaders
        train_dataloader, val_dataloader = self.setup_dataloaders(
            training_dataset, validation_dataset, batch_size=batch_size
        )

        # Setup Criterion & Optimizer
        criterion = torch.nn.BCEWithLogitsLoss()
        lr = self.hparams["learning_rate"]
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=lr)

        # Setup Metric Meters
        val_loss_avg_meter = AverageMeter()
        train_loss_avg_meter = AverageMeter()
        val_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")
        train_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")

        # Setup references for runtime-bests
        best_train_loss = float("inf")
        best_val_loss = float("inf")

        ########################################################################
        ########################################################################
        #
        # Iterate over Epochs
        ########################################################################
        for epoch in range(num_epochs):
            self.model.train()
            train_loss_avg_meter.reset()
            train_f1.reset()

            tqdm_iter = tqdm.tqdm(train_dataloader, total=len(train_dataloader))
            tqdm_iter.set_description(f"Epoch {epoch}")
            
            for sample in tqdm_iter:
                # Reset Optimizer Gradients
                optimizer.zero_grad()

                # Gather Data Sample
                idx = sample["idx"].to(self.device)
                image = sample["image"].to(self.device)
                label = torch.vstack(sample["label"]).T

                # Forward Pass
                output = self.model(image)
                # Compute Loss
                loss = criterion(output, label.to(self.device).float())

                # Update Metric Meters
                train_loss_avg_meter.update(loss.item(), image.shape[0])
                output_with_activation = self.model.activation(output.detach()).cpu()
                train_f1.update(output_with_activation, label)
                tqdm_iter.set_postfix(
                    iter_train_loss=loss.item(), avg_train_loss=train_loss_avg_meter.avg
                )

                # Backpropagate
                loss.backward()
                optimizer.step()

            print(
                "Epoch %d - Average Train Loss: %.5f \t Train F1: %.5f"
                % (epoch, train_loss_avg_meter.avg, train_f1.compute().item())
            )
            if train_loss_avg_meter.avg < best_train_loss:
                best_train_loss = train_loss_avg_meter.avg
                if validation_dataset is None:
                    self.best_model = copy.deepcopy(self.model)

            ####################################################################################
            ####################################################################################
            #
            # Validation
            ####################################################################################
            VALIDATION_INTERVAL = self.hparams["validation_interval"]
            if (
                validation_dataset is not None
                and (epoch + 1) % VALIDATION_INTERVAL == 0
            ):
                self.model.eval()
                val_loss_avg_meter.reset()
                val_f1.reset()
                tqdm_iter = tqdm.tqdm(val_dataloader, total=len(val_dataloader))
                tqdm_iter.set_description(f"Validation at Epoch {epoch}")

                for sample in tqdm_iter:
                    with torch.no_grad():

                        idx = sample["idx"].to(self.device)
                        image = sample["image"].to(self.device)
                        label = torch.vstack(sample["label"]).T

                        output = self.model(image)
                        loss = criterion(output, label.to(self.device).float())
                        output_with_activation = self.model.activation(
                            output.detach()
                        ).cpu()
                        val_f1.update(output_with_activation, label)

                        val_loss_avg_meter.update(loss.item(), image.shape[0])
                        tqdm_iter.set_postfix(avg_val_loss=val_loss_avg_meter.avg)

                print(
                    "Epoch %d - Average Val Loss: %.5f \t Val F1: %.5f \t Learning Rate %0.5f"
                    % (
                        epoch,
                        val_loss_avg_meter.avg,
                        val_f1.compute().item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )
                if val_loss_avg_meter.avg < best_val_loss:
                    best_val_loss = val_loss_avg_meter.avg
                    self.best_model = copy.deepcopy(self.model)

            train_metrics = {"f1": train_f1.compute().item()}
            val_metrics = {"f1": val_f1.compute().item()}
            info = {"learning_rate": optimizer.param_groups[0]["lr"]}
            self.update_runtime_metrics(
                epoch=epoch,
                total_epochs=num_epochs,
                train_loss=train_loss_avg_meter.avg,
                val_loss=val_loss_avg_meter.avg,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
                info=info,
            )

        # Load the best model if available
        self.load_best_model()
        
        return self.model
    
    def load_best_model(self):
        """
        Helper Function to load the best model (if available)
        """
        if self.best_model is not None:
            # Save the best_model as the internal model instance
            self.model = self.best_model.to(self.device)

    def predict(self, test_dataset, batch_size=32):
        # Load best model
        self.load_best_model()
        
        self.model.eval()

        transform = T.Compose(
            [
                T.ToTensor(),
                *self.model.required_transforms,
            ]
        )
        test_dataset.set_transform(transform)
        dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )
        outputs = []

        self.model.to(self.device)

        batches_processed = 0
        total_batches = len(dataloader)
        for data in tqdm.tqdm(dataloader, total=len(dataloader)):
            with torch.no_grad():
                image = data["image"].to(self.device)
                output = self.model(image)
                output_with_activation = self.model.activation(output).cpu().numpy()
                outputs.append(output_with_activation)

            batches_processed += 1
            prediction_progress = batches_processed / total_batches
            self.run_time_metrics["prediction_progress"] = prediction_progress
            if batches_processed % 5 == 0:
                # Avoid too many updates with the server
                self.sync_with_server(self.run_time_metrics)

        outputs = np.concatenate(outputs, axis=0)
        outputs = outputs > 0.5

        # Final sync with the server
        self.sync_with_server(self.run_time_metrics)
        return outputs

    def setup_runtime_metrics(self):

        self.metric_names = ["f1"]
        self.info_names = ["learning_rate"]

        def get_metric_dict():
            metric_dict = {}
            metric_dict["loss"] = []
            for name in self.metric_names:
                metric_dict[name] = []

            return metric_dict

        self.run_time_metrics = {
            "epoch": [],
            "train": get_metric_dict(),
            "validation": get_metric_dict(),
            "training_progress": 0.0,
            "prediction_progress": 0.0,
        }
        for name in self.info_names:
            self.run_time_metrics[name] = []

    def update_runtime_metrics(
        self,
        epoch,
        total_epochs,
        train_loss,
        val_loss,
        train_metrics,
        val_metrics,
        info,
    ):

        self.run_time_metrics["epoch"].append(epoch)

        self.run_time_metrics["train"]["loss"].append(train_loss)
        self.run_time_metrics["validation"]["loss"].append(val_loss)

        for mn in self.metric_names:
            self.run_time_metrics["train"][mn].append(train_metrics[mn])
            self.run_time_metrics["validation"][mn].append(val_metrics[mn])

        for ik, iv in info.items():
            self.run_time_metrics[ik].append(iv)

        progress = (epoch + 1) / total_epochs
        self.run_time_metrics["training_progress"] = progress
        # Sync run_time_metrics with the server
        self.sync_with_server(self.run_time_metrics)



##################################################################################################################################
##### Final model Trainer No.4
##################################################################################################################################

class ZEWDPCTrainer_4:
    def __init__(
        self,
        num_classes=6,
        use_pretrained=True,
        hparams={},
        seed=42,
        sync_with_server=lambda x: x,
    ):
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.DEFAULT_HYPERPARAMETERS = (
            {  # A single dictionary to hold all configurabale params.
                "learning_rate": 1e-3,
                "validation_interval": 1,
                "LR Scheduler Patience": 5,
                "LR Scheduler Factor": 0.2,
            }
        )
        self.hparams = self.DEFAULT_HYPERPARAMETERS
        self.hparams.update(hparams)
        self.seed(seed)
        self.sync_with_server = sync_with_server

        # A state dictionary which can hold any run_time metrics that need to be relayed to the user
        self.setup_runtime_metrics()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.init_model()

    def seed(self, seed):
        torch.use_deterministic_algorithms(True)
        torch.manual_seed(seed)
        self.seed = seed
        

    def init_model(self):
        """Initialize the base model used to train the network"""
        self.model = ZEWDPCModel_4(
            num_classes=self.num_classes, use_pretrained=self.use_pretrained
        )
        self.model = self.model.to(self.device)
        self.best_model = None

    def setup_transforms(self, training_dataset):
        """
        Sets up the necessary transforms for the training_dataset
        """
        ## Setup necessary Transformations
        train_transform = T.Compose(
            [
                T.ToTensor(),  # Converts image to [0, 1]
                T.RandomVerticalFlip(p=0.5),
                T.RandomHorizontalFlip(p=0.5),
                T.GaussianBlur(kernel_size=3),
                T.ColorJitter(brightness=0.2, contrast=0.2),
                # TODO: Could agree with participants on the final list of transforms
                *self.model.required_transforms,
            ]
        )

        if isinstance(training_dataset, ZEWDPCBaseDataset):
            training_dataset.set_transform(train_transform)
        elif isinstance(training_dataset, torch.utils.data.ConcatDataset):
            for dataset in training_dataset.datasets:
                if isinstance(dataset, ZEWDPCRuntimeDataset):
                    dataset.set_transform(train_transform)
                elif isinstance(dataset, ZEWDPCBaseDataset):
                    dataset.set_transform(train_transform)
                else:
                    raise NotImplementedError()
        else:
            raise NotImplementedError()

    def setup_validation_set(self, training_dataset, validation_percentage=0.05):
        """
        Creates a Validation Set from the Training Dataset
        """
        assert (
            0 < validation_percentage < 1
        ), "Expected : validation_percentage ∈ [0, 1]. Received validataion_percentage = {}".format(
            validation_percentage
        )

        validation_size = int(validation_percentage * len(training_dataset))
        training_dataset, validation_dataset = torch.utils.data.random_split(
            training_dataset,
            [
                len(training_dataset) - validation_size,
                validation_size,
            ],
            generator=torch.Generator().manual_seed(self.seed),
        )
        return training_dataset, validation_dataset

    def setup_dataloaders(self, training_dataset, validation_dataset, batch_size=32):
        """
        Sets up necessary dataloader
        """
        train_dataloader = torch.utils.data.DataLoader(
            training_dataset, batch_size=batch_size, shuffle=True
        )
        val_dataloader = torch.utils.data.DataLoader(
            validation_dataset, batch_size=batch_size, shuffle=True
        )

        return train_dataloader, val_dataloader

    def train(
        self, training_dataset, num_epochs=20, validation_percentage=0.1, batch_size=16  ############ BATCH_SIZE
    ):
        # Setup Transforms
        self.setup_transforms(training_dataset)

        # Prepare Validation Set
        training_dataset, validation_dataset = self.setup_validation_set(
            training_dataset, validation_percentage=validation_percentage
        )

        # Setup Dataloaders
        train_dataloader, val_dataloader = self.setup_dataloaders(
            training_dataset, validation_dataset, batch_size=batch_size
        )

        # Setup Criterion & Optimizer
        criterion = torch.nn.BCEWithLogitsLoss()
        lr = self.hparams["learning_rate"]
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=lr)

        # Setup Metric Meters
        val_loss_avg_meter = AverageMeter()
        train_loss_avg_meter = AverageMeter()
        val_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")
        train_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")

        # Setup references for runtime-bests
        best_train_loss = float("inf")
        best_val_loss = float("inf")

        ########################################################################
        ########################################################################
        #
        # Iterate over Epochs
        ########################################################################
        for epoch in range(num_epochs):
            self.model.train()
            train_loss_avg_meter.reset()
            train_f1.reset()

            tqdm_iter = tqdm.tqdm(train_dataloader, total=len(train_dataloader))
            tqdm_iter.set_description(f"Epoch {epoch}")
            
            for sample in tqdm_iter:
                # Reset Optimizer Gradients
                optimizer.zero_grad()

                # Gather Data Sample
                idx = sample["idx"].to(self.device)
                image = sample["image"].to(self.device)
                label = torch.vstack(sample["label"]).T

                # Forward Pass
                output = self.model(image)
                # Compute Loss
                loss = criterion(output, label.to(self.device).float())

                # Update Metric Meters
                train_loss_avg_meter.update(loss.item(), image.shape[0])
                output_with_activation = self.model.activation(output.detach()).cpu()
                train_f1.update(output_with_activation, label)
                tqdm_iter.set_postfix(
                    iter_train_loss=loss.item(), avg_train_loss=train_loss_avg_meter.avg
                )

                # Backpropagate
                loss.backward()
                optimizer.step()

            print(
                "Epoch %d - Average Train Loss: %.5f \t Train F1: %.5f"
                % (epoch, train_loss_avg_meter.avg, train_f1.compute().item())
            )
            if train_loss_avg_meter.avg < best_train_loss:
                best_train_loss = train_loss_avg_meter.avg
                if validation_dataset is None:
                    self.best_model = copy.deepcopy(self.model)

            ####################################################################################
            ####################################################################################
            #
            # Validation
            ####################################################################################
            VALIDATION_INTERVAL = self.hparams["validation_interval"]
            if (
                validation_dataset is not None
                and (epoch + 1) % VALIDATION_INTERVAL == 0
            ):
                self.model.eval()
                val_loss_avg_meter.reset()
                val_f1.reset()
                tqdm_iter = tqdm.tqdm(val_dataloader, total=len(val_dataloader))
                tqdm_iter.set_description(f"Validation at Epoch {epoch}")

                for sample in tqdm_iter:
                    with torch.no_grad():

                        idx = sample["idx"].to(self.device)
                        image = sample["image"].to(self.device)
                        label = torch.vstack(sample["label"]).T

                        output = self.model(image)
                        loss = criterion(output, label.to(self.device).float())
                        output_with_activation = self.model.activation(
                            output.detach()
                        ).cpu()
                        val_f1.update(output_with_activation, label)

                        val_loss_avg_meter.update(loss.item(), image.shape[0])
                        tqdm_iter.set_postfix(avg_val_loss=val_loss_avg_meter.avg)

                print(
                    "Epoch %d - Average Val Loss: %.5f \t Val F1: %.5f \t Learning Rate %0.5f"
                    % (
                        epoch,
                        val_loss_avg_meter.avg,
                        val_f1.compute().item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )
                if val_loss_avg_meter.avg < best_val_loss:
                    best_val_loss = val_loss_avg_meter.avg
                    self.best_model = copy.deepcopy(self.model)

            train_metrics = {"f1": train_f1.compute().item()}
            val_metrics = {"f1": val_f1.compute().item()}
            info = {"learning_rate": optimizer.param_groups[0]["lr"]}
            self.update_runtime_metrics(
                epoch=epoch,
                total_epochs=num_epochs,
                train_loss=train_loss_avg_meter.avg,
                val_loss=val_loss_avg_meter.avg,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
                info=info,
            )

        # Load the best model if available
        self.load_best_model()
        
        return self.model
    
    def load_best_model(self):
        """
        Helper Function to load the best model (if available)
        """
        if self.best_model is not None:
            # Save the best_model as the internal model instance
            self.model = self.best_model.to(self.device)

    def predict(self, test_dataset, batch_size=32):
        # Load best model
        self.load_best_model()
        
        self.model.eval()

        transform = T.Compose(
            [
                T.ToTensor(),
                *self.model.required_transforms,
            ]
        )
        test_dataset.set_transform(transform)
        dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )
        outputs = []

        self.model.to(self.device)

        batches_processed = 0
        total_batches = len(dataloader)
        for data in tqdm.tqdm(dataloader, total=len(dataloader)):
            with torch.no_grad():
                image = data["image"].to(self.device)
                output = self.model(image)
                output_with_activation = self.model.activation(output).cpu().numpy()
                outputs.append(output_with_activation)

            batches_processed += 1
            prediction_progress = batches_processed / total_batches
            self.run_time_metrics["prediction_progress"] = prediction_progress
            if batches_processed % 5 == 0:
                # Avoid too many updates with the server
                self.sync_with_server(self.run_time_metrics)

        outputs = np.concatenate(outputs, axis=0)
        outputs = outputs > 0.5

        # Final sync with the server
        self.sync_with_server(self.run_time_metrics)
        return outputs

    def setup_runtime_metrics(self):

        self.metric_names = ["f1"]
        self.info_names = ["learning_rate"]

        def get_metric_dict():
            metric_dict = {}
            metric_dict["loss"] = []
            for name in self.metric_names:
                metric_dict[name] = []

            return metric_dict

        self.run_time_metrics = {
            "epoch": [],
            "train": get_metric_dict(),
            "validation": get_metric_dict(),
            "training_progress": 0.0,
            "prediction_progress": 0.0,
        }
        for name in self.info_names:
            self.run_time_metrics[name] = []

    def update_runtime_metrics(
        self,
        epoch,
        total_epochs,
        train_loss,
        val_loss,
        train_metrics,
        val_metrics,
        info,
    ):

        self.run_time_metrics["epoch"].append(epoch)

        self.run_time_metrics["train"]["loss"].append(train_loss)
        self.run_time_metrics["validation"]["loss"].append(val_loss)

        for mn in self.metric_names:
            self.run_time_metrics["train"][mn].append(train_metrics[mn])
            self.run_time_metrics["validation"][mn].append(val_metrics[mn])

        for ik, iv in info.items():
            self.run_time_metrics[ik].append(iv)

        progress = (epoch + 1) / total_epochs
        self.run_time_metrics["training_progress"] = progress
        # Sync run_time_metrics with the server
        self.sync_with_server(self.run_time_metrics)


##################################################################################################################################
##### Final model Trainer No.5
##################################################################################################################################

class ZEWDPCTrainer_5:
    def __init__(
        self,
        num_classes=6,
        use_pretrained=True,
        hparams={},
        seed=42,
        sync_with_server=lambda x: x,
    ):
        self.num_classes = num_classes
        self.use_pretrained = use_pretrained

        self.DEFAULT_HYPERPARAMETERS = (
            {  # A single dictionary to hold all configurabale params.
                "learning_rate": 1e-3,
                "validation_interval": 1,
                "LR Scheduler Patience": 5,
                "LR Scheduler Factor": 0.2,
            }
        )
        self.hparams = self.DEFAULT_HYPERPARAMETERS
        self.hparams.update(hparams)
        self.seed(seed)
        self.sync_with_server = sync_with_server

        # A state dictionary which can hold any run_time metrics that need to be relayed to the user
        self.setup_runtime_metrics()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.init_model()

    def seed(self, seed):
        torch.use_deterministic_algorithms(True)
        torch.manual_seed(seed)
        self.seed = seed
        

    def init_model(self):
        """Initialize the base model used to train the network"""
        self.model = ZEWDPCModel_5(
            num_classes=self.num_classes, use_pretrained=self.use_pretrained
        )
        self.model = self.model.to(self.device)
        self.best_model = None

    def setup_transforms(self, training_dataset):
        """
        Sets up the necessary transforms for the training_dataset
        """
        ## Setup necessary Transformations
        train_transform = T.Compose(
            [
                T.ToTensor(),  # Converts image to [0, 1]
                T.RandomVerticalFlip(p=0.5),
                T.RandomHorizontalFlip(p=0.5),
                T.GaussianBlur(kernel_size=3),
                T.ColorJitter(brightness=0.2, contrast=0.2),
                # TODO: Could agree with participants on the final list of transforms
                *self.model.required_transforms,
            ]
        )

        if isinstance(training_dataset, ZEWDPCBaseDataset):
            training_dataset.set_transform(train_transform)
        elif isinstance(training_dataset, torch.utils.data.ConcatDataset):
            for dataset in training_dataset.datasets:
                if isinstance(dataset, ZEWDPCRuntimeDataset):
                    dataset.set_transform(train_transform)
                elif isinstance(dataset, ZEWDPCBaseDataset):
                    dataset.set_transform(train_transform)
                else:
                    raise NotImplementedError()
        else:
            raise NotImplementedError()

    def setup_validation_set(self, training_dataset, validation_percentage=0.05):
        """
        Creates a Validation Set from the Training Dataset
        """
        assert (
            0 < validation_percentage < 1
        ), "Expected : validation_percentage ∈ [0, 1]. Received validataion_percentage = {}".format(
            validation_percentage
        )

        validation_size = int(validation_percentage * len(training_dataset))
        training_dataset, validation_dataset = torch.utils.data.random_split(
            training_dataset,
            [
                len(training_dataset) - validation_size,
                validation_size,
            ],
            generator=torch.Generator().manual_seed(self.seed),
        )
        return training_dataset, validation_dataset

    def setup_dataloaders(self, training_dataset, validation_dataset, batch_size=32):
        """
        Sets up necessary dataloader
        """
        train_dataloader = torch.utils.data.DataLoader(
            training_dataset, batch_size=batch_size, shuffle=True
        )
        val_dataloader = torch.utils.data.DataLoader(
            validation_dataset, batch_size=batch_size, shuffle=True
        )

        return train_dataloader, val_dataloader

    def train(
        self, training_dataset, num_epochs=20, validation_percentage=0.1, batch_size=32  ############ BATCH_SIZE
    ):
        # Setup Transforms
        self.setup_transforms(training_dataset)

        # Prepare Validation Set
        training_dataset, validation_dataset = self.setup_validation_set(
            training_dataset, validation_percentage=validation_percentage
        )

        # Setup Dataloaders
        train_dataloader, val_dataloader = self.setup_dataloaders(
            training_dataset, validation_dataset, batch_size=batch_size
        )

        # Setup Criterion & Optimizer
        criterion = torch.nn.BCEWithLogitsLoss()
        lr = self.hparams["learning_rate"]
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=lr)

        # Setup Metric Meters
        val_loss_avg_meter = AverageMeter()
        train_loss_avg_meter = AverageMeter()
        val_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")
        train_f1 = torchmetrics.F1Score(num_classes=self.num_classes, average="macro")

        # Setup references for runtime-bests
        best_train_loss = float("inf")
        best_val_loss = float("inf")

        ########################################################################
        ########################################################################
        #
        # Iterate over Epochs
        ########################################################################
        for epoch in range(num_epochs):
            self.model.train()
            train_loss_avg_meter.reset()
            train_f1.reset()

            tqdm_iter = tqdm.tqdm(train_dataloader, total=len(train_dataloader))
            tqdm_iter.set_description(f"Epoch {epoch}")
            
            for sample in tqdm_iter:
                # Reset Optimizer Gradients
                optimizer.zero_grad()

                # Gather Data Sample
                idx = sample["idx"].to(self.device)
                image = sample["image"].to(self.device)
                label = torch.vstack(sample["label"]).T

                # Forward Pass
                output = self.model(image)
                # Compute Loss
                loss = criterion(output, label.to(self.device).float())

                # Update Metric Meters
                train_loss_avg_meter.update(loss.item(), image.shape[0])
                output_with_activation = self.model.activation(output.detach()).cpu()
                train_f1.update(output_with_activation, label)
                tqdm_iter.set_postfix(
                    iter_train_loss=loss.item(), avg_train_loss=train_loss_avg_meter.avg
                )

                # Backpropagate
                loss.backward()
                optimizer.step()

            print(
                "Epoch %d - Average Train Loss: %.5f \t Train F1: %.5f"
                % (epoch, train_loss_avg_meter.avg, train_f1.compute().item())
            )
            if train_loss_avg_meter.avg < best_train_loss:
                best_train_loss = train_loss_avg_meter.avg
                if validation_dataset is None:
                    self.best_model = copy.deepcopy(self.model)

            ####################################################################################
            ####################################################################################
            #
            # Validation
            ####################################################################################
            VALIDATION_INTERVAL = self.hparams["validation_interval"]
            if (
                validation_dataset is not None
                and (epoch + 1) % VALIDATION_INTERVAL == 0
            ):
                self.model.eval()
                val_loss_avg_meter.reset()
                val_f1.reset()
                tqdm_iter = tqdm.tqdm(val_dataloader, total=len(val_dataloader))
                tqdm_iter.set_description(f"Validation at Epoch {epoch}")

                for sample in tqdm_iter:
                    with torch.no_grad():

                        idx = sample["idx"].to(self.device)
                        image = sample["image"].to(self.device)
                        label = torch.vstack(sample["label"]).T

                        output = self.model(image)
                        loss = criterion(output, label.to(self.device).float())
                        output_with_activation = self.model.activation(
                            output.detach()
                        ).cpu()
                        val_f1.update(output_with_activation, label)

                        val_loss_avg_meter.update(loss.item(), image.shape[0])
                        tqdm_iter.set_postfix(avg_val_loss=val_loss_avg_meter.avg)

                print(
                    "Epoch %d - Average Val Loss: %.5f \t Val F1: %.5f \t Learning Rate %0.5f"
                    % (
                        epoch,
                        val_loss_avg_meter.avg,
                        val_f1.compute().item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )
                if val_loss_avg_meter.avg < best_val_loss:
                    best_val_loss = val_loss_avg_meter.avg
                    self.best_model = copy.deepcopy(self.model)

            train_metrics = {"f1": train_f1.compute().item()}
            val_metrics = {"f1": val_f1.compute().item()}
            info = {"learning_rate": optimizer.param_groups[0]["lr"]}
            self.update_runtime_metrics(
                epoch=epoch,
                total_epochs=num_epochs,
                train_loss=train_loss_avg_meter.avg,
                val_loss=val_loss_avg_meter.avg,
                train_metrics=train_metrics,
                val_metrics=val_metrics,
                info=info,
            )

        # Load the best model if available
        self.load_best_model()
        
        return self.model
    
    def load_best_model(self):
        """
        Helper Function to load the best model (if available)
        """
        if self.best_model is not None:
            # Save the best_model as the internal model instance
            self.model = self.best_model.to(self.device)

    def predict(self, test_dataset, batch_size=32):
        # Load best model
        self.load_best_model()
        
        self.model.eval()

        transform = T.Compose(
            [
                T.ToTensor(),
                *self.model.required_transforms,
            ]
        )
        test_dataset.set_transform(transform)
        dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )
        outputs = []

        self.model.to(self.device)

        batches_processed = 0
        total_batches = len(dataloader)
        for data in tqdm.tqdm(dataloader, total=len(dataloader)):
            with torch.no_grad():
                image = data["image"].to(self.device)
                output = self.model(image)
                output_with_activation = self.model.activation(output).cpu().numpy()
                outputs.append(output_with_activation)

            batches_processed += 1
            prediction_progress = batches_processed / total_batches
            self.run_time_metrics["prediction_progress"] = prediction_progress
            if batches_processed % 5 == 0:
                # Avoid too many updates with the server
                self.sync_with_server(self.run_time_metrics)

        outputs = np.concatenate(outputs, axis=0)
        outputs = outputs > 0.5

        # Final sync with the server
        self.sync_with_server(self.run_time_metrics)
        return outputs

    def setup_runtime_metrics(self):

        self.metric_names = ["f1"]
        self.info_names = ["learning_rate"]

        def get_metric_dict():
            metric_dict = {}
            metric_dict["loss"] = []
            for name in self.metric_names:
                metric_dict[name] = []

            return metric_dict

        self.run_time_metrics = {
            "epoch": [],
            "train": get_metric_dict(),
            "validation": get_metric_dict(),
            "training_progress": 0.0,
            "prediction_progress": 0.0,
        }
        for name in self.info_names:
            self.run_time_metrics[name] = []

    def update_runtime_metrics(
        self,
        epoch,
        total_epochs,
        train_loss,
        val_loss,
        train_metrics,
        val_metrics,
        info,
    ):

        self.run_time_metrics["epoch"].append(epoch)

        self.run_time_metrics["train"]["loss"].append(train_loss)
        self.run_time_metrics["validation"]["loss"].append(val_loss)

        for mn in self.metric_names:
            self.run_time_metrics["train"][mn].append(train_metrics[mn])
            self.run_time_metrics["validation"][mn].append(val_metrics[mn])

        for ik, iv in info.items():
            self.run_time_metrics[ik].append(iv)

        progress = (epoch + 1) / total_epochs
        self.run_time_metrics["training_progress"] = progress
        # Sync run_time_metrics with the server
        self.sync_with_server(self.run_time_metrics)


#### DEBUG

if __name__ == "__main__":

    DATASET_SHUFFLE_SEED = 1022022
    training_dataset = ZEWDPCBaseDataset(
        images_dir="./data/v0.2-rc4/dataset_debug/training/images",
        labels_path="./data/v0.2-rc4/dataset_debug/training/labels.csv",
        shuffle_seed=DATASET_SHUFFLE_SEED,
    )
    unlabelled_dataset = ZEWDPCProtectedDataset(
        images_dir="./data/v0.2-rc4/dataset_debug/unlabelled/images",
        labels_path="./data/v0.2-rc4/dataset_debug/unlabelled/labels.csv",
        purchase_budget=50,
        shuffle_seed=DATASET_SHUFFLE_SEED,
    )
    test_dataset = ZEWDPCBaseDataset(
        images_dir="./data/v0.2-rc4/dataset_debug/validation/images",
        labels_path="./data/v0.2-rc4/dataset_debug/validation/labels.csv",
        drop_labels=False,
        shuffle_seed=DATASET_SHUFFLE_SEED,
    )

    purchased_labels = {}
    for sample in tqdm.tqdm(unlabelled_dataset):
        idx = sample["idx"]
        try:
            label = unlabelled_dataset.purchase_label(idx)
            purchased_labels[idx] = label
        except OutOfBudetException:
            break

    # Create a runtime instance of the purchased dataset with the right labels
    purchased_dataset = instantiate_purchased_dataset(
        unlabelled_dataset, purchased_labels
    )

    aggregated_dataset = torch.utils.data.ConcatDataset(
        [training_dataset, purchased_dataset]
    )

    print("Training Dataset Size : ", len(training_dataset))
    print("Purchased Dataset Size : ", len(purchased_dataset))
    print("Aggregataed Dataset Size : ", len(aggregated_dataset))

    DEBUG_MODE = os.getenv("AICROWD_DEBUG_MODE", False)
    if DEBUG_MODE:
        TRAINER_CLASS = ZEWDPCDebugTrainer_1
    else:
        TRAINER_CLASS = ZEWDPCTrainer_1

    trainer = TRAINER_CLASS(num_classes=6, use_pretrained=True)
    trainer.train(
        aggregated_dataset, num_epochs=1, validation_percentage=0.1, batch_size=5
    )
    predictions = trainer.predict(test_dataset)

    y_pred = trainer.predict(test_dataset)
    y_true = [sample["label"] for sample in test_dataset]

    ### Metrics
    metrics = get_zew_dpc_metrics(y_true=y_true, y_pred=y_pred)
    print(metrics)

3) Local evaluation class

This code chunk must be copied into the "_YourRepository/local_evaluation.py_" file. It runs the local evaluation for the 5 models, allowing you to run all of them and get results for you to compare the different strategies you are implementing.

If you are on a low budget laptop/PC like me, I have included 2 parameters that will allow you to switch to a Low memory mode that will either use a batch size of 32 instead of 64, and a maximum low memory mode, that will skip the unfreezed EfficientNet B4 models altogether. Keep an eye for these parameters and modify them according to your setup.

In [ ]:
##
## NOTE: The code below is only for illustration purposes and local runs.
##       ANY CHANGES HERE WILL NOT BE REFLECTED IN THE EVALUATION SETUP
##       PLEASE MAKE YOUR CHANGES IN `run.py`.
##
import os
import tempfile
import time
import numpy as np
from tqdm.auto import tqdm

from evaluator.dataset import ZEWDPCBaseDataset, ZEWDPCProtectedDataset
from evaluator.utils import instantiate_purchased_dataset
from evaluator.trainer import ZEWDPCTrainer_1, ZEWDPCTrainer_2, ZEWDPCTrainer_3, ZEWDPCTrainer_4, ZEWDPCTrainer_5, ZEWDPCDebugTrainer_1

from run import ZEWDPCBaseRun

import torch

####################################################################################
####################################################################################
##
## Dataset Initialization
## We have initialised it with debug images for faster onboarding, you can change
## the location to respective dataset splits after downloading them.
####################################################################################

####################################################################################
## DEFINING IF LOW MEMORY 
####################################################################################

LOW_MEMORY = True ########### CHANGE THIS TO FALSE IF YOU HAVE HIGH BANDWITH TO RUN ON BATCHSIZE = 64

if LOW_MEMORY:
    BATCHSIZE = 32
    LOW_MEMORY_MAX = True ########### CHANGE THIS TO FALSE IF YOU CAN RUN B4 UNFREEZED MODELS ON YOUR SETUP
else:
    BATCHSIZE=64
    LOW_MEMORY_MAX = False  ########### CHANGE THIS TO TRUE IF YOU CAN RUN BATCHSIZE=64 BUT NOT B4 UNFREEZED MODELS ON YOUR SETUP


####################################################################################
## RUN LOOP FOR LOCAL 
####################################################################################

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html

DATASET_SHUFFLE_SEED = 1022022

PURCHASE_BUDGET = 500
COMPUTE_BUDGET = 60 * 60 # 1 hour


datafolder = './data/v0.2-rc4'

# Instantiate Training Dataset
training_dataset = ZEWDPCBaseDataset(
    images_dir=f"{datafolder}/training/images",
    labels_path=f"{datafolder}/training/labels.csv",
    shuffle_seed=DATASET_SHUFFLE_SEED,
)
# Instantiate Unlabelled Dataset
unlabelled_dataset = ZEWDPCProtectedDataset(
    images_dir=f"{datafolder}/unlabelled/images",
    labels_path=f"{datafolder}/unlabelled/labels.csv",
    purchase_budget=PURCHASE_BUDGET,  # Configurable Parameter
    shuffle_seed=DATASET_SHUFFLE_SEED,
)
# Instantiate Validation Dataset
val_dataset = ZEWDPCBaseDataset(
    images_dir=f"{datafolder}/validation/images",
    labels_path=f"{datafolder}/validation/labels.csv",
    drop_labels=True,
    shuffle_seed=DATASET_SHUFFLE_SEED,
)
# A second instantiation of the validation test with the labels present
#       - helpful later, when computing the scores.
val_dataset_gt = ZEWDPCBaseDataset(
    images_dir=f"{datafolder}/validation/images",
    labels_path=f"{datafolder}/validation/labels.csv",
    drop_labels=False,
    shuffle_seed=DATASET_SHUFFLE_SEED,
)

# Location to save your checkpoint
checkpoint_folder_path = tempfile.TemporaryDirectory().name
### NOTE: This folder doesnot clean up itself.
###       You are responsible for cleaning up the contents of this folder after
##        the desired usage.


####################################################################################
####################################################################################
##
## Setup Compute & Purchase Budgets
####################################################################################
time_started = time.time()

####################################################################################
####################################################################################
##
## Phase 1 : Pre-Training Phase
####################################################################################
run = ZEWDPCBaseRun()
run.pre_training_phase(training_dataset, compute_budget=COMPUTE_BUDGET)

run.save_checkpoint(checkpoint_folder_path)
# NOTE:It is critical that the checkpointing works in a self-contained way
#       As, the evaluators might choose to run the different phases separately.

del run
time_available = COMPUTE_BUDGET - (time.time() - time_started)

####################################################################################
####################################################################################
##
## Phase 2 : Purchase Phase
####################################################################################
run = ZEWDPCBaseRun()
run.load_checkpoint(checkpoint_folder_path)

run.purchase_phase(
    unlabelled_dataset, training_dataset, purchase_budget=PURCHASE_BUDGET, compute_budget=time_available
)

run.save_checkpoint(checkpoint_folder_path)
del run

####################################################################################
####################################################################################
##
## Phase 3 : Post Purchase Training Phase
####################################################################################

# Create a runtime instance of the purchased dataset with the right labels
purchased_dataset = instantiate_purchased_dataset(unlabelled_dataset)
aggregated_dataset = torch.utils.data.ConcatDataset(
    [training_dataset, purchased_dataset]
)
print("Training Dataset Size : ", len(training_dataset))
print("Purchased Dataset Size : ", len(purchased_dataset))
print("Aggregataed Dataset Size : ", len(aggregated_dataset))

DEBUG_MODE = os.getenv("AICROWD_DEBUG_MODE", False)
if DEBUG_MODE:
    TRAINER_CLASS = ZEWDPCDebugTrainer_1
else:
    TRAINER_CLASS = ZEWDPCTrainer_1

####################################################################################
####################################################################################
##
## Train for Final Model #1
####################################################################################

print('###########################\nUsing BATCHSIZE=', BATCHSIZE,'\nUsing LOW_MEMORY_MAX=',LOW_MEMORY_MAX,'\n###########################')


print('###########################\nTraining for final Model #1\n(B4 + CosineLR + 25 epochs)\n###########################')

trainer = ZEWDPCTrainer_1(num_classes=6, use_pretrained=True)
trainer.train(
    aggregated_dataset, num_epochs=25, validation_percentage=0.1, batch_size=BATCHSIZE ## BATCHSIZE
)

y_pred = trainer.predict(val_dataset)
y_true = val_dataset_gt._get_all_labels()

####################################################################################
####################################################################################
##
## Phase 4 : Evaluation Phase
####################################################################################
from evaluator.evaluation_metrics import get_zew_dpc_metrics

metrics = get_zew_dpc_metrics(y_true, y_pred)

f1_score = metrics["F1_score_macro"]
accuracy_score = metrics["accuracy_score"]
hamming_loss_score = metrics["hamming_loss"]

print('##########################\nResults for final Model #1\n(B4 + CosineLR + 25 epochs)\n##########################')

print("F1 Score : ", f1_score)
print("Accuracy Score : ", accuracy_score)
print("Hamming Loss : ", hamming_loss_score)

del trainer, y_pred, f1_score, accuracy_score, hamming_loss_score

####################################################################################
####################################################################################
##
## Train for Final Model #2
####################################################################################

print('###########################\nTraining for final Model #2\n(B4 + CosineLR + 25 epochs + No Blur)\n###########################')

trainer = ZEWDPCTrainer_2(num_classes=6, use_pretrained=True)
trainer.train(
    aggregated_dataset, num_epochs=25, validation_percentage=0.1, batch_size=BATCHSIZE
)

y_pred = trainer.predict(val_dataset)
y_true = val_dataset_gt._get_all_labels()

####################################################################################
####################################################################################
##
## Phase 4 : Evaluation Phase
####################################################################################
from evaluator.evaluation_metrics import get_zew_dpc_metrics

metrics = get_zew_dpc_metrics(y_true, y_pred)

f1_score = metrics["F1_score_macro"]
accuracy_score = metrics["accuracy_score"]
hamming_loss_score = metrics["hamming_loss"]

print('##########################\nResults for final Model #2\n(B4 + CosineLR + 25 epochs + No Blur)\n##########################')

print("F1 Score : ", f1_score)
print("Accuracy Score : ", accuracy_score)
print("Hamming Loss : ", hamming_loss_score)

del trainer, y_pred, f1_score, accuracy_score, hamming_loss_score

####################################################################################
# LOW MEMORY REQUIREMENT?
####################################################################################

if LOW_MEMORY_MAX:
    print('###############################\nSkipping Unfreezed B4 Models due to low memory\n##############################')
else:
    ####################################################################################
    ####################################################################################
    ##
    ## Train for Final Model #3
    ####################################################################################

    print('###########################\nTraining for final Model #3\n(B4 + No Scheduler + 10 epochs + Unfreezed + No Blur)\n###########################')

    trainer = ZEWDPCTrainer_3(num_classes=6, use_pretrained=True)
    trainer.train(
        aggregated_dataset, num_epochs=10, validation_percentage=0.1, batch_size=BATCHSIZE
    )

    y_pred = trainer.predict(val_dataset)
    y_true = val_dataset_gt._get_all_labels()

    ####################################################################################
    ####################################################################################
    ##
    ## Phase 4 : Evaluation Phase
    ####################################################################################
    from evaluator.evaluation_metrics import get_zew_dpc_metrics

    metrics = get_zew_dpc_metrics(y_true, y_pred)

    f1_score = metrics["F1_score_macro"]
    accuracy_score = metrics["accuracy_score"]
    hamming_loss_score = metrics["hamming_loss"]

    print('##########################\nResults for final Model #3\n(B4 + No Scheduler + 10 epochs + Unfreezed + No Blur)\n##########################')

    print("F1 Score : ", f1_score)
    print("Accuracy Score : ", accuracy_score)
    print("Hamming Loss : ", hamming_loss_score)

    del trainer, y_pred, f1_score, accuracy_score, hamming_loss_score

    ####################################################################################
    ####################################################################################
    ##
    ## Train for Final Model #4
    ####################################################################################

    print('###########################\nTraining for final Model #4\n(B4 + No Scheduler + 10 epochs + Unfreezed)\n###########################')

    trainer = ZEWDPCTrainer_4(num_classes=6, use_pretrained=True)
    trainer.train(
        aggregated_dataset, num_epochs=10, validation_percentage=0.1, batch_size=BATCHSIZE
    )

    y_pred = trainer.predict(val_dataset)
    y_true = val_dataset_gt._get_all_labels()

    ####################################################################################
    ####################################################################################
    ##
    ## Phase 4 : Evaluation Phase
    ####################################################################################
    from evaluator.evaluation_metrics import get_zew_dpc_metrics

    metrics = get_zew_dpc_metrics(y_true, y_pred)

    f1_score = metrics["F1_score_macro"]
    accuracy_score = metrics["accuracy_score"]
    hamming_loss_score = metrics["hamming_loss"]

    print('##########################\nResults for final Model #4\n(B4 + No Scheduler + 10 epochs + Unfreezed)\n##########################')

    print("F1 Score : ", f1_score)
    print("Accuracy Score : ", accuracy_score)
    print("Hamming Loss : ", hamming_loss_score)

    del trainer, y_pred, f1_score, accuracy_score, hamming_loss_score

####################################################################################
####################################################################################
##
## Train for Final Model #5
####################################################################################

print('###########################\nTraining for final Model #5\n(B0 + No Scheduler + 25 epochs + Unfreezed\n###########################')

trainer = ZEWDPCTrainer_5(num_classes=6, use_pretrained=True)
trainer.train(
    aggregated_dataset, num_epochs=25, validation_percentage=0.1, batch_size=BATCHSIZE
)

y_pred = trainer.predict(val_dataset)
y_true = val_dataset_gt._get_all_labels()

####################################################################################
####################################################################################
##
## Phase 4 : Evaluation Phase
####################################################################################
from evaluator.evaluation_metrics import get_zew_dpc_metrics

metrics = get_zew_dpc_metrics(y_true, y_pred)

f1_score = metrics["F1_score_macro"]
accuracy_score = metrics["accuracy_score"]
hamming_loss_score = metrics["hamming_loss"]

print('##########################\nResults for final Model #5\n##########################')

print("F1 Score : ", f1_score)
print("Accuracy Score : ", accuracy_score)
print("Hamming Loss : ", hamming_loss_score)

del trainer, y_pred, f1_score, accuracy_score, hamming_loss_score

That's all folks 🐰

That is all for today, I hope you find some use on these classes!

If you happen to run into any issues, or doubts, do not hesitate to reach out!



If you found this insightful, please, remember to leave a 💝, and comment if you have any suggestion.


Comments

You must login before you can post a comment.

Execute