227 lines
7.2 KiB
Python
227 lines
7.2 KiB
Python
import torch
|
|
import json
|
|
import os
|
|
from typing import Dict, List
|
|
from datasets import Dataset
|
|
from data_preprocessor import DataPreprocessor
|
|
import gc
|
|
|
|
# Import Unsloth components
|
|
from unsloth import FastLanguageModel
|
|
from unsloth.chat_templates import get_chat_template
|
|
from transformers import TrainingArguments
|
|
from trl import SFTTrainer
|
|
|
|
class OdooModelTrainer:
|
|
def __init__(self):
|
|
self.model_name = "unsloth/Qwen3-8B-bnb-4bit"
|
|
self.max_seq_length = 2048
|
|
self.load_in_4bit = True
|
|
self.model = None
|
|
self.tokenizer = None
|
|
|
|
def load_model(self):
|
|
"""Load the Qwen model with Unsloth optimizations"""
|
|
print("Loading model...")
|
|
|
|
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
|
model_name=self.model_name,
|
|
max_seq_length=self.max_seq_length,
|
|
dtype=None, # Auto detection
|
|
load_in_4bit=self.load_in_4bit,
|
|
)
|
|
|
|
# Enable gradient checkpointing for memory efficiency
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
print("Model loaded successfully!")
|
|
|
|
def prepare_data(self, data_file: str = 'training_data.json') -> Dataset:
|
|
"""Prepare training data"""
|
|
if not os.path.exists(data_file):
|
|
print(f"Data file {data_file} not found. Running data preprocessing...")
|
|
preprocessor = DataPreprocessor()
|
|
training_data = preprocessor.process_csv_data()
|
|
preprocessor.save_training_data(training_data, data_file)
|
|
|
|
print(f"Loading data from {data_file}")
|
|
with open(data_file, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
# Convert to HuggingFace Dataset format
|
|
dataset_dict = {
|
|
'instruction': [item['instruction'] for item in data],
|
|
'input': [item['input'] for item in data],
|
|
'output': [item['output'] for item in data]
|
|
}
|
|
|
|
dataset = Dataset.from_dict(dataset_dict)
|
|
print(f"Prepared dataset with {len(dataset)} samples")
|
|
|
|
return dataset
|
|
|
|
def format_chat_template(self, example):
|
|
"""Format data for chat template"""
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant specialized in Odoo documentation."},
|
|
{"role": "user", "content": f"{example['instruction']}\n\n{example['input']}"},
|
|
{"role": "assistant", "content": example['output']}
|
|
]
|
|
|
|
# Apply chat template
|
|
formatted_text = self.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=False
|
|
)
|
|
|
|
return {"text": formatted_text}
|
|
|
|
def train(self, dataset: Dataset, output_dir: str = './odoo_model_output'):
|
|
"""Train the model"""
|
|
print("Starting training...")
|
|
|
|
# Apply chat template to dataset
|
|
formatted_dataset = dataset.map(self.format_chat_template)
|
|
|
|
# Configure LoRA for efficient training
|
|
self.model = FastLanguageModel.get_peft_model(
|
|
self.model,
|
|
r=16, # LoRA rank
|
|
target_modules=[
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj"
|
|
],
|
|
lora_alpha=16,
|
|
lora_dropout=0,
|
|
bias="none",
|
|
use_gradient_checkpointing=True,
|
|
random_state=3407,
|
|
use_rslora=False,
|
|
loftq_config=None,
|
|
)
|
|
|
|
# Configure training arguments optimized for RTX3070 8GB VRAM
|
|
training_args = TrainingArguments(
|
|
per_device_train_batch_size=1, # Very small batch size for 8GB VRAM
|
|
gradient_accumulation_steps=4, # Effective batch size of 4
|
|
warmup_steps=5,
|
|
max_steps=100, # Limit steps for testing, increase as needed
|
|
learning_rate=2e-4,
|
|
fp16=not torch.cuda.is_bf16_supported(),
|
|
bf16=torch.cuda.is_bf16_supported(),
|
|
logging_steps=1,
|
|
optim="adamw_8bit",
|
|
weight_decay=0.01,
|
|
lr_scheduler_type="linear",
|
|
seed=3407,
|
|
output_dir=output_dir,
|
|
save_steps=50,
|
|
save_total_limit=2,
|
|
report_to="none", # Disable wandb/tensorboard for simplicity
|
|
)
|
|
|
|
# Initialize trainer
|
|
trainer = SFTTrainer(
|
|
model=self.model,
|
|
tokenizer=self.tokenizer,
|
|
train_dataset=formatted_dataset,
|
|
dataset_text_field="text",
|
|
max_seq_length=self.max_seq_length,
|
|
dataset_num_proc=2,
|
|
packing=False, # Can make training faster for short sequences
|
|
args=training_args,
|
|
)
|
|
|
|
# Clear cache before training
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
print("Starting training...")
|
|
trainer.train()
|
|
|
|
# Save the model
|
|
print(f"Saving model to {output_dir}")
|
|
trainer.save_model(output_dir)
|
|
|
|
# Save in GGUF format for compatibility
|
|
self.model.save_pretrained_gguf(
|
|
output_dir + "_gguf",
|
|
self.tokenizer,
|
|
quantization_method="q4_k_m" # 4-bit quantization
|
|
)
|
|
|
|
print("Training completed!")
|
|
|
|
def generate_response(self, prompt: str, max_new_tokens: int = 256) -> str:
|
|
"""Generate response from the trained model"""
|
|
if not self.model:
|
|
print("Model not loaded!")
|
|
return ""
|
|
|
|
# Enable faster inference
|
|
FastLanguageModel.for_inference(self.model)
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant specialized in Odoo documentation."},
|
|
{"role": "user", "content": prompt}
|
|
]
|
|
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=True,
|
|
add_generation_prompt=True,
|
|
return_tensors="pt"
|
|
).to("cuda")
|
|
|
|
outputs = self.model.generate(
|
|
input_ids=inputs,
|
|
max_new_tokens=max_new_tokens,
|
|
use_cache=True,
|
|
temperature=0.7,
|
|
min_p=0.1
|
|
)
|
|
|
|
response = self.tokenizer.batch_decode(outputs)[0]
|
|
return response
|
|
|
|
def main():
|
|
# Check CUDA availability
|
|
if not torch.cuda.is_available():
|
|
print("CUDA is not available. Please ensure you have a CUDA-compatible GPU.")
|
|
return
|
|
|
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
|
|
|
# Initialize trainer
|
|
trainer = OdooModelTrainer()
|
|
|
|
try:
|
|
# Load model
|
|
trainer.load_model()
|
|
|
|
# Prepare data
|
|
dataset = trainer.prepare_data()
|
|
|
|
if len(dataset) == 0:
|
|
print("No training data available!")
|
|
return
|
|
|
|
# Train model
|
|
trainer.train(dataset)
|
|
|
|
# Test the model
|
|
print("\nTesting the trained model:")
|
|
test_prompt = "How do I install Odoo?"
|
|
response = trainer.generate_response(test_prompt)
|
|
print(f"Prompt: {test_prompt}")
|
|
print(f"Response: {response}")
|
|
|
|
except Exception as e:
|
|
print(f"Error during training: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
if __name__ == "__main__":
|
|
main() |