| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- #!/usr/bin/env python3
- """
- Test script to verify few-shot learning functionality works correctly.
- """
- import os
- import sys
- import logging
- from pathlib import Path
- # Add trainer to path
- sys.path.append(str(Path(__file__).parent / "trainer"))
- from trainer.voice_recognition.trainer import create_voice_recognition_trainer_config, VoiceRecognitionTrainer
- def setup_test_environment():
- """Setup test environment and logging."""
- # Create test output directory
- test_output_dir = Path("./test_outputs/few_shot_test")
- test_output_dir.mkdir(parents=True, exist_ok=True)
-
- # Setup logging
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
- logger = logging.getLogger("few_shot_test")
-
- return test_output_dir, logger
- def test_few_shot_config_creation():
- """Test creating few-shot learning configuration."""
- test_output_dir, logger = setup_test_environment()
- logger.info("Testing few-shot configuration creation...")
-
- try:
- # Create few-shot config
- config = create_voice_recognition_trainer_config(
- model_name="test_few_shot",
- data_dir="./trainer/data/voice_recognition",
- training_method="few_shot", # Use few-shot learning
- num_epochs=100, # Not used in few-shot, but required for config validation
- min_epochs=1,
- batch_size=8,
- learning_rate=0.001,
- )
- config.output_dir = str(test_output_dir)
-
- # Verify few-shot specific parameters exist
- assert config.custom_params['training_method'] == 'few_shot'
- assert 'few_shot_config' in config.custom_params
- assert 'n_way' in config.custom_params['few_shot_config']
- assert 'k_shot' in config.custom_params['few_shot_config']
- assert 'n_episodes' in config.custom_params['few_shot_config']
-
- logger.info("✅ Few-shot configuration created successfully")
- logger.info(f"Few-shot config: {config.custom_params['few_shot_config']}")
- return True
-
- except Exception as e:
- logger.error(f"❌ Few-shot configuration test failed: {str(e)}")
- return False
- def test_few_shot_trainer_creation():
- """Test creating few-shot trainer."""
- test_output_dir, logger = setup_test_environment()
- logger.info("Testing few-shot trainer creation...")
-
- try:
- # Create few-shot config
- config = create_voice_recognition_trainer_config(
- model_name="test_few_shot_trainer",
- data_dir="./trainer/data/voice_recognition",
- training_method="few_shot",
- num_epochs=100,
- min_epochs=1,
- batch_size=8,
- learning_rate=0.001,
- )
- config.output_dir = str(test_output_dir)
-
- # Create trainer
- trainer = VoiceRecognitionTrainer(config)
-
- # Verify few-shot specific attributes
- assert hasattr(trainer, 'training_method')
- assert hasattr(trainer, 'few_shot_config')
- assert hasattr(trainer, 'alternative_training_manager')
-
- logger.info("✅ Few-shot trainer created successfully")
- logger.info(f"Training method: {trainer.training_method}")
- logger.info(f"Few-shot config: {trainer.few_shot_config}")
- return True
-
- except Exception as e:
- logger.error(f"❌ Few-shot trainer creation test failed: {str(e)}")
- logger.exception("Full error details:")
- return False
- def test_few_shot_training_simulation():
- """Test few-shot training simulation (no actual training data required)."""
- test_output_dir, logger = setup_test_environment()
- logger.info("Testing few-shot training simulation...")
-
- try:
- # Create few-shot config with small parameters for quick testing
- config = create_voice_recognition_trainer_config(
- model_name="test_few_shot_sim",
- data_dir="./trainer/data/voice_recognition",
- training_method="few_shot",
- num_epochs=100,
- min_epochs=1,
- batch_size=8,
- learning_rate=0.001,
- )
- config.output_dir = str(test_output_dir)
-
- # Override few-shot config for quick testing
- config.custom_params['few_shot_config'] = {
- 'n_way': 3, # 3-way classification
- 'k_shot': 2, # 2 samples per class
- 'n_query': 2, # 2 query samples per class
- 'n_episodes': 10, # Only 10 episodes for testing
- 'inner_lr': 0.01,
- 'outer_lr': 0.001,
- 'adaptation_steps': 3
- }
-
- # Create trainer
- trainer = VoiceRecognitionTrainer(config)
-
- # Check if we have training data
- data_dir = Path(config.data_dir)
- if not data_dir.exists():
- logger.warning(f"Training data not found at {data_dir}")
- logger.info("Testing few-shot method parameter retrieval only...")
-
- # Test alternative method parameters
- params = trainer._get_alternative_method_params()
- assert 'n_way' in params
- assert 'k_shot' in params
- assert 'n_episodes' in params
- assert params['n_way'] == 3
- assert params['k_shot'] == 2
- assert params['n_episodes'] == 10
-
- logger.info("✅ Few-shot parameters retrieved correctly")
- logger.info(f"Parameters: {params}")
- return True
-
- # Try actual few-shot training
- logger.info("Attempting few-shot training...")
- results = trainer.train()
-
- # Verify results contain few-shot specific information
- assert 'training_method_info' in results
- assert results['training_method_info']['method'] == 'few_shot_prototypical'
-
- logger.info("✅ Few-shot training simulation completed successfully")
- logger.info(f"Training method info: {results['training_method_info']}")
- return True
-
- except Exception as e:
- logger.error(f"❌ Few-shot training simulation test failed: {str(e)}")
- logger.exception("Full error details:")
- return False
- def test_data_augmentation_method():
- """Test data augmentation training method."""
- test_output_dir, logger = setup_test_environment()
- logger.info("Testing data augmentation method...")
-
- try:
- # Create data augmentation config
- config = create_voice_recognition_trainer_config(
- model_name="test_data_aug",
- data_dir="./trainer/data/voice_recognition",
- training_method="data_augmentation",
- num_epochs=5,
- min_epochs=1,
- batch_size=8,
- learning_rate=0.001,
- )
- config.output_dir = str(test_output_dir)
-
- # Override augmentation factor for quick testing
- config.custom_params['augmentation_factor'] = 2 # Small factor for testing
-
- # Create trainer
- trainer = VoiceRecognitionTrainer(config)
-
- # Verify data augmentation specific attributes
- assert trainer.training_method.value == 'data_augmentation'
- assert hasattr(trainer, 'augmentation_factor')
- assert trainer.augmentation_factor == 2
- assert trainer.use_heavy_augmentation == True
-
- logger.info("✅ Data augmentation trainer created successfully")
- logger.info(f"Augmentation factor: {trainer.augmentation_factor}")
- return True
-
- except Exception as e:
- logger.error(f"❌ Data augmentation method test failed: {str(e)}")
- logger.exception("Full error details:")
- return False
- def main():
- """Run all few-shot learning tests."""
- test_output_dir, logger = setup_test_environment()
-
- logger.info("🚀 Starting few-shot learning tests...")
- logger.info(f"Test output directory: {test_output_dir}")
-
- tests = [
- ("Few-shot configuration creation", test_few_shot_config_creation),
- ("Few-shot trainer creation", test_few_shot_trainer_creation),
- ("Few-shot training simulation", test_few_shot_training_simulation),
- ("Data augmentation method", test_data_augmentation_method),
- ]
-
- results = []
- for test_name, test_func in tests:
- logger.info(f"\n📋 Running test: {test_name}")
- try:
- result = test_func()
- results.append((test_name, result))
- if result:
- logger.info(f"✅ {test_name}: PASSED")
- else:
- logger.error(f"❌ {test_name}: FAILED")
- except Exception as e:
- logger.error(f"❌ {test_name}: EXCEPTION - {str(e)}")
- results.append((test_name, False))
-
- # Summary
- logger.info("\n📊 Test Results Summary:")
- logger.info("=" * 50)
-
- passed = 0
- total = len(results)
-
- for test_name, result in results:
- status = "✅ PASSED" if result else "❌ FAILED"
- logger.info(f"{test_name}: {status}")
- if result:
- passed += 1
-
- logger.info("=" * 50)
- logger.info(f"Overall: {passed}/{total} tests passed")
-
- if passed == total:
- logger.info("🎉 All few-shot learning tests passed!")
- return 0
- else:
- logger.error(f"⚠️ {total - passed} tests failed")
- return 1
- if __name__ == "__main__":
- exit_code = main()
- sys.exit(exit_code)
|