GPU Health Check

- if memory > 90%. Release memory
- if temprature > 85. Add batch delay
1. Set check interval (60 seconds)

2. check(), if its time to check, _run_checks()

3. _run_checks().

if memory_used > 90%*memory_total:
    _preemptive_memory_cleanup()

if temp > 85:
    throttle_training()


4. _preemptive_memory_cleanup()

save_checkpoint()
reduce_batch_size()
torch.cuda.empty_cache()

What is checkpoint?
  A snapshot of the model's complete state at a specific point in training   We will restart from last checkpoint after OOM

checkpoint      //Checkpoint includes
{
    'epoch': 10,                                        #Current epoch
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),     #Optimizer state
    'loss': 0.056,                                      #Training metrics (loss, accuracy)
    'batch_index': 3250                                 #batch index
}


5. _throttle_training(). Insert delay within training
- Reduces GPU utilization by spacing out workloads
- Allows cooling between computations
- Automatically adjusts based on temperature readings

batch_delay = 0  # Seconds between batches
def train_loop():
    global batch_delay
    for batch in dataloader:
        # Process the batch
        train_step(batch)
        # Apply dynamic delay
        time.sleep(batch_delay)  # HealthChecker modifies this value

def save_checkpoint():
    checkpoint = {
        'epoch': current_epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss': current_loss,
        'timestamp': datetime.now().isoformat()
    }
    torch.save(checkpoint, f"checkpoint_epoch{current_epoch}.pt")
    logging.info(f"Checkpoint saved at epoch {current_epoch}")

class HealthChecker:
def __init__(self, interval=60):        #1
    self.interval = interval
    self.last_check = time.time()
    
def check(self):
    now = time.time()
    if now - self.last_check > self.interval:
        self._run_checks()
        self.last_check = now

def _run_checks(self):
    # Memory health
    mem_used = torch.cuda.memory_allocated()
    mem_total = torch.cuda.get_device_properties(0).total_memory
    if mem_used > 0.9 * mem_total:
        self._preemptive_memory_cleanup()
        
    # Temperature check
    temp = get_gpu_temperature()
    if temp > 85:
        self._throttle_training()

def _preemptive_memory_cleanup(self):
    # Save checkpoint and restart with smaller batch
    save_checkpoint()
    reduce_batch_size()
    torch.cuda.empty_cache()
    
def _throttle_training(self):
    # Insert cooling delays between batches
    global batch_delay
    batch_delay += 0.1

HealthChecker h
h.check()