ai_github_trainer/src/dataset_processor.py
2025-08-22 21:53:40 +07:00

281 lines
9.1 KiB
Python

"""
Dataset processor for GitHub repositories
Processes code from GitHub repositories into training datasets
"""
import json
import logging
import os
import re
import shutil
import tempfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import git
from datasets import Dataset
from tqdm import tqdm
from config import AppConfig
class DatasetProcessor:
"""Processes GitHub repositories into training datasets"""
# Supported file extensions for code training
CODE_EXTENSIONS = {
'.py': 'python',
'.js': 'javascript',
'.ts': 'typescript',
'.java': 'java',
'.cpp': 'cpp',
'.c': 'c',
'.h': 'c',
'.hpp': 'cpp',
'.cs': 'csharp',
'.php': 'php',
'.rb': 'ruby',
'.go': 'go',
'.rs': 'rust',
'.swift': 'swift',
'.kt': 'kotlin',
'.scala': 'scala',
'.sql': 'sql',
'.sh': 'bash',
'.yaml': 'yaml',
'.yml': 'yaml',
'.json': 'json',
'.xml': 'xml',
'.html': 'html',
'.css': 'css',
'.md': 'markdown'
}
# Files and directories to exclude
EXCLUDE_PATTERNS = [
r'\.git/',
r'__pycache__/',
r'\.pytest_cache/',
r'node_modules/',
r'\.venv/',
r'venv/',
r'\.DS_Store',
r'\.pyc$',
r'\.pyo$',
r'\.pyd$',
r'\.so$',
r'\.dll$',
r'\.exe$',
r'\.bin$',
r'package-lock\.json$',
r'yarn\.lock$',
r'\.log$',
r'\.tmp$',
r'\.bak$',
r'~\$.*',
r'\.swp$',
r'\.swo$'
]
def __init__(self):
self.logger = logging.getLogger(__name__)
self.temp_dirs = []
def process_github_repos(self, repo_urls: List[str], config: AppConfig, github_token: Optional[str] = None) -> Dataset:
"""
Process multiple GitHub repositories into a training dataset
Args:
repo_urls: List of GitHub repository URLs
config: Training configuration
github_token: Optional GitHub token for accessing private repositories
Returns:
Dataset ready for training
"""
all_code_samples = []
for repo_url in repo_urls:
try:
self.logger.info(f"Processing repository: {repo_url}")
repo_samples = self._process_single_repo(repo_url, config, github_token)
all_code_samples.extend(repo_samples)
self.logger.info(f"Extracted {len(repo_samples)} samples from {repo_url}")
except Exception as e:
self.logger.error(f"Failed to process repository {repo_url}: {str(e)}")
continue
if not all_code_samples:
raise ValueError("No code samples extracted from any repository")
self.logger.info(f"Total samples collected: {len(all_code_samples)}")
# Create HuggingFace dataset
dataset = Dataset.from_list(all_code_samples)
# Filter by sequence length
dataset = dataset.filter(
lambda x: len(x['text'].split()) <= config.model.max_seq_length
)
self.logger.info(f"Dataset size after filtering: {len(dataset)}")
return dataset
def _process_single_repo(self, repo_url: str, config: AppConfig, github_token: Optional[str] = None) -> List[Dict]:
"""
Process a single GitHub repository
Args:
repo_url: GitHub repository URL
config: Training configuration
github_token: Optional GitHub token for accessing private repositories
Returns:
List of code samples with metadata
"""
# Create a persistent directory for cloned repositories
gitclone_dir = Path("./gitclone")
gitclone_dir.mkdir(exist_ok=True)
temp_dir = str(gitclone_dir)
# Note: We don't add this to temp_dirs since we want to keep it
depth = 1
branch = "18.0"
try:
# Clone repository
repo_name = repo_url.split('/')[-1].replace('.git', '')
repo_path = os.path.join(temp_dir, repo_name)
if not os.path.exists(repo_path):
self.logger.info(f"Cloning {repo_url} to {repo_path}")
# Use token for private repositories if provided
clone_url = repo_url
if github_token and "github.com" in repo_url:
# Handle SSH URLs
if repo_url.startswith("git@"):
# SSH URL doesn't need token modification
pass
else:
# Add token to HTTPS URL
if repo_url.startswith("https://"):
clone_url = repo_url.replace("https://", f"https://{github_token}@")
elif repo_url.startswith("http://"):
clone_url = repo_url.replace("http://", f"http://{github_token}@")
else:
# For URLs like "github.com/user/repo" or "user/repo"
if repo_url.startswith("github.com/"):
clone_url = f"https://{github_token}@{repo_url}"
else:
# Assume it's a GitHub path like "user/repo"
clone_url = f"https://{github_token}@github.com/{repo_url}"
repo = git.Repo.clone_from(clone_url, repo_path, depth=depth, branch=branch)
# Extract code samples
code_samples = self._extract_code_samples(repo_path, config)
return code_samples
finally:
# Cleanup temporary directories, but keep gitclone folder
# if temp_dir != "./gitclone":
# shutil.rmtree(temp_dir, ignore_errors=True)
def _extract_code_samples(self, repo_path: str, config: AppConfig) -> List[Dict]:
"""
Extract code samples from a repository
Args:
repo_path: Path to cloned repository
config: Training configuration
Returns:
List of code samples
"""
code_samples = []
repo_path_obj = Path(repo_path)
# Find all code files
code_files = []
for ext in self.CODE_EXTENSIONS:
code_files.extend(repo_path_obj.rglob(f'*{ext}'))
self.logger.info(f"Found {len(code_files)} code files")
for code_file in tqdm(code_files, desc="Processing code files"):
try:
if self._should_exclude_file(str(code_file.relative_to(repo_path))):
continue
sample = self._process_code_file(code_file, repo_path_obj, config)
if sample:
code_samples.append(sample)
except Exception as e:
self.logger.warning(f"Failed to process {code_file}: {str(e)}")
continue
return code_samples
def _should_exclude_file(self, relative_path: str) -> bool:
"""Check if a file should be excluded based on patterns"""
for pattern in self.EXCLUDE_PATTERNS:
if re.search(pattern, relative_path):
return True
return False
def _process_code_file(self, file_path: Path, repo_path: Path, config: AppConfig) -> Optional[Dict]:
"""
Process a single code file into a training sample
Args:
file_path: Path to the code file
repo_path: Path to the repository root
config: Training configuration
Returns:
Dictionary containing the processed sample or None if invalid
"""
try:
# Read file content
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# Skip if file is too small or too large
if len(content.strip()) < 10:
return None
if len(content) > config.model.max_seq_length * 4: # Rough character limit
return None
# Get relative path for context
relative_path = file_path.relative_to(repo_path)
# Determine language
extension = file_path.suffix.lower()
language = self.CODE_EXTENSIONS.get(extension, 'unknown')
# Create training sample
sample = {
'text': content,
'language': language,
'file_path': str(relative_path),
'repo_name': repo_path.name,
'file_size': len(content),
'line_count': len(content.splitlines())
}
return sample
except Exception as e:
self.logger.warning(f"Error processing {file_path}: {str(e)}")
return None
def cleanup(self):
"""Clean up temporary directories"""
for temp_dir in self.temp_dirs:
try:
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
self.logger.warning(f"Failed to cleanup {temp_dir}: {str(e)}")
self.temp_dirs.clear()