ai_github_trainer/example_dataset_processing.py

78 lines
2.5 KiB
Python

#!/usr/bin/env python3
"""
Example script demonstrating how to get and process a dataset from GitHub repositories
using the DatasetProcessor class.
"""
import sys
from pathlib import Path
# Add src to path to import our modules
sys.path.append(str(Path(__file__).parent))
from src.dataset_processor import DatasetProcessor
from src.config import AppConfig, ModelConfig, TrainingConfig, DatasetConfig, MemoryConfig
def main():
# Initialize configuration
config = AppConfig(
model=ModelConfig(
name="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
max_seq_length=2048
),
training=TrainingConfig(),
dataset=DatasetConfig(),
memory=MemoryConfig()
)
# Initialize dataset processor
processor = DatasetProcessor()
# Example GitHub repositories to process
# Replace these with your own repositories
repo_urls = [
"https://github.com/karpathy/nanoGPT.git",
# "https://github.com/your-username/your-repo.git"
]
try:
print("Processing GitHub repositories...")
dataset = processor.process_github_repos(
repo_urls=repo_urls,
config=config,
github_token=None # Add your token here if processing private repositories
)
print(f"Dataset processed successfully!")
print(f"Dataset size: {len(dataset)} samples")
# Show some examples from the dataset
print("\nFirst 3 samples from the dataset:")
for i in range(min(3, len(dataset))):
sample = dataset[i]
print(f"\nSample {i+1}:")
print(f" Repository: {sample['repo_name']}")
print(f" File path: {sample['file_path']}")
print(f" Language: {sample['language']}")
print(f" File size: {sample['file_size']} characters")
print(f" Lines: {sample['line_count']}")
# Show first 200 characters of the text
preview_text = sample['text'][:200] + "..." if len(sample['text']) > 200 else sample['text']
print(f" Text preview: {preview_text}")
# Save dataset to disk (optional)
# dataset.save_to_disk("./processed_dataset")
# print("\nDataset saved to ./processed_dataset")
return dataset
except Exception as e:
print(f"Error processing repositories: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
dataset = main()