test_few_shot_learning.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. #!/usr/bin/env python3
  2. """
  3. Test script to verify few-shot learning functionality works correctly.
  4. """
  5. import os
  6. import sys
  7. import logging
  8. from pathlib import Path
  9. # Add trainer to path
  10. sys.path.append(str(Path(__file__).parent / "trainer"))
  11. from trainer.voice_recognition.trainer import create_voice_recognition_trainer_config, VoiceRecognitionTrainer
  12. def setup_test_environment():
  13. """Setup test environment and logging."""
  14. # Create test output directory
  15. test_output_dir = Path("./test_outputs/few_shot_test")
  16. test_output_dir.mkdir(parents=True, exist_ok=True)
  17. # Setup logging
  18. logging.basicConfig(
  19. level=logging.INFO,
  20. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  21. )
  22. logger = logging.getLogger("few_shot_test")
  23. return test_output_dir, logger
  24. def test_few_shot_config_creation():
  25. """Test creating few-shot learning configuration."""
  26. test_output_dir, logger = setup_test_environment()
  27. logger.info("Testing few-shot configuration creation...")
  28. try:
  29. # Create few-shot config
  30. config = create_voice_recognition_trainer_config(
  31. model_name="test_few_shot",
  32. data_dir="./trainer/data/voice_recognition",
  33. training_method="few_shot", # Use few-shot learning
  34. num_epochs=100, # Not used in few-shot, but required for config validation
  35. min_epochs=1,
  36. batch_size=8,
  37. learning_rate=0.001,
  38. )
  39. config.output_dir = str(test_output_dir)
  40. # Verify few-shot specific parameters exist
  41. assert config.custom_params['training_method'] == 'few_shot'
  42. assert 'few_shot_config' in config.custom_params
  43. assert 'n_way' in config.custom_params['few_shot_config']
  44. assert 'k_shot' in config.custom_params['few_shot_config']
  45. assert 'n_episodes' in config.custom_params['few_shot_config']
  46. logger.info("✅ Few-shot configuration created successfully")
  47. logger.info(f"Few-shot config: {config.custom_params['few_shot_config']}")
  48. return True
  49. except Exception as e:
  50. logger.error(f"❌ Few-shot configuration test failed: {str(e)}")
  51. return False
  52. def test_few_shot_trainer_creation():
  53. """Test creating few-shot trainer."""
  54. test_output_dir, logger = setup_test_environment()
  55. logger.info("Testing few-shot trainer creation...")
  56. try:
  57. # Create few-shot config
  58. config = create_voice_recognition_trainer_config(
  59. model_name="test_few_shot_trainer",
  60. data_dir="./trainer/data/voice_recognition",
  61. training_method="few_shot",
  62. num_epochs=100,
  63. min_epochs=1,
  64. batch_size=8,
  65. learning_rate=0.001,
  66. )
  67. config.output_dir = str(test_output_dir)
  68. # Create trainer
  69. trainer = VoiceRecognitionTrainer(config)
  70. # Verify few-shot specific attributes
  71. assert hasattr(trainer, 'training_method')
  72. assert hasattr(trainer, 'few_shot_config')
  73. assert hasattr(trainer, 'alternative_training_manager')
  74. logger.info("✅ Few-shot trainer created successfully")
  75. logger.info(f"Training method: {trainer.training_method}")
  76. logger.info(f"Few-shot config: {trainer.few_shot_config}")
  77. return True
  78. except Exception as e:
  79. logger.error(f"❌ Few-shot trainer creation test failed: {str(e)}")
  80. logger.exception("Full error details:")
  81. return False
  82. def test_few_shot_training_simulation():
  83. """Test few-shot training simulation (no actual training data required)."""
  84. test_output_dir, logger = setup_test_environment()
  85. logger.info("Testing few-shot training simulation...")
  86. try:
  87. # Create few-shot config with small parameters for quick testing
  88. config = create_voice_recognition_trainer_config(
  89. model_name="test_few_shot_sim",
  90. data_dir="./trainer/data/voice_recognition",
  91. training_method="few_shot",
  92. num_epochs=100,
  93. min_epochs=1,
  94. batch_size=8,
  95. learning_rate=0.001,
  96. )
  97. config.output_dir = str(test_output_dir)
  98. # Override few-shot config for quick testing
  99. config.custom_params['few_shot_config'] = {
  100. 'n_way': 3, # 3-way classification
  101. 'k_shot': 2, # 2 samples per class
  102. 'n_query': 2, # 2 query samples per class
  103. 'n_episodes': 10, # Only 10 episodes for testing
  104. 'inner_lr': 0.01,
  105. 'outer_lr': 0.001,
  106. 'adaptation_steps': 3
  107. }
  108. # Create trainer
  109. trainer = VoiceRecognitionTrainer(config)
  110. # Check if we have training data
  111. data_dir = Path(config.data_dir)
  112. if not data_dir.exists():
  113. logger.warning(f"Training data not found at {data_dir}")
  114. logger.info("Testing few-shot method parameter retrieval only...")
  115. # Test alternative method parameters
  116. params = trainer._get_alternative_method_params()
  117. assert 'n_way' in params
  118. assert 'k_shot' in params
  119. assert 'n_episodes' in params
  120. assert params['n_way'] == 3
  121. assert params['k_shot'] == 2
  122. assert params['n_episodes'] == 10
  123. logger.info("✅ Few-shot parameters retrieved correctly")
  124. logger.info(f"Parameters: {params}")
  125. return True
  126. # Try actual few-shot training
  127. logger.info("Attempting few-shot training...")
  128. results = trainer.train()
  129. # Verify results contain few-shot specific information
  130. assert 'training_method_info' in results
  131. assert results['training_method_info']['method'] == 'few_shot_prototypical'
  132. logger.info("✅ Few-shot training simulation completed successfully")
  133. logger.info(f"Training method info: {results['training_method_info']}")
  134. return True
  135. except Exception as e:
  136. logger.error(f"❌ Few-shot training simulation test failed: {str(e)}")
  137. logger.exception("Full error details:")
  138. return False
  139. def test_data_augmentation_method():
  140. """Test data augmentation training method."""
  141. test_output_dir, logger = setup_test_environment()
  142. logger.info("Testing data augmentation method...")
  143. try:
  144. # Create data augmentation config
  145. config = create_voice_recognition_trainer_config(
  146. model_name="test_data_aug",
  147. data_dir="./trainer/data/voice_recognition",
  148. training_method="data_augmentation",
  149. num_epochs=5,
  150. min_epochs=1,
  151. batch_size=8,
  152. learning_rate=0.001,
  153. )
  154. config.output_dir = str(test_output_dir)
  155. # Override augmentation factor for quick testing
  156. config.custom_params['augmentation_factor'] = 2 # Small factor for testing
  157. # Create trainer
  158. trainer = VoiceRecognitionTrainer(config)
  159. # Verify data augmentation specific attributes
  160. assert trainer.training_method.value == 'data_augmentation'
  161. assert hasattr(trainer, 'augmentation_factor')
  162. assert trainer.augmentation_factor == 2
  163. assert trainer.use_heavy_augmentation == True
  164. logger.info("✅ Data augmentation trainer created successfully")
  165. logger.info(f"Augmentation factor: {trainer.augmentation_factor}")
  166. return True
  167. except Exception as e:
  168. logger.error(f"❌ Data augmentation method test failed: {str(e)}")
  169. logger.exception("Full error details:")
  170. return False
  171. def main():
  172. """Run all few-shot learning tests."""
  173. test_output_dir, logger = setup_test_environment()
  174. logger.info("🚀 Starting few-shot learning tests...")
  175. logger.info(f"Test output directory: {test_output_dir}")
  176. tests = [
  177. ("Few-shot configuration creation", test_few_shot_config_creation),
  178. ("Few-shot trainer creation", test_few_shot_trainer_creation),
  179. ("Few-shot training simulation", test_few_shot_training_simulation),
  180. ("Data augmentation method", test_data_augmentation_method),
  181. ]
  182. results = []
  183. for test_name, test_func in tests:
  184. logger.info(f"\n📋 Running test: {test_name}")
  185. try:
  186. result = test_func()
  187. results.append((test_name, result))
  188. if result:
  189. logger.info(f"✅ {test_name}: PASSED")
  190. else:
  191. logger.error(f"❌ {test_name}: FAILED")
  192. except Exception as e:
  193. logger.error(f"❌ {test_name}: EXCEPTION - {str(e)}")
  194. results.append((test_name, False))
  195. # Summary
  196. logger.info("\n📊 Test Results Summary:")
  197. logger.info("=" * 50)
  198. passed = 0
  199. total = len(results)
  200. for test_name, result in results:
  201. status = "✅ PASSED" if result else "❌ FAILED"
  202. logger.info(f"{test_name}: {status}")
  203. if result:
  204. passed += 1
  205. logger.info("=" * 50)
  206. logger.info(f"Overall: {passed}/{total} tests passed")
  207. if passed == total:
  208. logger.info("🎉 All few-shot learning tests passed!")
  209. return 0
  210. else:
  211. logger.error(f"⚠️ {total - passed} tests failed")
  212. return 1
  213. if __name__ == "__main__":
  214. exit_code = main()
  215. sys.exit(exit_code)