Are you sure you want to delete this access key?
Pre-requisites: Training in SG, Training with Configuration Files
Knowledge distillation is a technique in deep learning that aims to transfer the knowledge of a large, pre-trained neural network model (the "teacher") to a smaller, more computationally efficient model (the "student"). This is accomplished by training the student to mimic the teacher's predictions and the ground-truth labels. The student network can also be designed to have a different architecture from the teacher, making it possible to distill the knowledge of a complex teacher network into a lighter and faster student network for deployment in real-world applications.
The training flow with Knowledge distillation in SG is similar to regular training. For standard training, we used SGs Trainer
class - which was in charge of training the model, evaluating test data, making predictions, and saving checkpoints.
Equivalently, for knowledge distillation, we use the KDTrainer
class which inherits from Trainer
.
If for regular training with Trainer
, the general flow is:
...
trainer = Trainer("my_experiment")
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {...}
trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)
Then for training with knowledge distillation, the general flow is:
from super_gradients.training.kd_trainer import KDTrainer
...
kd_trainer = KDTrainer("my_experiment")
train_dataloader = ...
valid_dataloader = ...
student_model = ...
teacher_model = ...
train_params = {...}
kd_trainer.train(student=student_model, teacher=teacher_model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)
Check out our knowledge distillation tutorial notebook to see a practical example.
The most apparent difference in the training flow using knowledge distillation is that it requires two networks: the "teacher" and the "student".
The relation between the two is also configurable - for example, we may decide that the teacher model should preprocess the inputs differently.
For that matter, SG introduces a new torch.nn.Module
that wraps both the student and the teacher models: KDModule
.
Upon calling KDTrainer.train()
, the teacher and student models are passed along the kd_arch_params
to initialize a KDModule
instance.
Passing a KDModule
instance explicitly to KDTrainer.train()
through the model
argument instead of student and teacher models is also possible, which gives the users the option to customize KD to their needs.
A high-level example of KD customization:
import torch.nn
from super_gradients.training.kd_trainer import KDTrainer
...
class MyKDModule(KDModule):
...
def forward(self, x: torch.Tensor)->KDOutput:
intermediate_output_student = self.student.extract_intermediate_output(x, layer_ids=[1, 3, -1])
intermediate_output_teacher = self.teacher.extract_intermediate(x, layer_ids=[1, 3, -1])
return KDOutput(student_output=intermediate_output_student, teacher_output=intermediate_output_teacher)
class MyKDLoss(torch.nn.Module):
...
def forward(self, preds: KDOutput, target: torch.Tensor):
# does something with the intermediate outputs
...
kd_trainer = KDTrainer("my_customized_kd_experiment")
train_dataloader = ...
valid_dataloader = ...
student_model = ...
teacher_model = ...
kd_model = MyKDModule(student=student_model, teacher=teacher_model)
train_params = {'loss': MyKDLoss(),
...}
kd_trainer.train(model=kd_model, training_params=train_params,
train_loader=train_dataloader, valid_loader=valid_dataloader)
KDOutput
defines the structure of the output of KDModule
and has two self-explanatory attributes: student_output and teacher_output.
KDTrainer
uses these attributes behind the scenes to perform the usual operations of regular training, such as metrics calculations.
This means that when customizing KD, it's essential for the custom KDModule
to stick to this output format.
Currently, KDLogitsLoss is currently the only supported loss function in SGs KD losses bank, but more is to come.
Note that during KD training, the KDModule
outputs (which are of KDOutput
instance) are passed to the loss's forward method as predictions.
Checkpointing during KD training is generally the same as checkpointing without KD. Nevertheless, there are a few differences worth mentioning:
ckpt_latest.pth
contains the state dict of the entire KDModule
.ckpt_best.pth
contains the state dict of the student only.ckpt_best.pth
s net
entry holds the EMA network.As done when training without knowledge distillation, to train with configuration files, we call the KDTrainer.train_from_config
method, which assumes a specific configuration structure.
When training with KD, the same structure and required fields hold, but we introduce a few additions:
arch_params
are being passed to the KDModule
constructor. For example, in our Resnet50 KD training on Imagenet, we handle the difference in preprocessing of the teacher, which expects different normalization by passing the KDModule
a normalization adaptor module:# super_gradients/recipes/imagenet_resnet50_kd.yaml
...
arch_params:
teacher_input_adapter:
_target_: super_gradients.training.utils.kd_trainer_utils.NormalizationAdapter
mean_original: [0.485, 0.456, 0.406]
std_original: [0.229, 0.224, 0.225]
mean_required: [0.5, 0.5, 0.5]
std_required: [0.5, 0.5, 0.5]
Warning: Remember to distinguish the arch params being passed to the KDModule constructor from the student ones.
student_architecture
, teacher_architecture
, student_arch_params
, student_checkpoint_params
, teacher_arch_params
, and teacher_checkpoint_params
play the same role as architecture
, arch_params
and checkpoint_params
for instantiating our model in non-KD training, and are being passed to models.get(...)
to instantiate the teacher and the student:...
student_architecture: resnet50
teacher_architecture: beit_base_patch16_224
student_arch_params:
num_classes: 1000
teacher_arch_params:
num_classes: 1000
image_size: [224, 224]
patch_size: [16, 16]
teacher_checkpoint_params:
...
pretrained_weights: imagenet
student_checkpoint_params:
...
Any KD recipe can be launched with our train_from_kd_recipe_example script.
Press p or to see the previous file or, n or to see the next file
Browsing data directories saved to S3 is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with AWS S3!
Are you sure you want to delete this access key?
Browsing data directories saved to Google Cloud Storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with Google Cloud Storage!
Are you sure you want to delete this access key?
Browsing data directories saved to Azure Cloud Storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with Azure Cloud Storage!
Are you sure you want to delete this access key?
Browsing data directories saved to S3 compatible storage is possible with DAGsHub. Let's configure your repository to easily display your data in the context of any commit!
super-gradients is now integrated with your S3 compatible storage!
Are you sure you want to delete this access key?