Source code for sharkpy.plotting

# sharkpy/plotting.py

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay, RocCurveDisplay, PrecisionRecallDisplay
import matplotlib.font_manager as fm
import warnings
from typing import Dict, List, Optional, Union

# Suppress glyph warnings if font fallback fails
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

# Set fun, kid-friendly style
sns.set_palette("Set2")  # Accessible, vibrant palette
sns.set_context("notebook", font_scale=1.2)  # Slightly larger text

# Try to use a font that supports emojis
try:
[docs] available_fonts = {f.name for f in fm.fontManager.ttflist}
emoji_fonts = ['Segoe UI Emoji', 'Noto Color Emoji', 'Apple Color Emoji'] selected_font = next((f for f in emoji_fonts if f in available_fonts), 'Arial') plt.rcParams['font.family'] = selected_font except Exception as e: print(f"⚠️ Could not set emoji font, falling back to Arial: {e}") plt.rcParams['font.family'] = 'Arial' # Default SharkPy color scheme (ocean/water theme)
[docs] SHARKPY_COLORS = { 'primary': '#2E86AB', # Deep ocean blue 'secondary': '#A23B72', # Coral pink 'accent': '#F18F01', # Sunny yellow (like shark eyes) 'background': '#F7F7F7', # Light gray background 'grid': '#D3D3D3', # Grid lines 'text': '#2B2D42', # Dark blue-gray text 'bars': ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#6A8EAE', '#56E39F'] # Colorful bars }
[docs] def validate_colors(colors: Dict[str, str]) -> Dict[str, str]: """ Validate color dictionary and fall back to defaults if invalid. Parameters: colors: Dictionary of color specifications Returns: Validated color dictionary """ if not colors or not isinstance(colors, dict): return SHARKPY_COLORS.copy() # Create a copy of default colors and update with provided values validated_colors = SHARKPY_COLORS.copy() for key, value in colors.items(): if key in validated_colors: # Validate color format (basic check) if isinstance(value, str) and (value.startswith('#') or value in plt.colors.cnames): validated_colors[key] = value elif isinstance(value, list) and all(isinstance(c, str) for c in value): validated_colors[key] = value else: print(f"⚠️ Invalid color format for '{key}': {value}. Using default.") return validated_colors
[docs] def plot_model(model, X, y=None, kind="prediction", show=True, save_path=None, feature_names=None, colors: Optional[Dict[str, str]] = None): """ Visualizes model behavior depending on the specified kind. Parameters: model : trained model X : array-like, shape (n_samples, n_features) y : array-like, shape (n_samples,), optional kind : str, one of {"prediction", "residuals", "confusion_matrix", "roc", "pr_curve", "proba_hist", "feature_importance"} show : bool, whether to display the plot save_path : str or None, path to save the plot feature_names : list, optional, names of features for plotting colors : dict, optional, custom color specifications """ print(f"🦈 Debug: plot_model called with kind={kind}") if kind in ["prediction", "residuals", "confusion_matrix", "roc", "pr_curve", "proba_hist"] and y is None: raise ValueError(f"y must be provided for kind='{kind}'.") # Validate and prepare colors plot_colors = validate_colors(colors) # Dispatch to plotting functions if kind == "prediction": _plot_prediction(model, X, y, plot_colors) elif kind == "residuals": _plot_residuals(model, X, y, plot_colors) elif kind == "confusion_matrix": _plot_confusion_matrix(model, X, y, plot_colors) elif kind == "roc": _plot_roc(model, X, y, plot_colors) elif kind == "pr_curve": _plot_pr_curve(model, X, y, plot_colors) elif kind == "proba_hist": _plot_proba_hist(model, X, y, plot_colors) elif kind == "feature_importance": if feature_names is None: if isinstance(X, pd.DataFrame): feature_names = X.columns else: raise ValueError("Feature names required for feature importance plot.") _plot_feature_importance(model, feature_names, plot_colors) else: raise ValueError(f"Unknown kind='{kind}'. Supported kinds: prediction, residuals, confusion_matrix, roc, pr_curve, proba_hist, feature_importance") if save_path: dir_ = os.path.dirname(save_path) if dir_: os.makedirs(dir_, exist_ok=True) plt.savefig(save_path, bbox_inches='tight') if show: plt.show() else: plt.close()
# === Internal Helper Functions === #
[docs] def _plot_prediction(model, X, y, colors): y_pred = model.predict(X) fig, ax = plt.subplots(figsize=(8, 6), facecolor=colors['background']) # Fun scatter plot with star markers ax.scatter(y, y_pred, alpha=0.7, c=colors['primary'], edgecolor='white', s=120, marker='*') # Perfect prediction line with fun style line_range = [min(y), max(y)] ax.plot(line_range, line_range, '--', color=colors['secondary'], linewidth=3, label='Perfect Prediction') # Playful styling ax.set_xlabel("Actual Values", fontsize=12, color=colors['text']) ax.set_ylabel("Predicted Values", fontsize=12, color=colors['text']) ax.set_title("Actual vs Predicted Values\nShark's Predictions", fontsize=14, pad=20, color=colors['text']) ax.grid(True, alpha=0.3, color=colors['grid']) ax.legend(fontsize=10) # Set background color ax.set_facecolor(colors['background']) ax.set_aspect('equal', adjustable='box') plt.tight_layout()
[docs] def _plot_residuals(model, X, y, colors): y_pred = model.predict(X) residuals = y - y_pred fig, ax = plt.subplots(figsize=(8, 6), facecolor=colors['background']) # Fun scatter with different marker ax.scatter(y_pred, residuals, alpha=0.7, c=colors['primary'], edgecolor='white', s=120, marker='o') # Zero line with fun style ax.axhline(y=0, color=colors['secondary'], linestyle='--', linewidth=3, label='Zero Line') # Playful styling ax.set_xlabel("Predicted Values", fontsize=12, color=colors['text']) ax.set_ylabel("Residuals", fontsize=12, color=colors['text']) ax.set_title("Residual Plot\nShark's Accuracy Check", fontsize=14, pad=20, color=colors['text']) ax.grid(True, alpha=0.3, color=colors['grid']) ax.legend(fontsize=10) # Set background color ax.set_facecolor(colors['background']) plt.tight_layout()
[docs] def _plot_confusion_matrix(model, X, y, colors): try: fig, ax = plt.subplots(figsize=(8, 6), facecolor=colors['background']) disp = ConfusionMatrixDisplay.from_estimator( model, X, y, cmap='Blues', # Clearer for confusion matrix ax=ax, colorbar=False ) ax.set_title("Confusion Matrix\nShark's Score Card", fontsize=14, pad=20, color=colors['text']) # Set background color ax.set_facecolor(colors['background']) plt.tight_layout() except Exception as e: print(f"Could not plot confusion matrix: {e}")
[docs] def _plot_roc(model, X, y, colors): try: fig, ax = plt.subplots(figsize=(8, 6), facecolor=colors['background']) RocCurveDisplay.from_estimator(model, X, y, ax=ax) ax.set_title("ROC Curve\nShark's Sensitivity Radar", fontsize=14, pad=20, color=colors['text']) # Set background color ax.set_facecolor(colors['background']) plt.tight_layout() except Exception as e: print(f"Could not plot ROC curve: {e}")
[docs] def _plot_pr_curve(model, X, y, colors): try: fig, ax = plt.subplots(figsize=(8, 6), facecolor=colors['background']) PrecisionRecallDisplay.from_estimator(model, X, y, ax=ax) ax.set_title("Precision-Recall Curve\nShark's Precision Tracker", fontsize=14, pad=20, color=colors['text']) # Set background color ax.set_facecolor(colors['background']) plt.tight_layout() except Exception as e: print(f"Could not plot Precision-Recall curve: {e}")
[docs] def _plot_proba_hist(model, X, y, colors): try: if hasattr(model, "predict_proba"): proba = model.predict_proba(X) if proba.ndim == 2 and proba.shape[1] == 2: pos_proba = proba[:, 1] fig, ax = plt.subplots(figsize=(8, 6), facecolor=colors['background']) ax.hist(pos_proba, bins=20, color=colors['primary'], alpha=0.8) ax.set_title("Predicted Probabilities (Positive Class)\nShark's Confidence Meter", fontsize=14, pad=20, color=colors['text']) ax.set_xlabel("Probability of Positive Class", fontsize=12, color=colors['text']) ax.set_ylabel("Frequency", fontsize=12, color=colors['text']) ax.grid(True, alpha=0.3, color=colors['grid']) # Set background color ax.set_facecolor(colors['background']) plt.tight_layout() else: fig, ax = plt.subplots(figsize=(8, 6), facecolor=colors['background']) for i in range(proba.shape[1]): ax.hist(proba[:, i], bins=20, alpha=0.5, label=f"Class {i}", color=colors['bars'][i % len(colors['bars'])]) ax.legend() ax.set_title("Predicted Probabilities by Class\nShark's Confidence Meter", fontsize=14, pad=20, color=colors['text']) ax.set_xlabel("Predicted Probability", fontsize=12, color=colors['text']) ax.set_ylabel("Frequency", fontsize=12, color=colors['text']) ax.grid(True, alpha=0.3, color=colors['grid']) # Set background color ax.set_facecolor(colors['background']) plt.tight_layout() else: print("Model does not support predict_proba; skipping probability histogram.") except Exception as e: print(f"Could not plot predicted probability histogram: {e}")
[docs] def _plot_feature_importance(model, feature_names, colors): if hasattr(model, "feature_importances_"): importances = model.feature_importances_ ylabel = "Importance Score" elif hasattr(model, "coef_"): importances = np.abs(model.coef_) if len(importances.shape) > 1: # Multi-class logistic regression importances = importances.mean(axis=0) ylabel = "Absolute Coefficient" else: raise AttributeError("Model does not have feature_importances_ or coef_.") indices = np.argsort(importances)[::-1] fig, ax = plt.subplots(figsize=(10, 6), facecolor=colors['background']) # Colorful bars bar_colors = [colors['bars'][i % len(colors['bars'])] for i in range(len(importances))] ax.bar(range(len(importances)), importances[indices], color=bar_colors) # Set feature names as x-tick labels ax.set_xticks(range(len(importances))) ax.set_xticklabels([feature_names[i] for i in indices], rotation=45, ha='right') # Playful styling ax.set_title("Feature Importance\nShark's Favorite Features", fontsize=14, pad=20, color=colors['text']) ax.set_xlabel("Features", fontsize=12, color=colors['text']) ax.set_ylabel(ylabel, fontsize=12, color=colors['text']) ax.grid(True, alpha=0.3, color=colors['grid']) # Set background color ax.set_facecolor(colors['background']) plt.tight_layout()
# Example usage function
[docs] def get_sharkpy_colors() -> Dict[str, str]: """ Returns the default SharkPy color scheme. Returns: Dictionary of color specifications """ return SHARKPY_COLORS.copy()