1. Set check interval (60 seconds)
2. check(), if its time to check, _run_checks()
3. _run_checks().
if memory_used > 90%*memory_total:
_preemptive_memory_cleanup()
if temp > 85:
throttle_training()
4. _preemptive_memory_cleanup()
save_checkpoint()
reduce_batch_size()
torch.cuda.empty_cache()
What is checkpoint?
A snapshot of the model's complete state at a specific point in training
We will restart from last checkpoint after OOM
checkpoint //Checkpoint includes
{
'epoch': 10, #Current epoch
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), #Optimizer state
'loss': 0.056, #Training metrics (loss, accuracy)
'batch_index': 3250 #batch index
}
5. _throttle_training(). Insert delay within training
- Reduces GPU utilization by spacing out workloads
- Allows cooling between computations
- Automatically adjusts based on temperature readings
|
batch_delay = 0 # Seconds between batches
def train_loop():
global batch_delay
for batch in dataloader:
# Process the batch
train_step(batch)
# Apply dynamic delay
time.sleep(batch_delay) # HealthChecker modifies this value
def save_checkpoint():
checkpoint = {
'epoch': current_epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': current_loss,
'timestamp': datetime.now().isoformat()
}
torch.save(checkpoint, f"checkpoint_epoch{current_epoch}.pt")
logging.info(f"Checkpoint saved at epoch {current_epoch}")
class HealthChecker:
def __init__(self, interval=60): #1
self.interval = interval
self.last_check = time.time()
def check(self):
now = time.time()
if now - self.last_check > self.interval:
self._run_checks()
self.last_check = now
def _run_checks(self):
# Memory health
mem_used = torch.cuda.memory_allocated()
mem_total = torch.cuda.get_device_properties(0).total_memory
if mem_used > 0.9 * mem_total:
self._preemptive_memory_cleanup()
# Temperature check
temp = get_gpu_temperature()
if temp > 85:
self._throttle_training()
def _preemptive_memory_cleanup(self):
# Save checkpoint and restart with smaller batch
save_checkpoint()
reduce_batch_size()
torch.cuda.empty_cache()
def _throttle_training(self):
# Insert cooling delays between batches
global batch_delay
batch_delay += 0.1
HealthChecker h
h.check()
|