Nel mondo dell'intelligenza artificiale (IA) e del machine learning, spesso si sente parlare di modelli grandi e modelli piccoli, con il termine "distillazione" che viene usato per descrivere un processo che rende un modello più potente e pesante in qualcosa di più piccolo e veloce. Se sei alle prime armi con l'IA, o se hai appena iniziato a esplorare il mondo dei modelli di linguaggio, potresti sentirti un po' confuso. Come funziona davvero la distillazione? E come può un modello più piccolo imparare da uno più grande?
In questo articolo, partiremo da una spiegazione semplice, adatta a chi non ha una conoscenza tecnica approfondita, per poi passare a una spiegazione più dettagliata e tecnica. Vedremo anche un esempio pratico di distillazione utilizzando BERT, uno dei modelli di linguaggio più famosi. Se sei un principiante, non preoccuparti: partiamo dalle basi, e successivamente approfondiremo i dettagli.
Immagina di avere un esperto che sa molto su un argomento. Questo esperto può essere un grande modello di intelligenza artificiale, come BERT o GPT. Questo esperto ha una memoria impressionante e può rispondere a quasi ogni domanda. Ma c'è un problema: l'esperto è così grande che non può essere utilizzato facilmente su dispositivi piccoli, come smartphone, o in situazioni in cui la velocità è importante. Inoltre, se vuoi che l'esperto risponda rapidamente, la sua lentezza può essere un problema.
Ora, pensa che tu voglia creare una versione più piccola e veloce dell'esperto, che però sappia comunque abbastanza per fare il lavoro. Questa versione più piccola si chiama un "modello distillato", ed è esattamente ciò che ottieni attraverso il processo di distillazione.
Come funziona la distillazione?
Abbiamo un modello teacher, che è il nostro esperto superintelligente. Questo modello ha imparato da tantissimi dati ed è in grado di rispondere a domande con grande precisione. Quando gli fai una domanda, il teacher non si limita a darti una risposta, ma fornisce un insieme di informazioni dettagliate, quasi come se dicesse: “Questa è la mia risposta principale, ma considero anche queste alternative e la loro probabilità”.
Dall’altra parte, abbiamo un modello student, che è il nostro "alunno". Questo modello è più piccolo e meno complesso, e non ha avuto accesso agli stessi dati e risorse del teacher. Per diventare bravo, però, lo student cerca di copiare il comportamento del teacher. Non si limita a imparare le risposte corrette come in una normale lezione, ma imita il modo in cui il teacher ragiona e distribuisce le sue probabilità sulle risposte. È come se lo student dicesse: “Non mi interessa solo sapere la risposta giusta, voglio anche capire come pensi”.
La distillazione vera e propria
Il cuore della distillazione sta in questo processo di imitazione, basato su un concetto chiamato soft target. Quando il modello teacher risponde a una domanda, non fornisce solo una risposta definitiva, ma una "distribuzione di probabilità" su tutte le possibili risposte. Questo significa che invece di dire "la risposta è A", il teacher potrebbe dire qualcosa come: “Credo che A sia corretta al 70%, B al 20%, e C al 10%”. Questa sfumatura è ciò che chiamiamo soft target, ed è un elemento chiave che lo student deve imparare a replicare.
Lo student non si limita a cercare la risposta esatta, ma impara a distribuire le probabilità in modo simile al teacher, cercando di "pensare" come lui. Allo stesso tempo, però, ha accesso anche alle risposte esatte, chiamate hard target, che sono la verità assoluta su cosa sia corretto. Combinando i soft target del teacher e gli hard target dei dati, lo student riesce a migliorare progressivamente le sue capacità, fino a diventare un’imitazione del teacher, ma molto più leggera e veloce.
Un esempio concreto: BERT come Teacher
Per rendere tutto più concreto, prendiamo un modello di linguaggio noto, come BERT (Bidirectional Encoder Representations from Transformers). BERT è uno dei modelli più avanzati per comprendere il linguaggio naturale. È estremamente bravo a capire il contesto delle frasi, a rispondere alle domande, a completare testi e persino a rilevare sfumature nel significato delle parole. Tuttavia, ha un grande problema: è enorme.
Le dimensioni di BERT (in termini di parametri e potenza di calcolo necessaria) lo rendono poco pratico per dispositivi con risorse limitate, come smartphone, smartwatch o qualsiasi dispositivo che non abbia accesso a un'infrastruttura potente come un server cloud. Per esempio, provare a usare BERT in tempo reale su un telefono comporterebbe ritardi significativi e un consumo eccessivo di batteria.
Allora, come possiamo portare l'intelligenza di BERT su dispositivi più leggeri? È qui che entra in gioco la distillazione.
Immaginiamo di voler creare una versione "mini" di BERT, qualcosa che possiamo utilizzare su dispositivi con poche risorse, senza perdere troppo in qualità delle risposte. Per farlo, prendiamo BERT come teacher, cioè il modello esperto, e creiamo un modello più piccolo e leggero che sarà il nostro student.
BERT, in quanto teacher, non si limita a dare risposte definitive, ma fornisce una distribuzione di probabilità su tutte le possibili risposte. Per esempio, se la domanda è "Qual è la capitale d’Italia?", invece di rispondere semplicemente "Roma", il modello assegna delle percentuali di probabilità a ciascuna opzione plausibile. Potrebbe dire:
- Roma: 90%
- Milano: 5%
- Firenze: 3%
- Torino: 2%
Questo approccio si chiama soft target. Non ci si concentra solo sulla risposta corretta, ma anche sul modo in cui il modello arriva a quella risposta, includendo le alternative e il grado di fiducia in ciascuna.
Grazie alla distillazione, lo student impara a imitare il pensiero del teacher, non solo replicando la risposta esatta, ma anche comprendendo il ragionamento dietro di essa.
Il modello student, invece, non si limita a imparare la risposta corretta, o hard target, ma cerca di imitare il modo in cui il teacher "pensa". Il suo obiettivo è replicare la distribuzione di probabilità fornita dal modello più grande, cercando di avvicinarsi non solo alle risposte, ma anche al processo decisionale che il teacher utilizza per arrivare a quelle risposte. In questo modo, il student non diventa un semplice ripetitore, ma un modello che emula il ragionamento del teacher.
Durante l’addestramento, il student utilizza sia i hard target (le risposte esatte fornite dal dataset) che i soft target (le probabilità distribuite sulle possibili risposte fornite dal modello teacher). Questa combinazione consente al modello più piccolo di bilanciare la precisione con la capacità di generalizzazione. Il risultato finale è un modello significativamente più leggero, ma che conserva buona parte della qualità del modello originale. Ad esempio, un modello distillato come DistilBERT riesce a mantenere circa il 97% delle performance di BERT, ma con la metà dei parametri e una velocità di esecuzione molto maggiore.
import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, DistilBertForSequenceClassification
from transformers import BertTokenizer, DistilBertTokenizer
from datasets import load_dataset
# Carica il modello teacher (BERT)
teacher_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
teacher_model.eval()
# Carica il modello student (DistilBERT)
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
# Tokenizer per entrambi i modelli
teacher_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
student_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Carica un dataset di esempio
dataset = load_dataset("imdb", split="train[:1%]") # Carica 1% del dataset IMDb
train_texts = dataset["text"]
train_labels = dataset["label"]
# Tokenizza i dati
teacher_encodings = teacher_tokenizer(train_texts, truncation=True, padding=True, max_length=128, return_tensors="pt")
student_encodings = student_tokenizer(train_texts, truncation=True, padding=True, max_length=128, return_tensors="pt")
# Carica i dati in DataLoader
class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}, torch.tensor(self.labels[idx])
train_dataset = IMDbDataset(teacher_encodings, train_labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# Funzione di perdita per la distillazione (combina soft e hard target)
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=2.0):
# Soft target loss (Kullback-Leibler divergence)
teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
student_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
soft_loss = torch.nn.functional.kl_div(student_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
# Hard target loss (CrossEntropy)
hard_loss = torch.nn.functional.cross_entropy(student_logits, labels)
# Combina le perdite
return alpha * soft_loss + (1 - alpha) * hard_loss
# Ottimizzatore
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
# Ciclo di addestramento
student_model.train()
for epoch in range(3): # Addestramento per 3 epoche
for batch, labels in train_loader:
optimizer.zero_grad()
# Forward pass per teacher e student
with torch.no_grad():
teacher_outputs = teacher_model(**batch)
teacher_logits = teacher_outputs.logits
student_outputs = student_model(**batch)
student_logits = student_outputs.logits
# Calcolo della perdita
loss = distillation_loss(student_logits, teacher_logits, labels)
# Backpropagation
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")
# Salvataggio del modello distillato
student_model.save_pretrained("distilled_student_model")
student_tokenizer.save_pretrained("distilled_student_model")
Conclusione
La distillazione dei modelli è una tecnica che permette di sfruttare al massimo le capacità dei modelli di machine learning, riducendone il peso e migliorandone l'efficienza. Questo approccio è fondamentale in un'epoca in cui i modelli di linguaggio stanno diventando sempre più complessi e computazionalmente costosi.
Grazie alla distillazione, possiamo ottenere modelli più piccoli e leggeri senza sacrificare troppo in termini di accuratezza. Questi modelli sono ideali per applicazioni real-time, dispositivi mobili e ambienti con risorse limitate. Inoltre, il processo insegna al modello student non solo a rispondere correttamente, ma anche a pensare come un modello teacher, replicandone le distribuzioni di probabilità e comportamenti.
Nell'esempio di BERT e DistilBERT, vediamo come questa tecnica riesca a produrre un modello che mantiene gran parte delle prestazioni del modello originale, ma con una significativa riduzione dei parametri e del tempo di inferenza. È un compromesso ideale per applicazioni pratiche, come chatbot, motori di ricerca, traduzioni in tempo reale, e molto altro.
La distillazione non è solo un'alternativa per ottimizzare modelli esistenti; rappresenta anche una direzione strategica per rendere l'intelligenza artificiale più accessibile e sostenibile. Mentre i modelli teacher come LLAMA e GPT continuano a spingere i limiti dell'innovazione, i modelli distillati portano queste capacità avanzate nelle mani di sviluppatori e utenti in tutto il mondo.
In futuro, possiamo aspettarci che la distillazione si evolva ulteriormente, combinandosi con altre tecniche di compressione e apprendimento per creare modelli ancora più performanti e leggeri. È un passo essenziale verso un'IA inclusiva e applicabile in tutti i contesti.