Salta al contenuto principale

Introduzione a PyTorch e TorchVision per il caricamento dei dati

Profile picture for user luca77king

Spesso la fase di caricamento e pre-elaborazione dei dati rappresenta una parte cruciale e talvolta complessa di qualsiasi progetto di machine learning. Fortunatamente, PyTorch, in combinazione con la libreria TorchVision, semplifica notevolmente questo processo, fornendo un'infrastruttura robusta e facilmente adattabile a diverse esigenze. In questo articolo, esploreremo come PyTorch e TorchVision gestiscono il caricamento dei dati, focalizzandoci su concetti chiave e best practice.

Il caricamento dei dati in PyTorch si basa principalmente sulla classe DataLoader. Questa classe agisce come un'interfaccia elegante e efficiente per iterare sui dati, garantendo che vengano caricati in batch di dimensioni opportune, ottimizzando così l'utilizzo della memoria e accelerando l'addestramento del modello. Un DataLoader richiede due elementi fondamentali: un dataset e un sampler.

Il dataset rappresenta la collezione di dati grezzi, e in PyTorch è rappresentato da una classe che eredita dalla classe base Dataset. Questa classe necessita di implementare due metodi: __len__, che restituisce il numero totale di esempi nel dataset, e __getitem__, che restituisce un singolo esempio dato un indice. Questo approccio è estremamente versatile, consentendo di creare dataset personalizzati per qualsiasi tipo di dato, da immagini a testo, a dati tabulari. La semplicità di questa interfaccia è uno dei punti di forza di PyTorch: la flessibilità è garantita dalla possibilità di definire la logica di accesso ai dati internamente alla classe Dataset.

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        """
        Constructor for the CustomDataset class.
        :param data: List or numpy array containing the data (e.g., images, text, etc.)
        :param labels: List or numpy array containing the labels corresponding to the data
        """
        self.data = data
        self.labels = labels

    def __len__(self):
        """
        Returns the total number of examples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns a single data point (and its label) from the dataset.
        :param idx: The index of the data point to return
        :return: A tuple containing the data and its label
        """
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# Example usage
data = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0]), torch.tensor([5.0, 6.0])]  # Sample data
labels = [0, 1, 0]  # Corresponding labels

# Create a dataset instance
dataset = CustomDataset(data, labels)

# Accessing data from the dataset
print(f"Total number of samples: {len(dataset)}")
sample, label = dataset[0]
print(f"First sample: {sample}, Label: {label}")

Il sampler, invece, controlla l'ordine in cui gli esempi vengono prelevati dal dataset. Il sampler più semplice è il SequentialSampler, che itera sugli esempi in ordine sequenziale. Tuttavia, per l'addestramento di modelli di deep learning, è spesso preferibile utilizzare un RandomSampler, che mescola casualmente l'ordine degli esempi ad ogni epoca, prevenendo un bias dovuto all'ordine dei dati e migliorando la generalizzazione del modello. Per problemi di classificazione con classi fortemente sbilanciate, si potrebbero usare sampler più sofisticati che sovracampionano le classi minoritarie o sottocampionano le classi maggioritarie per bilanciare il dataset.

sequential_sampler = torch.utils.data.SequentialSampler(dataset)
sequential_loader = DataLoader(dataset, sampler=sequential_sampler, batch_size=2)

print("Sequential Sampling:")
for batch_data, batch_labels in sequential_loader:
    print(batch_data, batch_labels)

# 2. RandomSampler: Campiona i dati in modo casuale
random_sampler = torch.utils.data.RandomSampler(dataset)
random_loader = DataLoader(dataset, sampler=random_sampler, batch_size=2)

print("\nRandom Sampling:")
for batch_data, batch_labels in random_loader:
    print(batch_data, batch_labels)

# 3. WeightedRandomSampler: Per affrontare classi sbilanciate
# Supponiamo di avere un dataset sbilanciato con molte più istanze della classe 0 rispetto alla classe 1
class_weights = [0.5 if label == 0 else 1.5 for label in labels]  # Pesi maggiori per la classe 1
weights = torch.tensor(class_weights, dtype=torch.float)

weighted_sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
weighted_loader = DataLoader(dataset, sampler=weighted_sampler, batch_size=2)

print("\nWeighted Random Sampling (for imbalanced classes):")
for batch_data, batch_labels in weighted_loader:
    print(batch_data, batch_labels)

TorchVision, una libreria costruita sopra PyTorch, si integra perfettamente con il DataLoader fornendo dataset già pronti all'uso per le attività di computer vision. Questi dataset, come ImageFolder, CIFAR10, MNIST, ecc., semplificano enormemente il processo di caricamento delle immagini, gestendo automaticamente il parsing delle immagini dai file, la trasformazione in tensori PyTorch e la gestione delle label. ImageFolder, ad esempio, si aspetta una struttura di cartelle in cui ogni sottocartella rappresenta una classe, e le immagini all'interno di ciascuna sottocartella appartengono a quella classe. Questa semplice convenzione permette di caricare un dataset di immagini in modo estremamente efficiente senza bisogno di scrivere codice complesso per la gestione dei file.

L'utilizzo di trasformazioni (transforms) è un altro aspetto fondamentale per il pre-processing dei dati. Le trasformazioni sono funzioni che vengono applicate agli esempi durante il caricamento, consentendo di normalizzare i dati, aumentarli e prepararli per l'addestramento del modello. In TorchVision, sono disponibili numerosi transforms predefiniti, come la ridimensionamento delle immagini (Resize), la conversione in scala di grigi (Grayscale), le rotazioni (RandomRotation), i ritagli (RandomCrop), e l'applicazione di flip orizzontali e verticali (RandomHorizontalFlip, RandomVerticalFlip). Combinando queste trasformazioni, è possibile creare una pipeline di pre-processing personalizzata per ottimizzare le prestazioni del modello. L'utilizzo di trasformazioni durante il caricamento è particolarmente efficiente in quanto distribuisce il carico computazionale su più core, sfruttando la potenza di elaborazione parallela.

Infine, la possibilità di utilizzare i worker num_workers nel DataLoader permette di caricare i dati in parallelo su più processi, ulteriormente accelerando il processo di caricamento. Questo parametro specifica il numero di processi worker che vengono utilizzati per caricare i dati in background mentre il modello viene addestrato, migliorando notevolmente l'efficienza, soprattutto con dataset di grandi dimensioni.

In conclusione, PyTorch e TorchVision forniscono un ecosistema completo e flessibile per il caricamento e la pre-elaborazione dei dati. L'utilizzo di DataLoader, Dataset, sampler, e delle trasformazioni di TorchVision, permette di creare pipeline di caricamento efficienti e personalizzate, adattate alle specifiche esigenze di ogni progetto. La semplicità e la potenza di questi strumenti consentono di dedicare più tempo allo sviluppo e all'ottimizzazione del modello, piuttosto che alla gestione dei dati stessi, un aspetto cruciale per il successo di qualsiasi progetto di deep learning. La comprensione approfondita di questi concetti rappresenta un passo fondamentale per padroneggiare PyTorch e affrontare con successo progetti di machine learning complessi.