1. add 2 data processor type: standard and synthetic
2. add DataProcessorSynthetic class to format github repo to QA ChatML format
This commit is contained in:
parent
43d6f0e98a
commit
aaa0f1b51e
@ -46,6 +46,64 @@ dataset = processor.process_github_repos(
|
||||
print(f"Dataset processed successfully with {len(dataset)} samples")
|
||||
```
|
||||
|
||||
## Using the DatasetProcessorSynthetic Class
|
||||
|
||||
The `DatasetProcessorSynthetic` class in `src/dataset_processor_synthetic.py` provides functionality for processing GitHub repositories into training datasets in QA ChatML format using a local AI model (Ollama).
|
||||
|
||||
### Example Usage
|
||||
|
||||
```python
|
||||
from src.dataset_processor_synthetic import DatasetProcessorSynthetic
|
||||
from src.config import AppConfig, ModelConfig, TrainingConfig, DatasetConfig, MemoryConfig
|
||||
|
||||
# Initialize configuration
|
||||
config = AppConfig(
|
||||
model=ModelConfig(),
|
||||
training=TrainingConfig(),
|
||||
dataset=DatasetConfig(),
|
||||
memory=MemoryConfig()
|
||||
)
|
||||
|
||||
# Initialize dataset processor
|
||||
processor = DatasetProcessorSynthetic()
|
||||
|
||||
# Process GitHub repositories
|
||||
repo_urls = [
|
||||
"https://github.com/karpathy/nanoGPT.git",
|
||||
# Add more repository URLs as needed
|
||||
]
|
||||
|
||||
dataset = processor.process_github_repos(
|
||||
repo_urls=repo_urls,
|
||||
config=config,
|
||||
github_token=None # Add your token for private repositories
|
||||
)
|
||||
|
||||
print(f"Dataset processed successfully with {len(dataset)} samples")
|
||||
```
|
||||
|
||||
## Saving and Loading Datasets
|
||||
|
||||
Both dataset processors support saving and loading datasets to/from disk to avoid reprocessing:
|
||||
|
||||
```python
|
||||
# Save dataset
|
||||
processor.save_dataset(dataset, "./my_processed_dataset")
|
||||
|
||||
# Load dataset
|
||||
loaded_dataset = processor.load_dataset("./my_processed_dataset")
|
||||
```
|
||||
|
||||
The main script also supports saving/loading datasets via command-line arguments:
|
||||
|
||||
```bash
|
||||
# Process and save dataset
|
||||
python src/main.py --repo1 https://github.com/repo1 --repo2 https://github.com/repo2 --dataset_path ./my_dataset
|
||||
|
||||
# Load and train with existing dataset
|
||||
python src/main.py --repo1 https://github.com/repo1 --repo2 https://github.com/repo2 --dataset_path ./my_dataset
|
||||
```
|
||||
|
||||
## Using the Example Script
|
||||
|
||||
You can run the example script directly:
|
||||
@ -105,9 +163,19 @@ dataset_config = DatasetConfig(
|
||||
## Output Format
|
||||
|
||||
The processed dataset contains the following fields for each sample:
|
||||
|
||||
For the standard `DatasetProcessor`:
|
||||
- `text`: The content of the code file
|
||||
- `language`: The programming language detected
|
||||
- `file_path`: Relative path to the file within the repository
|
||||
- `repo_name`: Name of the repository
|
||||
- `file_size`: Size of the file in characters
|
||||
- `line_count`: Number of lines in the file
|
||||
|
||||
For the `DatasetProcessorSynthetic`:
|
||||
- `messages`: List of messages in ChatML format (system, user, assistant)
|
||||
- `language`: The programming language detected
|
||||
- `file_path`: Relative path to the file within the repository
|
||||
- `repo_name`: Name of the repository
|
||||
- `file_size`: Size of the file in characters
|
||||
- `line_count`: Number of lines in the file
|
||||
@ -11,15 +11,15 @@ model:
|
||||
|
||||
training:
|
||||
# Memory-optimized batch size for RTX3070 8GB
|
||||
per_device_train_batch_size: 1
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 16
|
||||
max_steps: 50
|
||||
max_steps: 120
|
||||
|
||||
# Training parameters
|
||||
num_train_epochs: 1
|
||||
learning_rate: 2.0e-4
|
||||
num_train_epochs: 3
|
||||
learning_rate: 1.0e-4
|
||||
warmup_steps: 10
|
||||
warmup_ratio: 0.1
|
||||
warmup_ratio: 0.03
|
||||
|
||||
# Logging and saving
|
||||
logging_steps: 1
|
||||
|
||||
91
example_synthetic_dataset_processing.py
Normal file
91
example_synthetic_dataset_processing.py
Normal file
@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example script demonstrating how to get and process a dataset from GitHub repositories
|
||||
using the DatasetProcessorSynthetic class, and save/load the processed dataset.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path to import our modules
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
from src.dataset_processor_synthetic import DatasetProcessorSynthetic
|
||||
from src.config import AppConfig, ModelConfig, TrainingConfig, DatasetConfig, MemoryConfig
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize configuration
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
name="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
|
||||
max_seq_length=2048
|
||||
),
|
||||
training=TrainingConfig(),
|
||||
dataset=DatasetConfig(),
|
||||
memory=MemoryConfig()
|
||||
)
|
||||
|
||||
# Initialize dataset processor
|
||||
processor = DatasetProcessorSynthetic()
|
||||
|
||||
# Example GitHub repositories to process
|
||||
# Replace these with your own repositories
|
||||
repo_urls = [
|
||||
"https://github.com/karpathy/nanoGPT.git",
|
||||
# "https://github.com/your-username/your-repo.git"
|
||||
]
|
||||
|
||||
try:
|
||||
# Check if a saved dataset exists
|
||||
dataset_path = "./processed_synthetic_dataset"
|
||||
import os
|
||||
if os.path.exists(dataset_path):
|
||||
print("Loading previously processed dataset...")
|
||||
dataset = processor.load_dataset(dataset_path)
|
||||
else:
|
||||
print("Processing GitHub repositories...")
|
||||
dataset = processor.process_github_repos(
|
||||
repo_urls=repo_urls,
|
||||
config=config,
|
||||
github_token=None # Add your token here if processing private repositories
|
||||
)
|
||||
|
||||
print(f"Dataset processed successfully!")
|
||||
print(f"Dataset size: {len(dataset)} samples")
|
||||
|
||||
# Save dataset to disk for future use
|
||||
print(f"Saving dataset to {dataset_path}...")
|
||||
processor.save_dataset(dataset, dataset_path)
|
||||
print("Dataset saved successfully!")
|
||||
|
||||
print(f"Dataset loaded with {len(dataset)} samples")
|
||||
|
||||
# Show some examples from the dataset
|
||||
print("\nFirst 2 samples from the dataset:")
|
||||
for i in range(min(2, len(dataset))):
|
||||
sample = dataset[i]
|
||||
print(f"\nSample {i+1}:")
|
||||
print(f" Repository: {sample['repo_name']}")
|
||||
print(f" File path: {sample['file_path']}")
|
||||
print(f" Language: {sample['language']}")
|
||||
print(f" File size: {sample['file_size']} characters")
|
||||
print(f" Lines: {sample['line_count']}")
|
||||
|
||||
# Show messages structure
|
||||
messages = sample['messages']
|
||||
print(f" Messages: {len(messages)} messages")
|
||||
for j, message in enumerate(messages):
|
||||
print(f" Message {j+1} ({message['role']}): {message['content'][:100]}...")
|
||||
|
||||
return dataset
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing repositories: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = main()
|
||||
@ -15,7 +15,7 @@ import yaml
|
||||
class ModelConfig:
|
||||
"""Model-specific configuration"""
|
||||
name: str = "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit"
|
||||
max_seq_length: int = 2048
|
||||
max_seq_length: int = 1024
|
||||
trust_remote_code: bool = True
|
||||
use_fast_tokenizer: bool = True
|
||||
padding_side: str = "left"
|
||||
@ -27,7 +27,7 @@ class TrainingConfig:
|
||||
"""Training configuration"""
|
||||
per_device_train_batch_size: int = 2
|
||||
gradient_accumulation_steps: int = 4
|
||||
max_steps: int = 10
|
||||
max_steps: int = 100
|
||||
num_train_epochs: int = 2
|
||||
learning_rate: float = 2e-4
|
||||
warmup_steps: int = 10
|
||||
@ -75,7 +75,7 @@ class TrainingConfig:
|
||||
|
||||
# Dataset processing
|
||||
dataset_shuffle: bool = True
|
||||
dataset_seed: int = 42
|
||||
dataset_seed: int = 3407
|
||||
|
||||
# Output settings
|
||||
output_dir: str = "./models"
|
||||
@ -123,7 +123,7 @@ class DatasetConfig:
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Memory optimization settings for RTX3070 8GB"""
|
||||
max_memory_usage: float = 0.85 # Use up to 85% of GPU memory
|
||||
max_memory_usage: float = 0.95 # Use up to 95% of GPU memory
|
||||
enable_memory_tracking: bool = True
|
||||
clear_cache_between_epochs: bool = True
|
||||
use_memory_efficient_attention: bool = True
|
||||
|
||||
372
src/dataset_processor_synthetic.py
Normal file
372
src/dataset_processor_synthetic.py
Normal file
@ -0,0 +1,372 @@
|
||||
"""
|
||||
Dataset processor for GitHub repositories
|
||||
Processes code from GitHub repositories into training datasets in QA ChatML format
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import git
|
||||
import requests
|
||||
from datasets import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import AppConfig
|
||||
|
||||
|
||||
class DatasetProcessorSynthetic:
|
||||
"""Processes GitHub repositories into training datasets in QA ChatML format"""
|
||||
|
||||
# Supported file extensions for code training
|
||||
CODE_EXTENSIONS = {
|
||||
'.py': 'python',
|
||||
'.js': 'javascript',
|
||||
'.ts': 'typescript',
|
||||
'.java': 'java',
|
||||
'.cpp': 'cpp',
|
||||
'.c': 'c',
|
||||
'.h': 'c',
|
||||
'.hpp': 'cpp',
|
||||
'.cs': 'csharp',
|
||||
'.php': 'php',
|
||||
'.rb': 'ruby',
|
||||
'.go': 'go',
|
||||
'.rs': 'rust',
|
||||
'.swift': 'swift',
|
||||
'.kt': 'kotlin',
|
||||
'.scala': 'scala',
|
||||
'.sql': 'sql',
|
||||
'.sh': 'bash',
|
||||
'.yaml': 'yaml',
|
||||
'.yml': 'yaml',
|
||||
'.json': 'json',
|
||||
'.xml': 'xml',
|
||||
'.html': 'html',
|
||||
'.css': 'css',
|
||||
'.md': 'markdown'
|
||||
}
|
||||
|
||||
# Files and directories to exclude
|
||||
EXCLUDE_PATTERNS = [
|
||||
r'\.git/',
|
||||
r'__pycache__/',
|
||||
r'\.pytest_cache/',
|
||||
r'node_modules/',
|
||||
r'\.venv/',
|
||||
r'venv/',
|
||||
r'\.DS_Store',
|
||||
r'\.pyc$',
|
||||
r'\.pyo$',
|
||||
r'\.pyd$',
|
||||
r'\.so$',
|
||||
r'\.dll$',
|
||||
r'\.exe$',
|
||||
r'\.bin$',
|
||||
r'package-lock\.json$',
|
||||
r'yarn\.lock$',
|
||||
r'\.log$',
|
||||
r'\.tmp$',
|
||||
r'\.bak$',
|
||||
r'~\$.*',
|
||||
r'\.swp$',
|
||||
r'\.swo$'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.temp_dirs = []
|
||||
|
||||
def process_github_repos(self, repo_urls: List[str], config: AppConfig, github_token: Optional[str] = None) -> Dataset:
|
||||
"""
|
||||
Process multiple GitHub repositories into a training dataset in QA ChatML format
|
||||
|
||||
Args:
|
||||
repo_urls: List of GitHub repository URLs
|
||||
config: Training configuration
|
||||
github_token: Optional GitHub token for accessing private repositories
|
||||
|
||||
Returns:
|
||||
Dataset ready for training
|
||||
"""
|
||||
all_code_samples = []
|
||||
|
||||
for repo_url in repo_urls:
|
||||
try:
|
||||
self.logger.info(f"Processing repository: {repo_url}")
|
||||
repo_samples = self._process_single_repo(repo_url, config, github_token)
|
||||
all_code_samples.extend(repo_samples)
|
||||
self.logger.info(f"Extracted {len(repo_samples)} samples from {repo_url}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to process repository {repo_url}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not all_code_samples:
|
||||
raise ValueError("No code samples extracted from any repository")
|
||||
|
||||
self.logger.info(f"Total samples collected: {len(all_code_samples)}")
|
||||
|
||||
# Create HuggingFace dataset
|
||||
dataset = Dataset.from_list(all_code_samples)
|
||||
|
||||
# Filter by sequence length (using messages format)
|
||||
dataset = dataset.filter(
|
||||
lambda x: self._get_total_message_tokens(x['messages']) <= config.model.max_seq_length
|
||||
)
|
||||
|
||||
self.logger.info(f"Dataset size after filtering: {len(dataset)}")
|
||||
return dataset
|
||||
|
||||
def _get_total_message_tokens(self, messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in messages
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
# Simple approximation: count words in all message content
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
if 'content' in message:
|
||||
total_tokens += len(message['content'].split())
|
||||
return total_tokens
|
||||
|
||||
def _process_single_repo(self, repo_url: str, config: AppConfig, github_token: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
Process a single GitHub repository
|
||||
|
||||
Args:
|
||||
repo_url: GitHub repository URL
|
||||
config: Training configuration
|
||||
github_token: Optional GitHub token for accessing private repositories
|
||||
|
||||
Returns:
|
||||
List of code samples with metadata
|
||||
"""
|
||||
# Create a persistent directory for cloned repositories
|
||||
gitclone_dir = Path("./gitclone")
|
||||
gitclone_dir.mkdir(exist_ok=True)
|
||||
temp_dir = str(gitclone_dir)
|
||||
# Note: We don't add this to temp_dirs since we want to keep it
|
||||
|
||||
depth = 1
|
||||
branch = "18.0"
|
||||
|
||||
try:
|
||||
# Clone repository
|
||||
repo_name = repo_url.split('/')[-1].replace('.git', '')
|
||||
repo_path = os.path.join(temp_dir, repo_name)
|
||||
if not os.path.exists(repo_path):
|
||||
self.logger.info(f"Cloning {repo_url} to {repo_path}")
|
||||
|
||||
# Use token for private repositories if provided
|
||||
clone_url = repo_url
|
||||
if github_token and "github.com" in repo_url:
|
||||
# Handle SSH URLs
|
||||
if repo_url.startswith("git@"):
|
||||
# SSH URL doesn't need token modification
|
||||
pass
|
||||
else:
|
||||
# Add token to HTTPS URL
|
||||
if repo_url.startswith("https://"):
|
||||
clone_url = repo_url.replace("https://", f"https://{github_token}@")
|
||||
elif repo_url.startswith("http://"):
|
||||
clone_url = repo_url.replace("http://", f"http://{github_token}@")
|
||||
else:
|
||||
# For URLs like "github.com/user/repo" or "user/repo"
|
||||
if repo_url.startswith("github.com/"):
|
||||
clone_url = f"https://{github_token}@{repo_url}"
|
||||
else:
|
||||
# Assume it's a GitHub path like "user/repo"
|
||||
clone_url = f"https://{github_token}@github.com/{repo_url}"
|
||||
|
||||
repo = git.Repo.clone_from(clone_url, repo_path, depth=depth, branch=branch)
|
||||
|
||||
# Extract code samples
|
||||
code_samples = self._extract_code_samples(repo_path, config)
|
||||
|
||||
return code_samples
|
||||
|
||||
finally:
|
||||
self.logger.info(f"Finished processing {repo_url}")
|
||||
# Cleanup temporary directories, but keep gitclone folder
|
||||
# if temp_dir != "./gitclone":
|
||||
# shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
def _extract_code_samples(self, repo_path: str, config: AppConfig) -> List[Dict]:
|
||||
"""
|
||||
Extract code samples from a repository
|
||||
|
||||
Args:
|
||||
repo_path: Path to cloned repository
|
||||
config: Training configuration
|
||||
|
||||
Returns:
|
||||
List of code samples
|
||||
"""
|
||||
code_samples = []
|
||||
repo_path_obj = Path(repo_path)
|
||||
|
||||
# Find all code files
|
||||
code_files = []
|
||||
for ext in self.CODE_EXTENSIONS:
|
||||
code_files.extend(repo_path_obj.rglob(f'*{ext}'))
|
||||
|
||||
self.logger.info(f"Found {len(code_files)} code files")
|
||||
|
||||
for code_file in tqdm(code_files, desc="Processing code files"):
|
||||
try:
|
||||
if self._should_exclude_file(str(code_file.relative_to(repo_path))):
|
||||
continue
|
||||
|
||||
sample = self._process_code_file(code_file, repo_path_obj, config)
|
||||
if sample:
|
||||
code_samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process {code_file}: {str(e)}")
|
||||
continue
|
||||
|
||||
return code_samples
|
||||
|
||||
def _should_exclude_file(self, relative_path: str) -> bool:
|
||||
"""Check if a file should be excluded based on patterns"""
|
||||
for pattern in self.EXCLUDE_PATTERNS:
|
||||
if re.search(pattern, relative_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_code_file(self, file_path: Path, repo_path: Path, config: AppConfig) -> Optional[Dict]:
|
||||
"""
|
||||
Process a single code file into a training sample in QA ChatML format using Ollama
|
||||
|
||||
Args:
|
||||
file_path: Path to the code file
|
||||
repo_path: Path to the repository root
|
||||
config: Training configuration
|
||||
|
||||
Returns:
|
||||
Dictionary containing the processed sample in QA ChatML format or None if invalid
|
||||
"""
|
||||
try:
|
||||
# Read file content
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# Skip if file is too small or too large
|
||||
if len(content.strip()) < 10:
|
||||
return None
|
||||
if len(content) > config.model.max_seq_length * 4: # Rough character limit
|
||||
return None
|
||||
|
||||
# Get relative path for context
|
||||
relative_path = file_path.relative_to(repo_path)
|
||||
|
||||
# Determine language
|
||||
extension = file_path.suffix.lower()
|
||||
language = self.CODE_EXTENSIONS.get(extension, 'unknown')
|
||||
|
||||
# Create prompt for Ollama
|
||||
prompt = f"Analyze the following {language} code file '{relative_path}' from repository '{repo_path.name}' and provide a detailed explanation of its purpose, functionality, and key components:\n\n{content}"
|
||||
|
||||
# Call Ollama API
|
||||
ollama_url = "http://localhost:11434/api/generate"
|
||||
ollama_payload = {
|
||||
"model": "qwen2.5-coder:7b", # Default model, can be changed as needed
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(ollama_url, json=ollama_payload, timeout=120)
|
||||
response.raise_for_status()
|
||||
ollama_response = response.json()
|
||||
analysis = ollama_response.get("response", "No response from model")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error calling Ollama API for {file_path}: {str(e)}")
|
||||
# Fallback to simple template if Ollama is not available
|
||||
analysis = f"This {language} code file '{relative_path}' from repository '{repo_path.name}' contains the following implementation:\n\n{content}"
|
||||
|
||||
# Create QA ChatML format
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are an expert {language} programmer. Analyze code and explain its purpose and functionality."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Analyze the {language} code file '{relative_path}' from the repository '{repo_path.name}':\n\n{content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": analysis
|
||||
}
|
||||
]
|
||||
|
||||
# Create training sample
|
||||
sample = {
|
||||
'messages': messages,
|
||||
'language': language,
|
||||
'file_path': str(relative_path),
|
||||
'repo_name': repo_path.name,
|
||||
'file_size': len(content),
|
||||
'line_count': len(content.splitlines())
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error processing {file_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up temporary directories"""
|
||||
for temp_dir in self.temp_dirs:
|
||||
try:
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to cleanup {temp_dir}: {str(e)}")
|
||||
self.temp_dirs.clear()
|
||||
|
||||
def save_dataset(self, dataset: Dataset, path: str) -> None:
|
||||
"""
|
||||
Save the processed dataset to disk
|
||||
|
||||
Args:
|
||||
dataset: The processed dataset to save
|
||||
path: The path where to save the dataset
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"Saving dataset to {path}")
|
||||
dataset.save_to_disk(path)
|
||||
self.logger.info("Dataset saved successfully")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to save dataset: {str(e)}")
|
||||
raise
|
||||
|
||||
def load_dataset(self, path: str) -> Dataset:
|
||||
"""
|
||||
Load a previously saved dataset from disk
|
||||
|
||||
Args:
|
||||
path: The path from where to load the dataset
|
||||
|
||||
Returns:
|
||||
The loaded dataset
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"Loading dataset from {path}")
|
||||
dataset = Dataset.load_from_disk(path)
|
||||
self.logger.info("Dataset loaded successfully")
|
||||
return dataset
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load dataset: {str(e)}")
|
||||
raise
|
||||
50
src/main.py
50
src/main.py
@ -12,6 +12,7 @@ os.environ['TORCH_COMPILE_DISABLE'] = '1'
|
||||
|
||||
from trainer import ModelTrainer
|
||||
from dataset_processor import DatasetProcessor
|
||||
from dataset_processor_synthetic import DatasetProcessorSynthetic
|
||||
from config import AppConfig
|
||||
from utils import setup_logging, check_gpu_memory
|
||||
|
||||
@ -71,6 +72,21 @@ def parse_arguments():
|
||||
help="GitHub token for accessing private repositories"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--processor_type",
|
||||
type=str,
|
||||
default="standard",
|
||||
choices=["standard", "synthetic"],
|
||||
help="Type of dataset processor to use"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save/load dataset (if specified, will save processed dataset or load existing one)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -85,6 +101,9 @@ def main():
|
||||
logger.info("Starting AI Trainer for Qwen2.5-Coder-7B-Instruct-bnb-4bit")
|
||||
logger.info(f"Repository 1: {args.repo1}")
|
||||
logger.info(f"Repository 2: {args.repo2}")
|
||||
logger.info(f"Processor type: {args.processor_type}")
|
||||
if args.dataset_path:
|
||||
logger.info(f"Dataset path: {args.dataset_path}")
|
||||
|
||||
try:
|
||||
# Check GPU memory
|
||||
@ -101,14 +120,29 @@ def main():
|
||||
logger.info("Configuration loaded successfully")
|
||||
|
||||
# Process datasets from GitHub repositories
|
||||
dataset_processor = DatasetProcessor()
|
||||
logger.info("Processing datasets from GitHub repositories...")
|
||||
|
||||
train_dataset = dataset_processor.process_github_repos(
|
||||
repo_urls=[args.repo1, args.repo2],
|
||||
config=config,
|
||||
github_token=args.github_token
|
||||
)
|
||||
if args.processor_type == "synthetic":
|
||||
dataset_processor = DatasetProcessorSynthetic()
|
||||
else:
|
||||
dataset_processor = DatasetProcessor()
|
||||
|
||||
logger.info(f"Using {args.processor_type} dataset processor")
|
||||
|
||||
# Check if we should load a saved dataset
|
||||
if args.dataset_path and os.path.exists(args.dataset_path):
|
||||
logger.info(f"Loading dataset from {args.dataset_path}")
|
||||
train_dataset = dataset_processor.load_dataset(args.dataset_path)
|
||||
else:
|
||||
logger.info("Processing datasets from GitHub repositories...")
|
||||
train_dataset = dataset_processor.process_github_repos(
|
||||
repo_urls=[args.repo1, args.repo2],
|
||||
config=config,
|
||||
github_token=args.github_token
|
||||
)
|
||||
|
||||
# Save dataset if path is specified
|
||||
if args.dataset_path:
|
||||
logger.info(f"Saving dataset to {args.dataset_path}")
|
||||
dataset_processor.save_dataset(train_dataset, args.dataset_path)
|
||||
|
||||
logger.info(f"Dataset processed successfully. Size: {len(train_dataset)}")
|
||||
|
||||
|
||||
@ -382,6 +382,9 @@ class ModelTrainer:
|
||||
# Save the model
|
||||
self.model.save_pretrained(str(final_model_dir))
|
||||
self.tokenizer.save_pretrained(str(final_model_dir))
|
||||
self.model.save_pretrained_gguf(str(final_model_dir), self.tokenizer, quantization_method = "q4_k_m")
|
||||
self.model.save_pretrained_gguf(str(final_model_dir), self.tokenizer, quantization_method = "q8_0")
|
||||
self.model.save_pretrained_gguf(str(final_model_dir), self.tokenizer, quantization_method = "q6_k")
|
||||
|
||||
# Save configuration
|
||||
self.config.save_yaml(final_model_dir / "training_config.yaml")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user