fix some bugs
This commit is contained in:
parent
c73b0d247a
commit
c7a84c520c
141
.gitignore
vendored
Normal file
141
.gitignore
vendored
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
/models
|
||||||
|
/unsloth_compiled_cache
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
build/
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
poetry.lock
|
||||||
|
|
||||||
|
# PEP 582; used by pythonloc
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# VS Code settings
|
||||||
|
.vscode/
|
||||||
31
README.md
31
README.md
@ -32,9 +32,27 @@ A Python application for training various unsloth models using data from GitHub
|
|||||||
- Git
|
- Git
|
||||||
- Dependencies listed in `requirements.txt`
|
- Dependencies listed in `requirements.txt`
|
||||||
|
|
||||||
|
## Private Repository Support
|
||||||
|
|
||||||
|
The application now supports processing private GitHub repositories by using a GitHub token for authentication.
|
||||||
|
To use this feature:
|
||||||
|
|
||||||
|
1. Generate a GitHub personal access token with appropriate permissions
|
||||||
|
2. Pass the token using the `--github_token` command line argument
|
||||||
|
3. Use private repository URLs in the same format as public repositories
|
||||||
|
|
||||||
|
Supported URL formats for private repositories:
|
||||||
|
- `https://github.com/user/private-repo.git`
|
||||||
|
- `github.com/user/private-repo`
|
||||||
|
- `user/private-repo`
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
1. Clone this repository
|
1. Clone this repository
|
||||||
|
2. if have CUDA GPU install PyTorch:
|
||||||
|
```bash
|
||||||
|
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu129
|
||||||
|
```
|
||||||
2. Install dependencies:
|
2. Install dependencies:
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
@ -56,6 +74,12 @@ python src/main.py \
|
|||||||
python run_training.py \
|
python run_training.py \
|
||||||
--repo1 https://github.com/user/repo1 \
|
--repo1 https://github.com/user/repo1 \
|
||||||
--repo2 https://github.com/user/repo2
|
--repo2 https://github.com/user/repo2
|
||||||
|
|
||||||
|
# Using private repositories with a GitHub token
|
||||||
|
python run_training.py \
|
||||||
|
--repo1 https://github.com/user/private-repo1 \
|
||||||
|
--repo2 https://github.com/user/private-repo2 \
|
||||||
|
--github_token YOUR_GITHUB_TOKEN
|
||||||
```
|
```
|
||||||
|
|
||||||
### Training Qwen3-8B
|
### Training Qwen3-8B
|
||||||
@ -72,6 +96,12 @@ python src/main.py \
|
|||||||
python run_training_qwen3.py \
|
python run_training_qwen3.py \
|
||||||
--repo1 https://github.com/user/repo1 \
|
--repo1 https://github.com/user/repo1 \
|
||||||
--repo2 https://github.com/user/repo2
|
--repo2 https://github.com/user/repo2
|
||||||
|
|
||||||
|
# Using private repositories with a GitHub token
|
||||||
|
python run_training_qwen3.py \
|
||||||
|
--repo1 https://github.com/user/private-repo1 \
|
||||||
|
--repo2 https://github.com/user/private-repo2 \
|
||||||
|
--github_token YOUR_GITHUB_TOKEN
|
||||||
```
|
```
|
||||||
|
|
||||||
### Command Line Arguments
|
### Command Line Arguments
|
||||||
@ -81,6 +111,7 @@ python run_training_qwen3.py \
|
|||||||
- `--config`: Path to training configuration file (default: configs/training_config.yaml)
|
- `--config`: Path to training configuration file (default: configs/training_config.yaml)
|
||||||
- `--output_dir`: Directory to save trained model (default: ./models)
|
- `--output_dir`: Directory to save trained model (default: ./models)
|
||||||
- `--log_level`: Logging level (DEBUG, INFO, WARNING, ERROR)
|
- `--log_level`: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
- `--github_token`: GitHub token for accessing private repositories (optional)
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
# Core ML libraries
|
# Core ML libraries
|
||||||
torch>=2.1.0
|
# torch>=2.1.0
|
||||||
torchvision>=0.16.0
|
# torchvision>=0.16.0
|
||||||
torchaudio>=2.1.0
|
# torchaudio>=2.1.0
|
||||||
|
|
||||||
# Unsloth for efficient model training
|
# Unsloth for efficient model training
|
||||||
unsloth[cu121]>=2024.5
|
unsloth[cu129]>=2024.5
|
||||||
unsloth_zoo>=2024.5
|
unsloth_zoo>=2024.5
|
||||||
|
|
||||||
# Transformers and tokenizers
|
# Transformers and tokenizers
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import git
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from config import TrainingConfig
|
from config import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class DatasetProcessor:
|
class DatasetProcessor:
|
||||||
@ -81,13 +81,14 @@ class DatasetProcessor:
|
|||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
self.temp_dirs = []
|
self.temp_dirs = []
|
||||||
|
|
||||||
def process_github_repos(self, repo_urls: List[str], config: TrainingConfig) -> Dataset:
|
def process_github_repos(self, repo_urls: List[str], config: AppConfig, github_token: Optional[str] = None) -> Dataset:
|
||||||
"""
|
"""
|
||||||
Process multiple GitHub repositories into a training dataset
|
Process multiple GitHub repositories into a training dataset
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_urls: List of GitHub repository URLs
|
repo_urls: List of GitHub repository URLs
|
||||||
config: Training configuration
|
config: Training configuration
|
||||||
|
github_token: Optional GitHub token for accessing private repositories
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset ready for training
|
Dataset ready for training
|
||||||
@ -97,7 +98,7 @@ class DatasetProcessor:
|
|||||||
for repo_url in repo_urls:
|
for repo_url in repo_urls:
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Processing repository: {repo_url}")
|
self.logger.info(f"Processing repository: {repo_url}")
|
||||||
repo_samples = self._process_single_repo(repo_url, config)
|
repo_samples = self._process_single_repo(repo_url, config, github_token)
|
||||||
all_code_samples.extend(repo_samples)
|
all_code_samples.extend(repo_samples)
|
||||||
self.logger.info(f"Extracted {len(repo_samples)} samples from {repo_url}")
|
self.logger.info(f"Extracted {len(repo_samples)} samples from {repo_url}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -120,13 +121,14 @@ class DatasetProcessor:
|
|||||||
self.logger.info(f"Dataset size after filtering: {len(dataset)}")
|
self.logger.info(f"Dataset size after filtering: {len(dataset)}")
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def _process_single_repo(self, repo_url: str, config: TrainingConfig) -> List[Dict]:
|
def _process_single_repo(self, repo_url: str, config: AppConfig, github_token: Optional[str] = None) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Process a single GitHub repository
|
Process a single GitHub repository
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_url: GitHub repository URL
|
repo_url: GitHub repository URL
|
||||||
config: Training configuration
|
config: Training configuration
|
||||||
|
github_token: Optional GitHub token for accessing private repositories
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of code samples with metadata
|
List of code samples with metadata
|
||||||
@ -134,13 +136,38 @@ class DatasetProcessor:
|
|||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
self.temp_dirs.append(temp_dir)
|
self.temp_dirs.append(temp_dir)
|
||||||
|
|
||||||
|
depth = 1
|
||||||
|
branch = "18.0"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Clone repository
|
# Clone repository
|
||||||
repo_name = repo_url.split('/')[-1].replace('.git', '')
|
repo_name = repo_url.split('/')[-1].replace('.git', '')
|
||||||
repo_path = os.path.join(temp_dir, repo_name)
|
repo_path = os.path.join(temp_dir, repo_name)
|
||||||
|
|
||||||
self.logger.info(f"Cloning {repo_url} to {repo_path}")
|
self.logger.info(f"Cloning {repo_url} to {repo_path}")
|
||||||
repo = git.Repo.clone_from(repo_url, repo_path)
|
|
||||||
|
# Use token for private repositories if provided
|
||||||
|
clone_url = repo_url
|
||||||
|
if github_token and "github.com" in repo_url:
|
||||||
|
# Handle SSH URLs
|
||||||
|
if repo_url.startswith("git@"):
|
||||||
|
# SSH URL doesn't need token modification
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Add token to HTTPS URL
|
||||||
|
if repo_url.startswith("https://"):
|
||||||
|
clone_url = repo_url.replace("https://", f"https://{github_token}@")
|
||||||
|
elif repo_url.startswith("http://"):
|
||||||
|
clone_url = repo_url.replace("http://", f"http://{github_token}@")
|
||||||
|
else:
|
||||||
|
# For URLs like "github.com/user/repo" or "user/repo"
|
||||||
|
if repo_url.startswith("github.com/"):
|
||||||
|
clone_url = f"https://{github_token}@{repo_url}"
|
||||||
|
else:
|
||||||
|
# Assume it's a GitHub path like "user/repo"
|
||||||
|
clone_url = f"https://{github_token}@github.com/{repo_url}"
|
||||||
|
|
||||||
|
repo = git.Repo.clone_from(clone_url, repo_path, depth=depth, branch=branch)
|
||||||
|
|
||||||
# Extract code samples
|
# Extract code samples
|
||||||
code_samples = self._extract_code_samples(repo_path, config)
|
code_samples = self._extract_code_samples(repo_path, config)
|
||||||
@ -151,7 +178,7 @@ class DatasetProcessor:
|
|||||||
# Cleanup
|
# Cleanup
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
def _extract_code_samples(self, repo_path: str, config: TrainingConfig) -> List[Dict]:
|
def _extract_code_samples(self, repo_path: str, config: AppConfig) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Extract code samples from a repository
|
Extract code samples from a repository
|
||||||
|
|
||||||
@ -194,7 +221,7 @@ class DatasetProcessor:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _process_code_file(self, file_path: Path, repo_path: Path, config: TrainingConfig) -> Optional[Dict]:
|
def _process_code_file(self, file_path: Path, repo_path: Path, config: AppConfig) -> Optional[Dict]:
|
||||||
"""
|
"""
|
||||||
Process a single code file into a training sample
|
Process a single code file into a training sample
|
||||||
|
|
||||||
|
|||||||
19
src/main.py
19
src/main.py
@ -15,7 +15,7 @@ sys.path.append(str(Path(__file__).parent))
|
|||||||
|
|
||||||
from trainer import ModelTrainer
|
from trainer import ModelTrainer
|
||||||
from dataset_processor import DatasetProcessor
|
from dataset_processor import DatasetProcessor
|
||||||
from config import TrainingConfig
|
from config import AppConfig
|
||||||
from utils import setup_logging, check_gpu_memory
|
from utils import setup_logging, check_gpu_memory
|
||||||
|
|
||||||
|
|
||||||
@ -59,6 +59,13 @@ def parse_arguments():
|
|||||||
help="Logging level"
|
help="Logging level"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--github_token",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="GitHub token for accessing private repositories"
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -80,7 +87,12 @@ def main():
|
|||||||
logger.info(f"GPU Memory Info: {gpu_info}")
|
logger.info(f"GPU Memory Info: {gpu_info}")
|
||||||
|
|
||||||
# Load configuration
|
# Load configuration
|
||||||
config = TrainingConfig.from_yaml(args.config)
|
logger.debug(f"Attempting to load config from: {args.config}")
|
||||||
|
logger.debug(f"AppConfig methods: {[m for m in dir(AppConfig) if not m.startswith('_')]}")
|
||||||
|
|
||||||
|
# Load configuration using AppConfig
|
||||||
|
config = AppConfig.from_yaml(args.config)
|
||||||
|
|
||||||
logger.info("Configuration loaded successfully")
|
logger.info("Configuration loaded successfully")
|
||||||
|
|
||||||
# Process datasets from GitHub repositories
|
# Process datasets from GitHub repositories
|
||||||
@ -89,7 +101,8 @@ def main():
|
|||||||
|
|
||||||
train_dataset = dataset_processor.process_github_repos(
|
train_dataset = dataset_processor.process_github_repos(
|
||||||
repo_urls=[args.repo1, args.repo2],
|
repo_urls=[args.repo1, args.repo2],
|
||||||
config=config
|
config=config,
|
||||||
|
github_token=args.github_token
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Dataset processed successfully. Size: {len(train_dataset)}")
|
logger.info(f"Dataset processed successfully. Size: {len(train_dataset)}")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user