319 lines
10 KiB
Python
319 lines
10 KiB
Python
"""
|
|
Utility functions for AI Trainer
|
|
Memory management, logging, and helper functions optimized for RTX3070 8GB VRAM
|
|
"""
|
|
|
|
import gc
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple, Any
|
|
|
|
import torch
|
|
import psutil
|
|
from colorama import init, Fore, Back, Style
|
|
|
|
# Initialize colorama for cross-platform colored output
|
|
init(autoreset=True)
|
|
|
|
|
|
def setup_logging(log_level: str = "INFO", log_file: Optional[str] = None) -> logging.Logger:
|
|
"""
|
|
Setup logging configuration with colored console output
|
|
|
|
Args:
|
|
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
|
log_file: Optional log file path
|
|
|
|
Returns:
|
|
Configured logger
|
|
"""
|
|
# Create formatter with colors
|
|
class ColoredFormatter(logging.Formatter):
|
|
COLORS = {
|
|
'DEBUG': Fore.CYAN,
|
|
'INFO': Fore.GREEN,
|
|
'WARNING': Fore.YELLOW,
|
|
'ERROR': Fore.RED,
|
|
'CRITICAL': Fore.RED + Back.WHITE
|
|
}
|
|
|
|
def format(self, record):
|
|
# Add color to the level name
|
|
if record.levelname in self.COLORS:
|
|
colored_levelname = f"{self.COLORS[record.levelname]}{record.levelname}{Style.RESET_ALL}"
|
|
record.levelname = colored_levelname
|
|
return super().format(record)
|
|
|
|
# Create logger
|
|
logger = logging.getLogger()
|
|
logger.setLevel(getattr(logging, log_level.upper()))
|
|
|
|
# Console handler with colors
|
|
console_handler = logging.StreamHandler(sys.stdout)
|
|
console_formatter = ColoredFormatter(
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
console_handler.setFormatter(console_formatter)
|
|
logger.addHandler(console_handler)
|
|
|
|
# File handler if specified
|
|
if log_file:
|
|
log_path = Path(log_file)
|
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
file_handler = logging.FileHandler(log_path)
|
|
file_formatter = logging.Formatter(
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
file_handler.setFormatter(file_formatter)
|
|
logger.addHandler(file_handler)
|
|
|
|
return logger
|
|
|
|
|
|
def check_gpu_memory() -> Dict[str, Any]:
|
|
"""
|
|
Check GPU memory status and availability
|
|
|
|
Returns:
|
|
Dictionary with GPU memory information
|
|
"""
|
|
if not torch.cuda.is_available():
|
|
return {"error": "CUDA not available"}
|
|
|
|
try:
|
|
device = torch.cuda.current_device()
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory
|
|
allocated_memory = torch.cuda.memory_allocated(device)
|
|
reserved_memory = torch.cuda.memory_reserved(device)
|
|
free_memory = total_memory - allocated_memory
|
|
|
|
return {
|
|
"device": torch.cuda.get_device_name(device),
|
|
"device_id": device,
|
|
"total_memory_gb": round(total_memory / (1024**3), 2),
|
|
"allocated_memory_gb": round(allocated_memory / (1024**3), 2),
|
|
"reserved_memory_gb": round(reserved_memory / (1024**3), 2),
|
|
"free_memory_gb": round(free_memory / (1024**3), 2),
|
|
"memory_utilization": round((allocated_memory / total_memory) * 100, 2),
|
|
"cuda_version": torch.version.cuda,
|
|
"cudnn_version": torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"
|
|
}
|
|
except Exception as e:
|
|
return {"error": f"Failed to get GPU info: {str(e)}"}
|
|
|
|
|
|
def get_memory_usage() -> Dict[str, float]:
|
|
"""
|
|
Get system memory usage
|
|
|
|
Returns:
|
|
Dictionary with memory usage information
|
|
"""
|
|
try:
|
|
# GPU memory
|
|
gpu_memory = check_gpu_memory()
|
|
|
|
# System memory
|
|
system_memory = psutil.virtual_memory()
|
|
|
|
return {
|
|
"gpu_total_gb": gpu_memory.get("total_memory_gb", 0),
|
|
"gpu_allocated_gb": gpu_memory.get("allocated_memory_gb", 0),
|
|
"gpu_free_gb": gpu_memory.get("free_memory_gb", 0),
|
|
"system_total_gb": round(system_memory.total / (1024**3), 2),
|
|
"system_available_gb": round(system_memory.available / (1024**3), 2),
|
|
"system_used_gb": round(system_memory.used / (1024**3), 2),
|
|
"system_memory_percent": system_memory.percent
|
|
}
|
|
except Exception as e:
|
|
return {"error": f"Failed to get memory usage: {str(e)}"}
|
|
|
|
|
|
def clear_gpu_cache():
|
|
"""Clear GPU cache and perform garbage collection"""
|
|
try:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
# Force garbage collection
|
|
gc.collect()
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Failed to clear GPU cache: {str(e)}")
|
|
|
|
|
|
def optimize_memory_settings():
|
|
"""Apply memory optimization settings for RTX3070"""
|
|
try:
|
|
if torch.cuda.is_available():
|
|
# Set memory fraction to prevent out-of-memory
|
|
torch.cuda.set_per_process_memory_fraction(0.85) # Use 85% of GPU memory
|
|
|
|
# Enable TF32 for better performance
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# Optimize CUDA memory allocator
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Failed to optimize memory settings: {str(e)}")
|
|
|
|
|
|
def format_bytes(bytes_value: int) -> str:
|
|
"""
|
|
Format bytes into human readable format
|
|
|
|
Args:
|
|
bytes_value: Number of bytes
|
|
|
|
Returns:
|
|
Formatted string (e.g., "1.5 GB")
|
|
"""
|
|
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
|
if bytes_value < 1024.0:
|
|
return ".1f"
|
|
bytes_value /= 1024.0
|
|
return ".1f"
|
|
|
|
|
|
def print_system_info():
|
|
"""Print comprehensive system information"""
|
|
print(f"\n{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
|
print(f"{Fore.CYAN}SYSTEM INFORMATION{Style.RESET_ALL}")
|
|
print(f"{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
|
|
|
# GPU Information
|
|
gpu_info = check_gpu_memory()
|
|
if "error" not in gpu_info:
|
|
print(f"\n{Fore.GREEN}GPU Information:{Style.RESET_ALL}")
|
|
print(f" Device: {gpu_info['device']}")
|
|
print(f" CUDA Version: {gpu_info['cuda_version']}")
|
|
print(f" Total Memory: {gpu_info['total_memory_gb']} GB")
|
|
print(f" Allocated Memory: {gpu_info['allocated_memory_gb']} GB")
|
|
print(f" Free Memory: {gpu_info['free_memory_gb']} GB")
|
|
print(f" Memory Utilization: {gpu_info['memory_utilization']}%")
|
|
else:
|
|
print(f"\n{Fore.RED}GPU Information: {gpu_info['error']}{Style.RESET_ALL}")
|
|
|
|
# System Memory
|
|
system_memory = psutil.virtual_memory()
|
|
print(f"\n{Fore.GREEN}System Memory:{Style.RESET_ALL}")
|
|
print(f" Total: {format_bytes(system_memory.total)}")
|
|
print(f" Available: {format_bytes(system_memory.available)}")
|
|
print(f" Used: {format_bytes(system_memory.used)}")
|
|
print(f" Usage: {system_memory.percent}%")
|
|
|
|
# CPU Information
|
|
print(f"\n{Fore.GREEN}CPU Information:{Style.RESET_ALL}")
|
|
print(f" Cores: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count(logical=True)} logical")
|
|
print(f" CPU Usage: {psutil.cpu_percent()}%")
|
|
|
|
print(f"\n{Fore.CYAN}{'='*60}{Style.RESET_ALL}")
|
|
|
|
|
|
def validate_environment():
|
|
"""Validate that the environment is suitable for training"""
|
|
issues = []
|
|
|
|
# Check CUDA availability
|
|
if not torch.cuda.is_available():
|
|
issues.append("CUDA is not available. A CUDA-compatible GPU is required.")
|
|
|
|
# Check GPU memory
|
|
if torch.cuda.is_available():
|
|
gpu_info = check_gpu_memory()
|
|
if "total_memory_gb" in gpu_info:
|
|
total_memory = gpu_info["total_memory_gb"]
|
|
if total_memory < 8:
|
|
issues.append(f"GPU memory ({total_memory} GB) may be insufficient. Recommended: 8GB+")
|
|
|
|
# Check required Python modules
|
|
required_modules = ['torch', 'transformers', 'datasets', 'git']
|
|
for module in required_modules:
|
|
try:
|
|
__import__(module)
|
|
except ImportError:
|
|
issues.append(f"Required module '{module}' is not installed.")
|
|
|
|
if issues:
|
|
print(f"\n{Fore.YELLOW}Environment Validation Issues:{Style.RESET_ALL}")
|
|
for issue in issues:
|
|
print(f" - {issue}")
|
|
return False
|
|
|
|
print(f"\n{Fore.GREEN}Environment validation passed!{Style.RESET_ALL}")
|
|
return True
|
|
|
|
|
|
def create_training_summary(config, training_time: float, final_model_path: str) -> str:
|
|
"""
|
|
Create a summary of the training session
|
|
|
|
Args:
|
|
config: Training configuration
|
|
training_time: Training time in seconds
|
|
final_model_path: Path to the saved model
|
|
|
|
Returns:
|
|
Formatted summary string
|
|
"""
|
|
summary = ".1f"".2f"f"""
|
|
{Fore.CYAN}{'='*60}{Style.RESET_ALL}
|
|
TRAINING SUMMARY
|
|
{Fore.CYAN}{'='*60}{Style.RESET_ALL}
|
|
|
|
Configuration:
|
|
Model: {config.model.name}
|
|
Epochs: {config.training.num_train_epochs}
|
|
Batch Size: {config.training.per_device_train_batch_size}
|
|
Gradient Accumulation: {config.training.gradient_accumulation_steps}
|
|
Learning Rate: {config.training.learning_rate}
|
|
Max Sequence Length: {config.model.max_seq_length}
|
|
|
|
Performance:
|
|
Training Time: {training_time:.2f} seconds ({training_time/3600:.2f} hours)
|
|
Effective Batch Size: {config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps}
|
|
|
|
Output:
|
|
Model Saved To: {final_model_path}
|
|
|
|
Memory Settings:
|
|
Gradient Checkpointing: {config.training.use_gradient_checkpointing}
|
|
CPU Offloading: {config.training.offload_to_cpu}
|
|
BF16 Enabled: {config.training.bf16}
|
|
|
|
{Fore.CYAN}{'='*60}{Style.RESET_ALL}
|
|
"""
|
|
|
|
return summary
|
|
|
|
|
|
def safe_import(module_name: str, fallback: Any = None):
|
|
"""
|
|
Safely import a module with fallback
|
|
|
|
Args:
|
|
module_name: Name of the module to import
|
|
fallback: Fallback value if import fails
|
|
|
|
Returns:
|
|
Imported module or fallback
|
|
"""
|
|
try:
|
|
return __import__(module_name)
|
|
except ImportError:
|
|
return fallback
|
|
|
|
|
|
# Initialize memory optimization settings on import
|
|
try:
|
|
optimize_memory_settings()
|
|
except Exception:
|
|
pass # Ignore errors during initialization |