ai_html_document_trainer/train_model.py
2025-08-22 16:30:56 +07:00

227 lines
7.2 KiB
Python

import torch
import json
import os
from typing import Dict, List
from datasets import Dataset
from data_preprocessor import DataPreprocessor
import gc
# Import Unsloth components
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TrainingArguments
from trl import SFTTrainer
class OdooModelTrainer:
def __init__(self):
self.model_name = "unsloth/Qwen3-8B-bnb-4bit"
self.max_seq_length = 2048
self.load_in_4bit = True
self.model = None
self.tokenizer = None
def load_model(self):
"""Load the Qwen model with Unsloth optimizations"""
print("Loading model...")
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.model_name,
max_seq_length=self.max_seq_length,
dtype=None, # Auto detection
load_in_4bit=self.load_in_4bit,
)
# Enable gradient checkpointing for memory efficiency
self.model.gradient_checkpointing_enable()
print("Model loaded successfully!")
def prepare_data(self, data_file: str = 'training_data.json') -> Dataset:
"""Prepare training data"""
if not os.path.exists(data_file):
print(f"Data file {data_file} not found. Running data preprocessing...")
preprocessor = DataPreprocessor()
training_data = preprocessor.process_csv_data()
preprocessor.save_training_data(training_data, data_file)
print(f"Loading data from {data_file}")
with open(data_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# Convert to HuggingFace Dataset format
dataset_dict = {
'instruction': [item['instruction'] for item in data],
'input': [item['input'] for item in data],
'output': [item['output'] for item in data]
}
dataset = Dataset.from_dict(dataset_dict)
print(f"Prepared dataset with {len(dataset)} samples")
return dataset
def format_chat_template(self, example):
"""Format data for chat template"""
messages = [
{"role": "system", "content": "You are a helpful assistant specialized in Odoo documentation."},
{"role": "user", "content": f"{example['instruction']}\n\n{example['input']}"},
{"role": "assistant", "content": example['output']}
]
# Apply chat template
formatted_text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
return {"text": formatted_text}
def train(self, dataset: Dataset, output_dir: str = './odoo_model_output'):
"""Train the model"""
print("Starting training...")
# Apply chat template to dataset
formatted_dataset = dataset.map(self.format_chat_template)
# Configure LoRA for efficient training
self.model = FastLanguageModel.get_peft_model(
self.model,
r=16, # LoRA rank
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing=True,
random_state=3407,
use_rslora=False,
loftq_config=None,
)
# Configure training arguments optimized for RTX3070 8GB VRAM
training_args = TrainingArguments(
per_device_train_batch_size=1, # Very small batch size for 8GB VRAM
gradient_accumulation_steps=4, # Effective batch size of 4
warmup_steps=5,
max_steps=100, # Limit steps for testing, increase as needed
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir=output_dir,
save_steps=50,
save_total_limit=2,
report_to="none", # Disable wandb/tensorboard for simplicity
)
# Initialize trainer
trainer = SFTTrainer(
model=self.model,
tokenizer=self.tokenizer,
train_dataset=formatted_dataset,
dataset_text_field="text",
max_seq_length=self.max_seq_length,
dataset_num_proc=2,
packing=False, # Can make training faster for short sequences
args=training_args,
)
# Clear cache before training
gc.collect()
torch.cuda.empty_cache()
print("Starting training...")
trainer.train()
# Save the model
print(f"Saving model to {output_dir}")
trainer.save_model(output_dir)
# Save in GGUF format for compatibility
self.model.save_pretrained_gguf(
output_dir + "_gguf",
self.tokenizer,
quantization_method="q4_k_m" # 4-bit quantization
)
print("Training completed!")
def generate_response(self, prompt: str, max_new_tokens: int = 256) -> str:
"""Generate response from the trained model"""
if not self.model:
print("Model not loaded!")
return ""
# Enable faster inference
FastLanguageModel.for_inference(self.model)
messages = [
{"role": "system", "content": "You are a helpful assistant specialized in Odoo documentation."},
{"role": "user", "content": prompt}
]
inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to("cuda")
outputs = self.model.generate(
input_ids=inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
temperature=0.7,
min_p=0.1
)
response = self.tokenizer.batch_decode(outputs)[0]
return response
def main():
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA is not available. Please ensure you have a CUDA-compatible GPU.")
return
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# Initialize trainer
trainer = OdooModelTrainer()
try:
# Load model
trainer.load_model()
# Prepare data
dataset = trainer.prepare_data()
if len(dataset) == 0:
print("No training data available!")
return
# Train model
trainer.train(dataset)
# Test the model
print("\nTesting the trained model:")
test_prompt = "How do I install Odoo?"
response = trainer.generate_response(test_prompt)
print(f"Prompt: {test_prompt}")
print(f"Response: {response}")
except Exception as e:
print(f"Error during training: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()