Source code for gaia.callbacks

"""GAIA Training Callbacks"""

from typing import List, Callable, Any, Dict
import logging

logger = logging.getLogger(__name__)

[docs] class CallbackManager: """Manages training callbacks for GAIA trainer"""
[docs] def __init__(self, callbacks: List[Callable] = None): self.callbacks = callbacks or []
[docs] def on_train_begin(self, logs: Dict[str, Any] = None): """Called at the beginning of training""" for callback in self.callbacks: if hasattr(callback, 'on_train_begin'): callback.on_train_begin(logs)
[docs] def on_train_end(self, logs: Dict[str, Any] = None): """Called at the end of training""" for callback in self.callbacks: if hasattr(callback, 'on_train_end'): callback.on_train_end(logs)
[docs] def on_train_error(self, logs: Dict[str, Any] = None, error: Exception = None): """Called when training encounters an error""" for callback in self.callbacks: if hasattr(callback, 'on_train_error'): callback.on_train_error(logs, error)
[docs] def on_epoch_begin(self, epoch: int, logs: Dict[str, Any] = None): """Called at the beginning of each epoch""" for callback in self.callbacks: if hasattr(callback, 'on_epoch_begin'): callback.on_epoch_begin(epoch, logs)
[docs] def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = None): """Called at the end of each epoch""" for callback in self.callbacks: if hasattr(callback, 'on_epoch_end'): callback.on_epoch_end(epoch, logs)
[docs] def on_batch_begin(self, batch: int, logs: Dict[str, Any] = None): """Called at the beginning of each batch""" for callback in self.callbacks: if hasattr(callback, 'on_batch_begin'): callback.on_batch_begin(batch, logs)
[docs] def on_batch_end(self, batch: int, logs: Dict[str, Any] = None): """Called at the end of each batch""" for callback in self.callbacks: if hasattr(callback, 'on_batch_end'): callback.on_batch_end(batch, logs)
[docs] class EarlyStopping: """Early stopping callback"""
[docs] def __init__(self, patience: int = 10, min_delta: float = 0.0, monitor: str = 'val_loss'): self.patience = patience self.min_delta = min_delta self.monitor = monitor self.best_score = None self.counter = 0 self.early_stop = False
[docs] def on_epoch_end(self, epoch: int, logs: Dict[str, Any] = None): if logs is None: return current_score = logs.get(self.monitor) if current_score is None: return if self.best_score is None: self.best_score = current_score elif current_score < self.best_score - self.min_delta: self.best_score = current_score self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True logger.info(f"Early stopping triggered at epoch {epoch}")
__all__ = ['CallbackManager', 'EarlyStopping']