| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- #!/usr/bin/env python3
- """
- Test script to verify the training visualization system works correctly.
- This script tests the visualization components independently to ensure
- plots are generated correctly before running a full training session.
- """
- import os
- import sys
- import logging
- from pathlib import Path
- import numpy as np
- import torch
- import matplotlib
- matplotlib.use('Agg') # Use non-interactive backend for headless testing
- import matplotlib.pyplot as plt
- # Add trainer to path
- sys.path.append(str(Path(__file__).parent / "trainer"))
- from trainer.visualization import create_training_visualizer
- from trainer.base import TrainingMetrics
- def setup_test_environment():
- """Setup test environment and logging."""
- # Create test output directory
- test_output_dir = Path("./test_outputs/visualization_test")
- test_output_dir.mkdir(parents=True, exist_ok=True)
-
- # Setup logging
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
- logger = logging.getLogger("visualization_test")
-
- return test_output_dir, logger
- def create_mock_training_metrics():
- """Create mock training metrics for testing."""
- metrics = TrainingMetrics()
-
- # Simulate training progress
- num_epochs = 10
-
- # Training metrics with realistic patterns
- train_losses = []
- train_accuracies = []
- val_losses = []
- val_accuracies = []
-
- for epoch in range(num_epochs):
- # Simulate decreasing loss with some noise
- train_loss = 2.0 * np.exp(-epoch * 0.3) + 0.1 * np.random.random()
- val_loss = train_loss + 0.2 * np.random.random()
-
- # Simulate increasing accuracy with plateau
- train_acc = min(0.95, 0.3 + 0.7 * (1 - np.exp(-epoch * 0.4))) + 0.05 * np.random.random()
- val_acc = train_acc - 0.1 + 0.05 * np.random.random()
-
- train_losses.append(train_loss)
- train_accuracies.append(train_acc)
- val_losses.append(val_loss)
- val_accuracies.append(val_acc)
-
- # Add metrics to TrainingMetrics object
- metrics.metrics['train_loss'] = train_losses
- metrics.metrics['train_accuracy'] = train_accuracies
- metrics.metrics['val_loss'] = val_losses
- metrics.metrics['val_accuracy'] = val_accuracies
-
- # Add learning rate schedule
- metrics.metrics['learning_rate'] = [0.001 * (0.5 ** (epoch // 3)) for epoch in range(num_epochs)]
-
- return metrics
- def create_mock_voice_recognition_metrics():
- """Create mock voice recognition specific metrics."""
- return {
- 'speaker_verification': {
- 'equal_error_rate': 0.12,
- 'eer_threshold': 0.52,
- 'mean_positive_similarity': 0.78,
- 'mean_negative_similarity': 0.34,
- 'num_positive_pairs': 1250,
- 'num_negative_pairs': 8750
- },
- 'embedding_analysis': {
- 'num_speakers': 3,
- 'mean_intra_distance': 0.45,
- 'std_intra_distance': 0.12,
- 'mean_inter_distance': 1.23,
- 'std_inter_distance': 0.34,
- 'separability_ratio': 2.73
- }
- }
- def test_basic_matplotlib():
- """Test basic matplotlib functionality."""
- logger = logging.getLogger("visualization_test")
- logger.info("Testing basic matplotlib functionality...")
-
- try:
- fig, ax = plt.subplots(1, 1, figsize=(8, 6))
- x = np.linspace(0, 10, 100)
- y = np.sin(x)
- ax.plot(x, y)
- ax.set_title("Test Plot")
- ax.set_xlabel("X")
- ax.set_ylabel("Y")
-
- test_file = "./test_outputs/visualization_test/basic_matplotlib_test.png"
- plt.savefig(test_file, dpi=150, bbox_inches='tight')
- plt.close()
-
- if os.path.exists(test_file):
- logger.info(f"✅ Basic matplotlib test passed: {test_file}")
- return True
- else:
- logger.error("❌ Basic matplotlib test failed: file not created")
- return False
-
- except Exception as e:
- logger.error(f"❌ Basic matplotlib test failed: {str(e)}")
- return False
- def test_training_visualizer():
- """Test the training visualizer."""
- test_output_dir, logger = setup_test_environment()
- logger.info("Testing training visualizer...")
-
- try:
- # Create visualizer
- visualizer = create_training_visualizer(
- output_dir=str(test_output_dir),
- model_name="test_voice_model"
- )
- logger.info("✅ Visualizer created successfully")
-
- # Create mock data
- training_metrics = create_mock_training_metrics()
- additional_metrics = create_mock_voice_recognition_metrics()
-
- logger.info("✅ Mock data created successfully")
-
- # Generate plots
- logger.info("Generating training plots...")
- plot_files = visualizer.create_training_plots(
- metrics=training_metrics,
- additional_metrics=additional_metrics
- )
-
- if plot_files:
- logger.info(f"✅ Generated {len(plot_files)} plots:")
- for plot_file in plot_files:
- if os.path.exists(plot_file):
- logger.info(f" ✅ {plot_file}")
- else:
- logger.error(f" ❌ {plot_file} (file not found)")
- return False
- return True
- else:
- logger.error("❌ No plots were generated")
- return False
-
- except Exception as e:
- logger.error(f"❌ Training visualizer test failed: {str(e)}")
- logger.exception("Full error details:")
- return False
- def test_reports_directory_creation():
- """Test that reports directory is created properly."""
- logger = logging.getLogger("visualization_test")
- logger.info("Testing reports directory creation...")
-
- test_model_dir = Path("./test_outputs/visualization_test/test_model")
- reports_dir = test_model_dir / "reports"
- plots_dir = test_model_dir / "plots"
-
- # Create directories
- reports_dir.mkdir(parents=True, exist_ok=True)
- plots_dir.mkdir(parents=True, exist_ok=True)
-
- if reports_dir.exists() and plots_dir.exists():
- logger.info("✅ Reports and plots directories created successfully")
- return True
- else:
- logger.error("❌ Failed to create reports/plots directories")
- return False
- def test_voice_recognition_training_quick():
- """Test a very quick voice recognition training to verify plots are generated."""
- test_output_dir, logger = setup_test_environment()
- logger.info("Testing quick voice recognition training with plot generation...")
-
- try:
- from trainer.voice_recognition.trainer import create_voice_recognition_trainer_config, VoiceRecognitionTrainer
-
- # Create minimal config for testing
- config = create_voice_recognition_trainer_config(
- model_name="test_vr_visualization",
- data_dir="./trainer/data/voice_recognition",
- num_epochs=5, # Must be >= min_epochs (default 30, but we override)
- min_epochs=2, # Override default min_epochs
- batch_size=8,
- learning_rate=0.01
- )
- config.output_dir = str(test_output_dir)
-
- # Check if training data exists
- data_dir = Path(config.data_dir)
- if not data_dir.exists():
- logger.warning(f"❌ Training data directory not found: {data_dir}")
- logger.info("Skipping actual training test - data not available")
- return True # Don't fail the test for missing data
-
- # Create trainer
- trainer = VoiceRecognitionTrainer(config)
- logger.info("✅ Trainer created successfully")
-
- # Try to prepare data
- try:
- train_loader, val_loader, test_loader = trainer.prepare_data()
- logger.info("✅ Data prepared successfully")
-
- # Check if we have any data
- if len(train_loader.dataset) == 0:
- logger.warning("❌ No training data found")
- return True # Don't fail for empty dataset
-
- # Build model
- trainer.build_model()
- logger.info("✅ Model built successfully")
-
- # Just test the visualization generation without full training
- from trainer.base import TrainingMetrics
-
- # Create fake training metrics
- mock_metrics = create_mock_training_metrics()
- mock_test_results = create_mock_voice_recognition_metrics()
-
- # Test visualization generation
- trainer._generate_training_visualizations(mock_metrics, mock_test_results)
-
- # Check if plots were created
- plots_dir = Path(config.output_dir) / config.model_name / "plots"
- if plots_dir.exists() and any(plots_dir.glob("*.png")):
- logger.info(f"✅ Training plots generated successfully in {plots_dir}")
- return True
- else:
- logger.warning(f"⚠️ No plots found in {plots_dir} - this might be expected")
- return True
-
- except Exception as e:
- logger.error(f"❌ Data preparation failed: {str(e)}")
- return True # Don't fail the entire test for data issues
-
- except Exception as e:
- logger.error(f"❌ Quick training test failed: {str(e)}")
- logger.exception("Full error details:")
- return False
- def main():
- """Run all visualization tests."""
- test_output_dir, logger = setup_test_environment()
-
- logger.info("🚀 Starting visualization system tests...")
- logger.info(f"Test output directory: {test_output_dir}")
-
- tests = [
- ("Basic matplotlib functionality", test_basic_matplotlib),
- ("Reports directory creation", test_reports_directory_creation),
- ("Training visualizer", test_training_visualizer),
- ("Quick voice recognition training", test_voice_recognition_training_quick),
- ]
-
- results = []
- for test_name, test_func in tests:
- logger.info(f"\n📋 Running test: {test_name}")
- try:
- result = test_func()
- results.append((test_name, result))
- if result:
- logger.info(f"✅ {test_name}: PASSED")
- else:
- logger.error(f"❌ {test_name}: FAILED")
- except Exception as e:
- logger.error(f"❌ {test_name}: EXCEPTION - {str(e)}")
- results.append((test_name, False))
-
- # Summary
- logger.info("\n📊 Test Results Summary:")
- logger.info("=" * 50)
-
- passed = 0
- total = len(results)
-
- for test_name, result in results:
- status = "✅ PASSED" if result else "❌ FAILED"
- logger.info(f"{test_name}: {status}")
- if result:
- passed += 1
-
- logger.info("=" * 50)
- logger.info(f"Overall: {passed}/{total} tests passed")
-
- if passed == total:
- logger.info("🎉 All visualization tests passed!")
- return 0
- else:
- logger.error(f"⚠️ {total - passed} tests failed")
- return 1
- if __name__ == "__main__":
- exit_code = main()
- sys.exit(exit_code)
|