2. add DataProcessorSynthetic class to format github repo to QA ChatML format
91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Example script demonstrating how to get and process a dataset from GitHub repositories
|
|
using the DatasetProcessorSynthetic class, and save/load the processed dataset.
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add src to path to import our modules
|
|
sys.path.append(str(Path(__file__).parent))
|
|
|
|
from src.dataset_processor_synthetic import DatasetProcessorSynthetic
|
|
from src.config import AppConfig, ModelConfig, TrainingConfig, DatasetConfig, MemoryConfig
|
|
|
|
|
|
def main():
|
|
# Initialize configuration
|
|
config = AppConfig(
|
|
model=ModelConfig(
|
|
name="unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit",
|
|
max_seq_length=2048
|
|
),
|
|
training=TrainingConfig(),
|
|
dataset=DatasetConfig(),
|
|
memory=MemoryConfig()
|
|
)
|
|
|
|
# Initialize dataset processor
|
|
processor = DatasetProcessorSynthetic()
|
|
|
|
# Example GitHub repositories to process
|
|
# Replace these with your own repositories
|
|
repo_urls = [
|
|
"https://github.com/karpathy/nanoGPT.git",
|
|
# "https://github.com/your-username/your-repo.git"
|
|
]
|
|
|
|
try:
|
|
# Check if a saved dataset exists
|
|
dataset_path = "./processed_synthetic_dataset"
|
|
import os
|
|
if os.path.exists(dataset_path):
|
|
print("Loading previously processed dataset...")
|
|
dataset = processor.load_dataset(dataset_path)
|
|
else:
|
|
print("Processing GitHub repositories...")
|
|
dataset = processor.process_github_repos(
|
|
repo_urls=repo_urls,
|
|
config=config,
|
|
github_token=None # Add your token here if processing private repositories
|
|
)
|
|
|
|
print(f"Dataset processed successfully!")
|
|
print(f"Dataset size: {len(dataset)} samples")
|
|
|
|
# Save dataset to disk for future use
|
|
print(f"Saving dataset to {dataset_path}...")
|
|
processor.save_dataset(dataset, dataset_path)
|
|
print("Dataset saved successfully!")
|
|
|
|
print(f"Dataset loaded with {len(dataset)} samples")
|
|
|
|
# Show some examples from the dataset
|
|
print("\nFirst 2 samples from the dataset:")
|
|
for i in range(min(2, len(dataset))):
|
|
sample = dataset[i]
|
|
print(f"\nSample {i+1}:")
|
|
print(f" Repository: {sample['repo_name']}")
|
|
print(f" File path: {sample['file_path']}")
|
|
print(f" Language: {sample['language']}")
|
|
print(f" File size: {sample['file_size']} characters")
|
|
print(f" Lines: {sample['line_count']}")
|
|
|
|
# Show messages structure
|
|
messages = sample['messages']
|
|
print(f" Messages: {len(messages)} messages")
|
|
for j, message in enumerate(messages):
|
|
print(f" Message {j+1} ({message['role']}): {message['content'][:100]}...")
|
|
|
|
return dataset
|
|
|
|
except Exception as e:
|
|
print(f"Error processing repositories: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset = main() |