281 lines
9.1 KiB
Python
281 lines
9.1 KiB
Python
"""
|
|
Dataset processor for GitHub repositories
|
|
Processes code from GitHub repositories into training datasets
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import git
|
|
from datasets import Dataset
|
|
from tqdm import tqdm
|
|
|
|
from config import AppConfig
|
|
|
|
|
|
class DatasetProcessor:
|
|
"""Processes GitHub repositories into training datasets"""
|
|
|
|
# Supported file extensions for code training
|
|
CODE_EXTENSIONS = {
|
|
'.py': 'python',
|
|
'.js': 'javascript',
|
|
'.ts': 'typescript',
|
|
'.java': 'java',
|
|
'.cpp': 'cpp',
|
|
'.c': 'c',
|
|
'.h': 'c',
|
|
'.hpp': 'cpp',
|
|
'.cs': 'csharp',
|
|
'.php': 'php',
|
|
'.rb': 'ruby',
|
|
'.go': 'go',
|
|
'.rs': 'rust',
|
|
'.swift': 'swift',
|
|
'.kt': 'kotlin',
|
|
'.scala': 'scala',
|
|
'.sql': 'sql',
|
|
'.sh': 'bash',
|
|
'.yaml': 'yaml',
|
|
'.yml': 'yaml',
|
|
'.json': 'json',
|
|
'.xml': 'xml',
|
|
'.html': 'html',
|
|
'.css': 'css',
|
|
'.md': 'markdown'
|
|
}
|
|
|
|
# Files and directories to exclude
|
|
EXCLUDE_PATTERNS = [
|
|
r'\.git/',
|
|
r'__pycache__/',
|
|
r'\.pytest_cache/',
|
|
r'node_modules/',
|
|
r'\.venv/',
|
|
r'venv/',
|
|
r'\.DS_Store',
|
|
r'\.pyc$',
|
|
r'\.pyo$',
|
|
r'\.pyd$',
|
|
r'\.so$',
|
|
r'\.dll$',
|
|
r'\.exe$',
|
|
r'\.bin$',
|
|
r'package-lock\.json$',
|
|
r'yarn\.lock$',
|
|
r'\.log$',
|
|
r'\.tmp$',
|
|
r'\.bak$',
|
|
r'~\$.*',
|
|
r'\.swp$',
|
|
r'\.swo$'
|
|
]
|
|
|
|
def __init__(self):
|
|
self.logger = logging.getLogger(__name__)
|
|
self.temp_dirs = []
|
|
|
|
def process_github_repos(self, repo_urls: List[str], config: AppConfig, github_token: Optional[str] = None) -> Dataset:
|
|
"""
|
|
Process multiple GitHub repositories into a training dataset
|
|
|
|
Args:
|
|
repo_urls: List of GitHub repository URLs
|
|
config: Training configuration
|
|
github_token: Optional GitHub token for accessing private repositories
|
|
|
|
Returns:
|
|
Dataset ready for training
|
|
"""
|
|
all_code_samples = []
|
|
|
|
for repo_url in repo_urls:
|
|
try:
|
|
self.logger.info(f"Processing repository: {repo_url}")
|
|
repo_samples = self._process_single_repo(repo_url, config, github_token)
|
|
all_code_samples.extend(repo_samples)
|
|
self.logger.info(f"Extracted {len(repo_samples)} samples from {repo_url}")
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to process repository {repo_url}: {str(e)}")
|
|
continue
|
|
|
|
if not all_code_samples:
|
|
raise ValueError("No code samples extracted from any repository")
|
|
|
|
self.logger.info(f"Total samples collected: {len(all_code_samples)}")
|
|
|
|
# Create HuggingFace dataset
|
|
dataset = Dataset.from_list(all_code_samples)
|
|
|
|
# Filter by sequence length
|
|
dataset = dataset.filter(
|
|
lambda x: len(x['text'].split()) <= config.model.max_seq_length
|
|
)
|
|
|
|
self.logger.info(f"Dataset size after filtering: {len(dataset)}")
|
|
return dataset
|
|
|
|
def _process_single_repo(self, repo_url: str, config: AppConfig, github_token: Optional[str] = None) -> List[Dict]:
|
|
"""
|
|
Process a single GitHub repository
|
|
|
|
Args:
|
|
repo_url: GitHub repository URL
|
|
config: Training configuration
|
|
github_token: Optional GitHub token for accessing private repositories
|
|
|
|
Returns:
|
|
List of code samples with metadata
|
|
"""
|
|
# Create a persistent directory for cloned repositories
|
|
gitclone_dir = Path("./gitclone")
|
|
gitclone_dir.mkdir(exist_ok=True)
|
|
temp_dir = str(gitclone_dir)
|
|
# Note: We don't add this to temp_dirs since we want to keep it
|
|
|
|
depth = 1
|
|
branch = "18.0"
|
|
|
|
try:
|
|
# Clone repository
|
|
repo_name = repo_url.split('/')[-1].replace('.git', '')
|
|
repo_path = os.path.join(temp_dir, repo_name)
|
|
if not os.path.exists(repo_path):
|
|
self.logger.info(f"Cloning {repo_url} to {repo_path}")
|
|
|
|
# Use token for private repositories if provided
|
|
clone_url = repo_url
|
|
if github_token and "github.com" in repo_url:
|
|
# Handle SSH URLs
|
|
if repo_url.startswith("git@"):
|
|
# SSH URL doesn't need token modification
|
|
pass
|
|
else:
|
|
# Add token to HTTPS URL
|
|
if repo_url.startswith("https://"):
|
|
clone_url = repo_url.replace("https://", f"https://{github_token}@")
|
|
elif repo_url.startswith("http://"):
|
|
clone_url = repo_url.replace("http://", f"http://{github_token}@")
|
|
else:
|
|
# For URLs like "github.com/user/repo" or "user/repo"
|
|
if repo_url.startswith("github.com/"):
|
|
clone_url = f"https://{github_token}@{repo_url}"
|
|
else:
|
|
# Assume it's a GitHub path like "user/repo"
|
|
clone_url = f"https://{github_token}@github.com/{repo_url}"
|
|
|
|
repo = git.Repo.clone_from(clone_url, repo_path, depth=depth, branch=branch)
|
|
|
|
# Extract code samples
|
|
code_samples = self._extract_code_samples(repo_path, config)
|
|
|
|
return code_samples
|
|
|
|
finally:
|
|
# Cleanup temporary directories, but keep gitclone folder
|
|
# if temp_dir != "./gitclone":
|
|
# shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
def _extract_code_samples(self, repo_path: str, config: AppConfig) -> List[Dict]:
|
|
"""
|
|
Extract code samples from a repository
|
|
|
|
Args:
|
|
repo_path: Path to cloned repository
|
|
config: Training configuration
|
|
|
|
Returns:
|
|
List of code samples
|
|
"""
|
|
code_samples = []
|
|
repo_path_obj = Path(repo_path)
|
|
|
|
# Find all code files
|
|
code_files = []
|
|
for ext in self.CODE_EXTENSIONS:
|
|
code_files.extend(repo_path_obj.rglob(f'*{ext}'))
|
|
|
|
self.logger.info(f"Found {len(code_files)} code files")
|
|
|
|
for code_file in tqdm(code_files, desc="Processing code files"):
|
|
try:
|
|
if self._should_exclude_file(str(code_file.relative_to(repo_path))):
|
|
continue
|
|
|
|
sample = self._process_code_file(code_file, repo_path_obj, config)
|
|
if sample:
|
|
code_samples.append(sample)
|
|
|
|
except Exception as e:
|
|
self.logger.warning(f"Failed to process {code_file}: {str(e)}")
|
|
continue
|
|
|
|
return code_samples
|
|
|
|
def _should_exclude_file(self, relative_path: str) -> bool:
|
|
"""Check if a file should be excluded based on patterns"""
|
|
for pattern in self.EXCLUDE_PATTERNS:
|
|
if re.search(pattern, relative_path):
|
|
return True
|
|
return False
|
|
|
|
def _process_code_file(self, file_path: Path, repo_path: Path, config: AppConfig) -> Optional[Dict]:
|
|
"""
|
|
Process a single code file into a training sample
|
|
|
|
Args:
|
|
file_path: Path to the code file
|
|
repo_path: Path to the repository root
|
|
config: Training configuration
|
|
|
|
Returns:
|
|
Dictionary containing the processed sample or None if invalid
|
|
"""
|
|
try:
|
|
# Read file content
|
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
content = f.read()
|
|
|
|
# Skip if file is too small or too large
|
|
if len(content.strip()) < 10:
|
|
return None
|
|
if len(content) > config.model.max_seq_length * 4: # Rough character limit
|
|
return None
|
|
|
|
# Get relative path for context
|
|
relative_path = file_path.relative_to(repo_path)
|
|
|
|
# Determine language
|
|
extension = file_path.suffix.lower()
|
|
language = self.CODE_EXTENSIONS.get(extension, 'unknown')
|
|
|
|
# Create training sample
|
|
sample = {
|
|
'text': content,
|
|
'language': language,
|
|
'file_path': str(relative_path),
|
|
'repo_name': repo_path.name,
|
|
'file_size': len(content),
|
|
'line_count': len(content.splitlines())
|
|
}
|
|
|
|
return sample
|
|
|
|
except Exception as e:
|
|
self.logger.warning(f"Error processing {file_path}: {str(e)}")
|
|
return None
|
|
|
|
def cleanup(self):
|
|
"""Clean up temporary directories"""
|
|
for temp_dir in self.temp_dirs:
|
|
try:
|
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
except Exception as e:
|
|
self.logger.warning(f"Failed to cleanup {temp_dir}: {str(e)}")
|
|
self.temp_dirs.clear() |