|
@@ -10,10 +10,11 @@ from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
import torch
|
|
import torch
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
from super_gradients.common.abstractions.abstract_logger import get_logger
|
|
|
|
+from super_gradients.training import models
|
|
from super_gradients.training.utils.callbacks import Phase, PhaseCallback, PhaseContext
|
|
from super_gradients.training.utils.callbacks import Phase, PhaseCallback, PhaseContext
|
|
import os
|
|
import os
|
|
from enum import Enum
|
|
from enum import Enum
|
|
-from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model, read_ckpt_state_dict
|
|
|
|
|
|
+from super_gradients.training.utils.checkpoint_utils import load_checkpoint_to_model
|
|
from super_gradients.training.utils import get_param
|
|
from super_gradients.training.utils import get_param
|
|
from super_gradients.training.utils.distributed_training_utils import get_local_rank, \
|
|
from super_gradients.training.utils.distributed_training_utils import get_local_rank, \
|
|
get_world_size
|
|
get_world_size
|
|
@@ -261,21 +262,6 @@ class QATCallback(PhaseCallback):
|
|
|
|
|
|
def __call__(self, context: PhaseContext):
|
|
def __call__(self, context: PhaseContext):
|
|
if context.epoch == self.start_epoch:
|
|
if context.epoch == self.start_epoch:
|
|
- # SET CHECKPOINT PARAMS SO WE LOAD THE BEST CHECKPOINT SO FAR
|
|
|
|
- checkpoint_params_qat = context.checkpoint_params.to_dict()
|
|
|
|
- checkpoint_params_qat['ckpt_name'] = 'ckpt_best.pth'
|
|
|
|
-
|
|
|
|
- if self.calibrated_model_path is not None:
|
|
|
|
- checkpoint_params_qat['external_checkpoint_path'] = self.calibrated_model_path
|
|
|
|
- checkpoint_params_qat['load_ema_as_net'] = 'ema_net' in read_ckpt_state_dict(self.calibrated_model_path).keys()
|
|
|
|
- checkpoint_params_qat['load_checkpoint'] = True
|
|
|
|
- elif self.start_epoch > 0:
|
|
|
|
- checkpoint_params_qat['load_ema_as_net'] = context.training_params.ema
|
|
|
|
- checkpoint_params_qat['load_checkpoint'] = True
|
|
|
|
- if checkpoint_params_qat['load_ema_as_net']:
|
|
|
|
- logger.warning("EMA net loaded from best checkpoint, continuing QAT without EMA.")
|
|
|
|
- context.context_methods.set_ema(False)
|
|
|
|
-
|
|
|
|
# REMOVE REFERENCES TO NETWORK AND CLEAN GPU MEMORY BEFORE BUILDING THE NEW NET
|
|
# REMOVE REFERENCES TO NETWORK AND CLEAN GPU MEMORY BEFORE BUILDING THE NEW NET
|
|
context.context_methods.set_net(None)
|
|
context.context_methods.set_net(None)
|
|
context.net = None
|
|
context.net = None
|
|
@@ -283,9 +269,14 @@ class QATCallback(PhaseCallback):
|
|
|
|
|
|
# BUILD THE SAME MODEL BUT WITH FAKE QUANTIZED MODULES, AND LOAD BEST CHECKPOINT TO IT
|
|
# BUILD THE SAME MODEL BUT WITH FAKE QUANTIZED MODULES, AND LOAD BEST CHECKPOINT TO IT
|
|
self._initialize_quant_modules()
|
|
self._initialize_quant_modules()
|
|
- context.context_methods.build_model(architecture=context.architecture,
|
|
|
|
- arch_params=context.arch_params.to_dict(),
|
|
|
|
- checkpoint_params=checkpoint_params_qat)
|
|
|
|
|
|
+
|
|
|
|
+ if self.calibrated_model_path is not None:
|
|
|
|
+ checkpoint_path = self.calibrated_model_path
|
|
|
|
+ elif self.start_epoch > 0:
|
|
|
|
+ checkpoint_path = os.path.join(context.ckpt_dir, 'ckpt_best.pth')
|
|
|
|
+
|
|
|
|
+ qat_net = models.get(context.architecture, arch_params=context.arch_params.to_dict(), checkpoint_path=checkpoint_path)
|
|
|
|
+
|
|
_deactivate_quant_modules_wrapping()
|
|
_deactivate_quant_modules_wrapping()
|
|
|
|
|
|
# UPDATE CONTEXT'S NET REFERENCE
|
|
# UPDATE CONTEXT'S NET REFERENCE
|
|
@@ -300,6 +291,9 @@ class QATCallback(PhaseCallback):
|
|
# SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
|
|
# SET NEW FILENAME FOR THE BEST CHECKPOINT SO WE DON'T OVERRIDE THE PREVIOUS ONES
|
|
context.context_methods.set_ckpt_best_name('qat_ckpt_best.pth')
|
|
context.context_methods.set_ckpt_best_name('qat_ckpt_best.pth')
|
|
|
|
|
|
|
|
+ # FINALLY, SET THE QAT NET TO CONTINUE TRAINING
|
|
|
|
+ context.context_methods.set_net(qat_net)
|
|
|
|
+
|
|
def _calibrate_model(self, context: PhaseContext):
|
|
def _calibrate_model(self, context: PhaseContext):
|
|
"""
|
|
"""
|
|
Performs model calibration (collecting stats and setting amax for the fake quantized moduls)
|
|
Performs model calibration (collecting stats and setting amax for the fake quantized moduls)
|