Source code for sharkpy.battle

# sharkpy/battle.py

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Union, Dict, Optional
from .learning import learn

[docs] MODEL_DETAILS = { 'linear_regression': { 'name': 'Linear Regression', 'strengths': 'Simple, interpretable, fast', 'best_for': 'Linear relationships, baseline models', 'data_size': 'Any size', 'training_speed': 'Very fast' }, 'logistic_regression': { 'name': 'Logistic Regression', 'strengths': 'Simple, probabilistic output', 'best_for': 'Binary/multiclass classification, baseline models', 'data_size': 'Any size', 'training_speed': 'Very fast' }, 'random_forest': { 'name': 'Random Forest', 'strengths': 'Robust, handles non-linearity, feature importance', 'best_for': 'Complex relationships, feature selection', 'data_size': 'Medium to large', 'training_speed': 'Fast' }, 'svm': { 'name': 'Support Vector Machine', 'strengths': 'Effective in high dimensions', 'best_for': 'Non-linear relationships, high-dimensional data', 'data_size': 'Small to medium', 'training_speed': 'Slow for large datasets' }, 'ridge': { 'name': 'Ridge Regression', 'strengths': 'L2 regularization, handles multicollinearity', 'best_for': 'Linear relationships with correlated features', 'data_size': 'Any size', 'training_speed': 'Fast' }, 'lasso': { 'name': 'Lasso Regression', 'strengths': 'L1 regularization, feature selection', 'best_for': 'Linear relationships with feature selection', 'data_size': 'Any size', 'training_speed': 'Fast' }, 'knn': { 'name': 'K-Nearest Neighbors', 'strengths': 'Simple, non-parametric', 'best_for': 'Pattern recognition, simple non-linear relationships', 'data_size': 'Small to medium', 'training_speed': 'Very fast (but slow predictions)' }, 'xgboost': { 'name': 'XGBoost', 'strengths': 'High performance, feature importance, handles missing values', 'best_for': 'Complex relationships, competitions', 'data_size': 'Medium to large', 'training_speed': 'Medium (with GPU support)' }, 'lightgbm': { 'name': 'LightGBM', 'strengths': 'Fast training, low memory usage', 'best_for': 'Large datasets, speed-critical applications', 'data_size': 'Large', 'training_speed': 'Very fast' }, 'catboost': { 'name': 'CatBoost', 'strengths': 'Handles categorical features, reduced overfitting', 'best_for': 'Datasets with many categorical features', 'data_size': 'Medium to large', 'training_speed': 'Medium' } }
[docs] def battle(self, data: pd.DataFrame, target: str, models: List[str] = ['linear_regression', 'random_forest', 'xgboost'], metric: str = 'r2', n_trials: int = 30, early_stopping: bool = False, min_score: float = 0.5, verbose: int = 0) -> Dict: """Battle multiple models against each other and return the champion. Parameters ---------- data : pd.DataFrame Input data for training target : str Name of target column models : list List of model names to compete metric : str Metric to compare models (default: 'r2') n_trials : int Number of optimization trials for boosting models (default: 30) early_stopping : bool, optional If True, stops training if any model exceeds `min_score`. Not recommended as it may miss better models later (default: False) min_score : float Minimum score to trigger early stopping (ignored if early_stopping=False) (default: 0.5) verbose : int Verbosity level for model training (default: 0) Returns ------- dict Dictionary containing champion model name, model object, score, all results, details, and comparison plot """ print("🦈⚔️ MODEL BATTLE ROYALE ⚔️🦈") print("Preparing the arena...") # Store data in main instance for later use if isinstance(data, pd.DataFrame): self.features = data.drop(columns=[target]) self.target = data[target] else: raise ValueError("🦈 Data must be a DataFrame!") # Check data size suitability data_size = len(data) for model_name in models: details = MODEL_DETAILS.get(model_name, {}) recommended_size = details.get('data_size', 'Any size') if recommended_size == 'Small to medium' and data_size > 10000: print(f"⚠️ Warning: {model_name} is best for small to medium datasets, but you have {data_size} rows.") elif recommended_size == 'Medium to large' and data_size < 1000: print(f"⚠️ Warning: {model_name} performs better with medium to large datasets, but you have {data_size} rows.") elif recommended_size == 'Large' and data_size < 10000: print(f"⚠️ Warning: {model_name} is optimized for large datasets, but you have {data_size} rows.") results = {} best_score = -float('inf') best_model = None # Train and evaluate all models for model_name in models: print(f"\n🗡️ {model_name.upper()} enters the fray!") details = MODEL_DETAILS.get(model_name, {}) print(f" Strengths: {details.get('strengths', 'Unknown')}") print(f" Best for: {details.get('best_for', 'Unknown')}") # Create a temporary Shark instance for each model temp_shark = self.__class__() try: # Train the model temp_shark.learn( data=data, target=target, model_choice=model_name, n_trials=n_trials if model_name in ['xgboost', 'lightgbm', 'catboost'] else None, verbose=verbose # Add this line ) # Get metrics cv_results, _ = temp_shark.report(cv_folds=5, export_path=None) # Get the appropriate metric score if temp_shark.problem_type == "regression": score = cv_results['test_r2'].mean() else: score = cv_results['test_accuracy'].mean() results[model_name] = score print(f" Score: {score:.4f}") # Update best model if score is better if score > best_score: best_score = score best_model = temp_shark.model self.problem_type = temp_shark.problem_type except Exception as e: print(f" ❌ {model_name} failed: {str(e)}") results[model_name] = None # Find the champion among all models valid_results = {k: v for k, v in results.items() if v is not None} if valid_results: champion = max(valid_results, key=valid_results.get) print(f"\n🏆 CHAMPION: {champion.upper()} with {metric}: {valid_results[champion]:.4f}") # Store the winning model and its data self.model = best_model return { 'champion': champion, 'model': self.model, 'score': valid_results[champion], 'all_results': results, 'details': MODEL_DETAILS.get(champion, {}), 'comparison': _visualize_battle_results(valid_results) } else: print("\n💀 ALL MODELS FAILED! The battle is a draw.") print("💡 Suggestion: Try simpler models (e.g., linear_regression) or check your data for issues like missing values or incorrect formats.") return {'champion': None, 'all_results': results}
[docs] def _visualize_battle_results(results): """Create a bar plot of model performances.""" valid_results = {k: v for k, v in results.items() if v is not None} plt.figure(figsize=(10, 6)) sns.barplot(x=list(valid_results.keys()), y=list(valid_results.values())) plt.title("🦈 Model Battle Results 🦈") plt.xticks(rotation=45) plt.tight_layout() return plt.gcf()