commit 55cfa53e9b6fba8069534efa79756cec6939aebe Author: Suherdy SYC. Yacob Date: Fri Aug 22 16:30:56 2025 +0700 first commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..b65e4cb --- /dev/null +++ b/README.md @@ -0,0 +1,183 @@ +# Odoo AI Model Trainer + +A comprehensive Python project for training AI models on Odoo documentation using Unsloth, optimized for RTX3070 8GB VRAM. The project scrapes both English and Indonesian Odoo documentation and fine-tunes the unsloth/Qwen3-8B-bnb-4bit model. + +## Features + +- 🌐 **Bilingual Support**: Scrapes both English and Indonesian Odoo documentation +- 🚀 **Optimized Training**: Uses Unsloth for 2x faster training and 70% less memory +- 🎯 **RTX3070 Optimized**: Configured for 8GB VRAM with memory-efficient settings +- 📊 **Data Pipeline**: Complete pipeline from data collection to model training +- 🔧 **Modular Design**: Separate scripts for scraping, preprocessing, and training +- 📈 **Progress Tracking**: Built-in statistics and progress monitoring + +## Requirements + +### Hardware +- NVIDIA RTX3070 (8GB VRAM) or better +- 16GB+ RAM recommended +- 50GB+ free disk space + +### Software +- Python 3.8+ +- CUDA 11.8+ +- PyTorch with CUDA support + +## Installation + +1. **Clone or download this project** + +2. **Install dependencies**: + ```bash + pip install -r requirements.txt + ``` + +3. **Verify CUDA installation**: + ```bash + python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" + ``` + +## Usage + +### Full Pipeline (Recommended) +Run the complete training pipeline: +```bash +python main.py +``` + +### Step-by-Step Execution + +1. **Data Collection Only**: + ```bash + python main.py --only-collection + ``` + +2. **Data Preprocessing Only**: + ```bash + python main.py --only-preprocessing + ``` + +3. **Model Training Only**: + ```bash + python main.py --only-training + ``` + +### Skip Specific Steps +```bash +# Skip data collection if you already have data +python main.py --skip-collection + +# Skip preprocessing if you already have training data +python main.py --skip-preprocessing + +# Skip training for testing other components +python main.py --skip-training +``` + +### Individual Scripts +You can also run individual scripts directly: + +```bash +# Scrape Odoo documentation +python data_scraper.py + +# Preprocess the scraped data +python data_preprocessor.py + +# Train the model +python train_model.py +``` + +## Project Structure + +``` +. +├── main.py # Main orchestrator script +├── data_scraper.py # Web scraping for Odoo docs +├── data_preprocessor.py # Data cleaning and formatting +├── train_model.py # Model training with Unsloth +├── requirements.txt # Python dependencies +├── README.md # This file +├── odoo_docs_data.csv # Scraped raw data (generated) +├── training_data.json # Processed training data (generated) +└── odoo_model_output/ # Trained model (generated) +``` + +## Output Files + +- **odoo_docs_data.csv**: Raw scraped documentation +- **training_data.json**: Processed training data in instruction format +- **odoo_model_output/**: Directory containing the fine-tuned model +- **odoo_model_output_gguf/**: GGUF quantized model for deployment + +## Configuration + +### Memory Optimization for RTX3070 +The training is configured with: +- Batch size: 1 (per device) +- Gradient accumulation: 4 (effective batch size: 4) +- Max sequence length: 2048 tokens +- 4-bit quantization to save VRAM +- Gradient checkpointing enabled + +### Training Parameters +- Learning rate: 2e-4 +- Max steps: 100 (increase for production) +- Warmup steps: 5 +- LoRA rank: 16 +- LoRA alpha: 16 + +## Troubleshooting + +### CUDA Out of Memory +If you encounter CUDA OOM errors: +1. Reduce batch size in `train_model.py` +2. Increase gradient accumulation steps +3. Reduce max sequence length +4. Restart your Python session + +### Data Collection Issues +- Check internet connection +- Odoo website may block rapid requests - the script includes delays +- If Indonesian docs fail, they may be at a different URL + +### Training Issues +- Ensure CUDA is properly installed +- Check that your GPU drivers are up to date +- Verify PyTorch CUDA compatibility + +## Model Usage + +After training, you can use the model for Odoo-related questions: + +```python +from train_model import OdooModelTrainer + +trainer = OdooModelTrainer() +trainer.load_model() + +# Load your trained model +# trainer.model = ... (load from odoo_model_output) + +response = trainer.generate_response("How do I install Odoo?") +print(response) +``` + +## Performance Notes + +- **Training Time**: ~30-60 minutes for 100 steps on RTX3070 +- **Memory Usage**: ~6-7GB VRAM during training +- **Data Size**: ~20-50MB of documentation data +- **Model Size**: ~4-5GB for the fine-tuned model + +## Contributing + +Feel free to submit issues and enhancement requests! + +## License + +This project is open source. Please check individual component licenses for details. + +## Disclaimer + +This project is for educational and research purposes. Ensure compliance with Odoo's terms of service when scraping documentation. \ No newline at end of file diff --git a/data_preprocessor.py b/data_preprocessor.py new file mode 100644 index 0000000..b5352bd --- /dev/null +++ b/data_preprocessor.py @@ -0,0 +1,157 @@ +import pandas as pd +import json +import re +from typing import List, Dict +import os + +class DataPreprocessor: + def __init__(self): + self.max_length = 2048 # Suitable for Qwen model + self.overlap = 200 # Overlap between chunks + + def clean_text(self, text: str) -> str: + """Clean and normalize text""" + if not text: + return "" + + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text.strip()) + + # Remove special characters but keep basic punctuation + text = re.sub(r'[^\w\s\.\,\!\?\-\:\;\(\)]', '', text) + + return text.strip() + + def chunk_text(self, text: str, title: str = "", language: str = "en") -> List[str]: + """Split text into chunks suitable for training""" + if not text: + return [] + + # Add title as context if available + if title: + if language == "id": + context = f"Judul: {title}\n\nKonten:\n\n" + else: + context = f"Title: {title}\n\nContent:\n\n" + text = context + text + + words = text.split() + chunks = [] + + for i in range(0, len(words), self.max_length - self.overlap): + chunk_words = words[i:i + self.max_length] + chunk = ' '.join(chunk_words) + + if len(chunk_words) >= 50: # Only keep substantial chunks + chunks.append(chunk) + + return chunks + + def create_training_format(self, chunk: str, language: str = "en") -> Dict: + """Format chunk for instruction tuning""" + if language == "id": + instruction = "Jelaskan dan berikan informasi tentang topik berikut berdasarkan dokumentasi Odoo:" + response_format = f"Berdasarkan dokumentasi Odoo:\n\n{chunk}" + else: + instruction = "Explain and provide information about the following topic based on Odoo documentation:" + response_format = f"Based on Odoo documentation:\n\n{chunk}" + + return { + "instruction": instruction, + "input": chunk[:500] + "..." if len(chunk) > 500 else chunk, # Truncate input for instruction + "output": response_format, + "language": language + } + + def process_csv_data(self, input_file: str = 'odoo_docs_data.csv') -> List[Dict]: + """Process CSV data and prepare for training""" + if not os.path.exists(input_file): + print(f"Input file {input_file} not found!") + return [] + + print(f"Loading data from {input_file}") + df = pd.read_csv(input_file) + + training_data = [] + + for _, row in df.iterrows(): + content = self.clean_text(row.get('content', '')) + title = row.get('title', '') + language = row.get('language', 'en') + + if not content: + continue + + # Create chunks from the content + chunks = self.chunk_text(content, title, language) + + # Convert each chunk to training format + for chunk in chunks: + training_format = self.create_training_format(chunk, language) + training_data.append(training_format) + + print(f"Processed {len(training_data)} training samples") + return training_data + + def save_training_data(self, training_data: List[Dict], output_file: str = 'training_data.json'): + """Save processed training data""" + if not training_data: + print("No training data to save!") + return + + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(training_data, f, ensure_ascii=False, indent=2) + + print(f"Saved {len(training_data)} samples to {output_file}") + + # Also create a smaller sample for testing + sample_size = min(100, len(training_data)) + sample_data = training_data[:sample_size] + + sample_file = output_file.replace('.json', '_sample.json') + with open(sample_file, 'w', encoding='utf-8') as f: + json.dump(sample_data, f, ensure_ascii=False, indent=2) + + print(f"Saved sample of {sample_size} items to {sample_file}") + + def get_statistics(self, training_data: List[Dict]) -> Dict: + """Get statistics about the training data""" + if not training_data: + return {} + + languages = {} + total_length = 0 + + for item in training_data: + lang = item.get('language', 'unknown') + languages[lang] = languages.get(lang, 0) + 1 + total_length += len(item.get('output', '')) + + return { + 'total_samples': len(training_data), + 'language_distribution': languages, + 'average_length': total_length / len(training_data), + 'max_length': max(len(item.get('output', '')) for item in training_data), + 'min_length': min(len(item.get('output', '')) for item in training_data) + } + +if __name__ == "__main__": + preprocessor = DataPreprocessor() + + # Process the scraped data + training_data = preprocessor.process_csv_data() + + if training_data: + # Save the training data + preprocessor.save_training_data(training_data) + + # Print statistics + stats = preprocessor.get_statistics(training_data) + print("\nTraining Data Statistics:") + print(f"Total samples: {stats['total_samples']}") + print(f"Language distribution: {stats['language_distribution']}") + print(".2f") + print(f"Max length: {stats['max_length']}") + print(f"Min length: {stats['min_length']}") + else: + print("No training data was generated!") \ No newline at end of file diff --git a/data_scraper.py b/data_scraper.py new file mode 100644 index 0000000..eeaec86 --- /dev/null +++ b/data_scraper.py @@ -0,0 +1,126 @@ +import requests +from bs4 import BeautifulSoup +import pandas as pd +import time +import os +from urllib.parse import urljoin + +class OdooDocScraper: + def __init__(self): + self.base_urls = { + 'en': 'https://www.odoo.com/documentation/18.0/', + 'id': 'https://www.odoo.com/documentation/18.0/id/' + } + self.session = requests.Session() + self.session.headers.update({ + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' + }) + + def get_page_content(self, url, lang): + """Scrape content from a documentation page""" + try: + response = self.session.get(url, timeout=30) + response.raise_for_status() + soup = BeautifulSoup(response.content, 'html.parser') + + # Extract main content + content_selectors = [ + 'div.document', + 'div.content', + 'main', + 'article' + ] + + content_text = "" + for selector in content_selectors: + content_div = soup.select_one(selector) + if content_div: + # Remove script and style elements + for script in content_div(["script", "style"]): + script.decompose() + + # Extract text + text = content_div.get_text(separator=' ', strip=True) + if len(text) > 100: # Only keep substantial content + content_text = text + break + + return { + 'url': url, + 'language': lang, + 'title': soup.title.string if soup.title else '', + 'content': content_text + } + + except Exception as e: + print(f"Error scraping {url}: {e}") + return None + + def get_main_pages(self, lang): + """Get main documentation pages to scrape""" + base_url = self.base_urls[lang] + main_pages = [] + + try: + response = self.session.get(base_url, timeout=30) + response.raise_for_status() + soup = BeautifulSoup(response.content, 'html.parser') + + # Look for navigation links + nav_selectors = [ + 'nav a[href]', + '.toctree a[href]', + 'ul li a[href]' + ] + + for selector in nav_selectors: + links = soup.select(selector) + for link in links: + href = link.get('href') + if href and not href.startswith('#') and not href.startswith('mailto:'): + full_url = urljoin(base_url, href) + if full_url.startswith(base_url) and full_url not in main_pages: + main_pages.append(full_url) + + # Limit to first 20 pages per language to avoid overwhelming + return main_pages[:20] + + except Exception as e: + print(f"Error getting main pages for {lang}: {e}") + return [base_url] # Fallback to base URL + + def scrape_documentation(self): + """Scrape documentation from both languages""" + all_data = [] + + for lang in ['en', 'id']: + print(f"Scraping {lang} documentation...") + pages = self.get_main_pages(lang) + + for i, page_url in enumerate(pages): + print(f"Scraping page {i+1}/{len(pages)}: {page_url}") + page_data = self.get_page_content(page_url, lang) + if page_data and page_data['content']: + all_data.append(page_data) + time.sleep(1) # Be respectful to the server + + return all_data + + def save_data(self, data, output_file='odoo_docs_data.csv'): + """Save scraped data to CSV""" + if not data: + print("No data to save!") + return + + df = pd.DataFrame(data) + df.to_csv(output_file, index=False, encoding='utf-8') + print(f"Saved {len(data)} pages to {output_file}") + + # Also save as JSON for training + df.to_json(output_file.replace('.csv', '.json'), orient='records', force_ascii=False, indent=2) + print(f"Also saved as JSON: {output_file.replace('.csv', '.json')}") + +if __name__ == "__main__": + scraper = OdooDocScraper() + data = scraper.scrape_documentation() + scraper.save_data(data) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..94b9700 --- /dev/null +++ b/main.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Odoo AI Model Trainer - Main Orchestrator Script +Trains an AI model on Odoo documentation using Unsloth +""" + +import os +import sys +import argparse +from data_scraper import OdooDocScraper +from data_preprocessor import DataPreprocessor +from train_model import OdooModelTrainer + +def run_data_collection(): + """Step 1: Collect data from Odoo documentation""" + print("=== Step 1: Data Collection ===") + + if os.path.exists('odoo_docs_data.csv'): + print("Data file already exists. Skipping data collection.") + print("To re-scrape data, delete 'odoo_docs_data.csv' and run again.") + return True + + try: + scraper = OdooDocScraper() + data = scraper.scrape_documentation() + scraper.save_data(data) + return len(data) > 0 + except Exception as e: + print(f"Error during data collection: {e}") + return False + +def run_data_preprocessing(): + """Step 2: Preprocess and format the collected data""" + print("\n=== Step 2: Data Preprocessing ===") + + if not os.path.exists('odoo_docs_data.csv'): + print("No raw data found. Please run data collection first.") + return False + + if os.path.exists('training_data.json'): + print("Training data already exists. Skipping preprocessing.") + print("To reprocess data, delete 'training_data.json' and run again.") + return True + + try: + preprocessor = DataPreprocessor() + training_data = preprocessor.process_csv_data() + preprocessor.save_training_data(training_data) + + stats = preprocessor.get_statistics(training_data) + print("\nTraining Data Statistics:") + print(f"Total samples: {stats['total_samples']}") + print(f"Language distribution: {stats['language_distribution']}") + print(f"Average length: {stats['average_length']:.2f}") + + return len(training_data) > 0 + except Exception as e: + print(f"Error during data preprocessing: {e}") + return False + +def run_model_training(skip_training=False): + """Step 3: Train the AI model""" + print("\n=== Step 3: Model Training ===") + + if skip_training: + print("Training skipped as requested.") + return True + + if not os.path.exists('training_data.json'): + print("No training data found. Please run data preprocessing first.") + return False + + try: + trainer = OdooModelTrainer() + trainer.load_model() + dataset = trainer.prepare_data() + trainer.train(dataset) + return True + except Exception as e: + print(f"Error during model training: {e}") + import traceback + traceback.print_exc() + return False + +def main(): + """Main orchestrator function""" + parser = argparse.ArgumentParser(description='Odoo AI Model Trainer') + parser.add_argument('--skip-collection', action='store_true', + help='Skip data collection step') + parser.add_argument('--skip-preprocessing', action='store_true', + help='Skip data preprocessing step') + parser.add_argument('--skip-training', action='store_true', + help='Skip model training step') + parser.add_argument('--only-collection', action='store_true', + help='Only run data collection') + parser.add_argument('--only-preprocessing', action='store_true', + help='Only run data preprocessing') + parser.add_argument('--only-training', action='store_true', + help='Only run model training') + + args = parser.parse_args() + + print("🚀 Odoo AI Model Trainer") + print("=" * 50) + + # Check for specific modes + if args.only_collection: + success = run_data_collection() + sys.exit(0 if success else 1) + + if args.only_preprocessing: + success = run_data_preprocessing() + sys.exit(0 if success else 1) + + if args.only_training: + success = run_model_training() + sys.exit(0 if success else 1) + + # Full pipeline mode + steps = [] + if not args.skip_collection: + steps.append(("Data Collection", run_data_collection)) + if not args.skip_preprocessing: + steps.append(("Data Preprocessing", run_data_preprocessing)) + if not args.skip_training: + steps.append(("Model Training", run_model_training)) + + if not steps: + print("No steps to run. Use --help to see available options.") + return + + success_count = 0 + for step_name, step_func in steps: + if step_func(): + success_count += 1 + print(f"✅ {step_name} completed successfully") + else: + print(f"❌ {step_name} failed") + break + + print("\n=== Final Results ===") + print(f"Completed steps: {success_count}/{len(steps)}") + + if success_count == len(steps): + print("🎉 All steps completed successfully!") + print("\nNext steps:") + print("1. Check the 'odoo_model_output' directory for trained model") + print("2. Use the model for Odoo-related questions") + else: + print("❌ Some steps failed. Check the output above for details.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..87fdc55 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git +transformers>=4.38.0 +torch>=2.1.0 +requests>=2.31.0 +beautifulsoup4>=4.12.0 +pandas>=2.0.0 +numpy>=1.24.0 +datasets>=2.14.0 +accelerate>=0.27.0 +peft>=0.7.0 +bitsandbytes>=0.41.0 \ No newline at end of file diff --git a/test_setup.py b/test_setup.py new file mode 100644 index 0000000..2826ba5 --- /dev/null +++ b/test_setup.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +Test script to verify the Odoo AI Model Trainer setup +""" + +import sys +import torch + +def test_cuda(): + """Test CUDA availability""" + print("Testing CUDA availability...") + if torch.cuda.is_available(): + print(f"✅ CUDA available: {torch.cuda.get_device_name(0)}") + print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") + return True + else: + print("❌ CUDA not available") + return False + +def test_imports(): + """Test if all required packages can be imported""" + print("\nTesting imports...") + + required_packages = [ + 'requests', + 'beautifulsoup4', + 'pandas', + 'numpy', + 'datasets', + 'transformers', + 'unsloth' + ] + + failed_imports = [] + + for package in required_packages: + try: + if package == 'beautifulsoup4': + import bs4 + elif package == 'unsloth': + try: + import unsloth + except ImportError: + print(f"⚠️ {package} not available (this is normal if not installed)") + continue + else: + __import__(package) + print(f"✅ {package}") + except ImportError: + print(f"❌ {package}") + failed_imports.append(package) + + return len(failed_imports) == 0 + +def test_file_structure(): + """Test if all required files exist""" + print("\nTesting file structure...") + + required_files = [ + 'main.py', + 'data_scraper.py', + 'data_preprocessor.py', + 'train_model.py', + 'requirements.txt', + 'README.md' + ] + + missing_files = [] + + for file in required_files: + try: + with open(file, 'r') as f: + pass + print(f"✅ {file}") + except FileNotFoundError: + print(f"❌ {file}") + missing_files.append(file) + + return len(missing_files) == 0 + +def main(): + """Run all tests""" + print("🚀 Odoo AI Model Trainer - Setup Test") + print("=" * 50) + + tests = [ + ("CUDA Setup", test_cuda), + ("Package Imports", test_imports), + ("File Structure", test_file_structure) + ] + + results = [] + + for test_name, test_func in tests: + print(f"\n{test_name}:") + print("-" * 30) + result = test_func() + results.append((test_name, result)) + + print("\n" + "=" * 50) + print("📋 Test Summary:") + + all_passed = True + for test_name, result in results: + status = "✅ PASS" if result else "❌ FAIL" + print(f" {test_name}: {status}") + if not result: + all_passed = False + + print("\n" + "=" * 50) + if all_passed: + print("🎉 All tests passed! Your setup is ready.") + print("\nNext steps:") + print("1. Run 'python main.py' to start the full pipeline") + print("2. Or run 'python data_scraper.py' to start with data collection") + else: + print("❌ Some tests failed. Please fix the issues above.") + print("\nCommon solutions:") + print("1. Install missing packages: pip install -r requirements.txt") + print("2. Check CUDA installation") + print("3. Ensure all files are in the correct directory") + + return 0 if all_passed else 1 + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/train_model.py b/train_model.py new file mode 100644 index 0000000..c2b1a32 --- /dev/null +++ b/train_model.py @@ -0,0 +1,227 @@ +import torch +import json +import os +from typing import Dict, List +from datasets import Dataset +from data_preprocessor import DataPreprocessor +import gc + +# Import Unsloth components +from unsloth import FastLanguageModel +from unsloth.chat_templates import get_chat_template +from transformers import TrainingArguments +from trl import SFTTrainer + +class OdooModelTrainer: + def __init__(self): + self.model_name = "unsloth/Qwen3-8B-bnb-4bit" + self.max_seq_length = 2048 + self.load_in_4bit = True + self.model = None + self.tokenizer = None + + def load_model(self): + """Load the Qwen model with Unsloth optimizations""" + print("Loading model...") + + self.model, self.tokenizer = FastLanguageModel.from_pretrained( + model_name=self.model_name, + max_seq_length=self.max_seq_length, + dtype=None, # Auto detection + load_in_4bit=self.load_in_4bit, + ) + + # Enable gradient checkpointing for memory efficiency + self.model.gradient_checkpointing_enable() + + print("Model loaded successfully!") + + def prepare_data(self, data_file: str = 'training_data.json') -> Dataset: + """Prepare training data""" + if not os.path.exists(data_file): + print(f"Data file {data_file} not found. Running data preprocessing...") + preprocessor = DataPreprocessor() + training_data = preprocessor.process_csv_data() + preprocessor.save_training_data(training_data, data_file) + + print(f"Loading data from {data_file}") + with open(data_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # Convert to HuggingFace Dataset format + dataset_dict = { + 'instruction': [item['instruction'] for item in data], + 'input': [item['input'] for item in data], + 'output': [item['output'] for item in data] + } + + dataset = Dataset.from_dict(dataset_dict) + print(f"Prepared dataset with {len(dataset)} samples") + + return dataset + + def format_chat_template(self, example): + """Format data for chat template""" + messages = [ + {"role": "system", "content": "You are a helpful assistant specialized in Odoo documentation."}, + {"role": "user", "content": f"{example['instruction']}\n\n{example['input']}"}, + {"role": "assistant", "content": example['output']} + ] + + # Apply chat template + formatted_text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False + ) + + return {"text": formatted_text} + + def train(self, dataset: Dataset, output_dir: str = './odoo_model_output'): + """Train the model""" + print("Starting training...") + + # Apply chat template to dataset + formatted_dataset = dataset.map(self.format_chat_template) + + # Configure LoRA for efficient training + self.model = FastLanguageModel.get_peft_model( + self.model, + r=16, # LoRA rank + target_modules=[ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj" + ], + lora_alpha=16, + lora_dropout=0, + bias="none", + use_gradient_checkpointing=True, + random_state=3407, + use_rslora=False, + loftq_config=None, + ) + + # Configure training arguments optimized for RTX3070 8GB VRAM + training_args = TrainingArguments( + per_device_train_batch_size=1, # Very small batch size for 8GB VRAM + gradient_accumulation_steps=4, # Effective batch size of 4 + warmup_steps=5, + max_steps=100, # Limit steps for testing, increase as needed + learning_rate=2e-4, + fp16=not torch.cuda.is_bf16_supported(), + bf16=torch.cuda.is_bf16_supported(), + logging_steps=1, + optim="adamw_8bit", + weight_decay=0.01, + lr_scheduler_type="linear", + seed=3407, + output_dir=output_dir, + save_steps=50, + save_total_limit=2, + report_to="none", # Disable wandb/tensorboard for simplicity + ) + + # Initialize trainer + trainer = SFTTrainer( + model=self.model, + tokenizer=self.tokenizer, + train_dataset=formatted_dataset, + dataset_text_field="text", + max_seq_length=self.max_seq_length, + dataset_num_proc=2, + packing=False, # Can make training faster for short sequences + args=training_args, + ) + + # Clear cache before training + gc.collect() + torch.cuda.empty_cache() + + print("Starting training...") + trainer.train() + + # Save the model + print(f"Saving model to {output_dir}") + trainer.save_model(output_dir) + + # Save in GGUF format for compatibility + self.model.save_pretrained_gguf( + output_dir + "_gguf", + self.tokenizer, + quantization_method="q4_k_m" # 4-bit quantization + ) + + print("Training completed!") + + def generate_response(self, prompt: str, max_new_tokens: int = 256) -> str: + """Generate response from the trained model""" + if not self.model: + print("Model not loaded!") + return "" + + # Enable faster inference + FastLanguageModel.for_inference(self.model) + + messages = [ + {"role": "system", "content": "You are a helpful assistant specialized in Odoo documentation."}, + {"role": "user", "content": prompt} + ] + + inputs = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt" + ).to("cuda") + + outputs = self.model.generate( + input_ids=inputs, + max_new_tokens=max_new_tokens, + use_cache=True, + temperature=0.7, + min_p=0.1 + ) + + response = self.tokenizer.batch_decode(outputs)[0] + return response + +def main(): + # Check CUDA availability + if not torch.cuda.is_available(): + print("CUDA is not available. Please ensure you have a CUDA-compatible GPU.") + return + + print(f"Using GPU: {torch.cuda.get_device_name(0)}") + print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") + + # Initialize trainer + trainer = OdooModelTrainer() + + try: + # Load model + trainer.load_model() + + # Prepare data + dataset = trainer.prepare_data() + + if len(dataset) == 0: + print("No training data available!") + return + + # Train model + trainer.train(dataset) + + # Test the model + print("\nTesting the trained model:") + test_prompt = "How do I install Odoo?" + response = trainer.generate_response(test_prompt) + print(f"Prompt: {test_prompt}") + print(f"Response: {response}") + + except Exception as e: + print(f"Error during training: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file