fix_audio_processing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #!/usr/bin/env python3
  2. """
  3. Audio Processing Parameter Fix Script
  4. This script identifies and fixes common audio processing parameter issues
  5. in the Trixy voice assistant codebase, specifically MelSpectrogram and MFCC
  6. parameter compatibility with current torchaudio versions.
  7. """
  8. import os
  9. import re
  10. import json
  11. import logging
  12. from pathlib import Path
  13. from typing import List, Dict, Any
  14. # Configure logging
  15. logging.basicConfig(level=logging.INFO)
  16. logger = logging.getLogger(__name__)
  17. class AudioProcessingFixer:
  18. """Fix audio processing parameter issues."""
  19. def __init__(self, source_dir: str = "."):
  20. """Initialize the fixer."""
  21. self.source_dir = Path(source_dir)
  22. self.issues_found = []
  23. self.fixes_applied = []
  24. def scan_for_issues(self) -> List[Dict[str, Any]]:
  25. """Scan for potential audio processing parameter issues."""
  26. logger.info("Scanning for audio processing parameter issues...")
  27. # Patterns to look for
  28. patterns = [
  29. # Wrong parameter names in MelSpectrogram
  30. (r'MelSpectrogram\([^)]*fmin\s*=', 'MelSpectrogram using fmin instead of f_min'),
  31. (r'MelSpectrogram\([^)]*fmax\s*=', 'MelSpectrogram using fmax instead of f_max'),
  32. # Wrong parameter names in MFCC melkwargs
  33. (r'melkwargs\s*=\s*{[^}]*["\']fmin["\']', 'MFCC melkwargs using fmin instead of f_min'),
  34. (r'melkwargs\s*=\s*{[^}]*["\']fmax["\']', 'MFCC melkwargs using fmax instead of f_max'),
  35. # Deprecated normalized parameter usage
  36. (r'MelSpectrogram\([^)]*normalized\s*=\s*True', 'MelSpectrogram using deprecated normalized=True'),
  37. ]
  38. # Scan Python files
  39. python_files = list(self.source_dir.rglob("*.py"))
  40. for file_path in python_files:
  41. try:
  42. with open(file_path, 'r', encoding='utf-8') as f:
  43. content = f.read()
  44. for pattern, description in patterns:
  45. matches = re.finditer(pattern, content, re.MULTILINE | re.DOTALL)
  46. for match in matches:
  47. line_num = content[:match.start()].count('\n') + 1
  48. self.issues_found.append({
  49. 'file': str(file_path),
  50. 'line': line_num,
  51. 'pattern': pattern,
  52. 'description': description,
  53. 'match': match.group(0)
  54. })
  55. except Exception as e:
  56. logger.warning(f"Error scanning {file_path}: {e}")
  57. # Scan configuration files
  58. config_files = list(self.source_dir.rglob("*.json"))
  59. for file_path in config_files:
  60. try:
  61. with open(file_path, 'r', encoding='utf-8') as f:
  62. content = f.read()
  63. # Look for audio processing configuration
  64. if any(term in content.lower() for term in ['mel', 'mfcc', 'audio', 'spectrogram']):
  65. # Check for wrong parameter names in JSON
  66. if 'fmin' in content and 'f_min' not in content:
  67. self.issues_found.append({
  68. 'file': str(file_path),
  69. 'line': 1,
  70. 'pattern': 'fmin in JSON config',
  71. 'description': 'JSON config using fmin instead of f_min',
  72. 'match': 'fmin in configuration'
  73. })
  74. except Exception as e:
  75. logger.warning(f"Error scanning {file_path}: {e}")
  76. logger.info(f"Found {len(self.issues_found)} potential issues")
  77. return self.issues_found
  78. def test_current_audio_pipeline(self) -> bool:
  79. """Test the current audio processing pipeline."""
  80. logger.info("Testing current audio processing pipeline...")
  81. try:
  82. # Test data pipeline
  83. from trainer.data_pipeline import AudioProcessingConfig, AudioProcessor
  84. config = AudioProcessingConfig()
  85. processor = AudioProcessor(config)
  86. logger.info("✓ AudioProcessor creation successful")
  87. # Test voice recognition audio features
  88. from trixy_core.ml.voice_recognition.audio_features import create_feature_extractor
  89. log_mel_extractor = create_feature_extractor("log_mel", sample_rate=16000)
  90. mfcc_extractor = create_feature_extractor("mfcc", sample_rate=16000)
  91. logger.info("✓ Voice recognition feature extractors creation successful")
  92. # Test wakeword audio features
  93. from trixy_core.ml.wakeword.audio_features import create_feature_extractor as ww_create_extractor
  94. ww_extractor = ww_create_extractor()
  95. logger.info("✓ Wakeword feature extractor creation successful")
  96. # Test actual feature extraction with dummy data
  97. import torch
  98. dummy_audio = torch.randn(1, 16000) # 1 second of audio
  99. # Test all extractors
  100. features1 = processor.extract_features(dummy_audio)
  101. features2 = log_mel_extractor(dummy_audio)
  102. features3 = mfcc_extractor(dummy_audio)
  103. features4 = ww_extractor.extract_features(dummy_audio)
  104. logger.info("✓ All audio feature extraction tests passed")
  105. return True
  106. except Exception as e:
  107. logger.error(f"Audio pipeline test failed: {e}")
  108. import traceback
  109. traceback.print_exc()
  110. return False
  111. def validate_torchaudio_compatibility(self) -> Dict[str, Any]:
  112. """Validate torchaudio compatibility."""
  113. logger.info("Validating torchaudio compatibility...")
  114. import torch
  115. import torchaudio
  116. version_info = {
  117. 'torch_version': torch.__version__,
  118. 'torchaudio_version': torchaudio.__version__,
  119. 'compatible': True,
  120. 'recommendations': []
  121. }
  122. # Test MelSpectrogram parameters
  123. try:
  124. import torchaudio.transforms as T
  125. # Test correct parameter names
  126. mel_spec = T.MelSpectrogram(
  127. sample_rate=16000,
  128. n_fft=512,
  129. hop_length=160,
  130. n_mels=40,
  131. f_min=80.0,
  132. f_max=8000.0
  133. )
  134. logger.info("✓ MelSpectrogram with f_min/f_max works")
  135. # Test MFCC with melkwargs
  136. mfcc = T.MFCC(
  137. sample_rate=16000,
  138. n_mfcc=40,
  139. melkwargs={
  140. 'n_fft': 512,
  141. 'hop_length': 160,
  142. 'n_mels': 40,
  143. 'f_min': 80.0,
  144. 'f_max': 8000.0
  145. }
  146. )
  147. logger.info("✓ MFCC with melkwargs f_min/f_max works")
  148. except Exception as e:
  149. version_info['compatible'] = False
  150. version_info['recommendations'].append(f"MelSpectrogram/MFCC parameter error: {e}")
  151. # Check for deprecated features
  152. try:
  153. from torch.cuda.amp import GradScaler
  154. version_info['recommendations'].append(
  155. "Consider updating GradScaler usage to torch.amp.GradScaler('cuda') "
  156. "to avoid deprecation warnings"
  157. )
  158. except ImportError:
  159. pass
  160. return version_info
  161. def create_compatibility_fixes(self) -> str:
  162. """Create a compatibility fix patch."""
  163. fixes = []
  164. # Check if we need to fix GradScaler usage
  165. grad_scaler_files = []
  166. for file_path in self.source_dir.rglob("*.py"):
  167. try:
  168. with open(file_path, 'r', encoding='utf-8') as f:
  169. content = f.read()
  170. if 'torch.cuda.amp.GradScaler()' in content:
  171. grad_scaler_files.append(file_path)
  172. except:
  173. pass
  174. if grad_scaler_files:
  175. fixes.append("Fix GradScaler deprecation warning:")
  176. for file_path in grad_scaler_files:
  177. fixes.append(f" - {file_path}: Replace torch.cuda.amp.GradScaler() with torch.amp.GradScaler('cuda')")
  178. return "\n".join(fixes) if fixes else "No compatibility fixes needed."
  179. def run_comprehensive_check(self) -> Dict[str, Any]:
  180. """Run comprehensive audio processing check."""
  181. logger.info("Running comprehensive audio processing check...")
  182. results = {
  183. 'issues_found': self.scan_for_issues(),
  184. 'pipeline_test': self.test_current_audio_pipeline(),
  185. 'compatibility': self.validate_torchaudio_compatibility(),
  186. 'recommendations': []
  187. }
  188. # Generate recommendations
  189. if not results['issues_found']:
  190. results['recommendations'].append("✓ No audio processing parameter issues found")
  191. else:
  192. results['recommendations'].append(f"⚠️ Found {len(results['issues_found'])} potential issues")
  193. for issue in results['issues_found']:
  194. results['recommendations'].append(f" - {issue['file']}:{issue['line']} - {issue['description']}")
  195. if results['pipeline_test']:
  196. results['recommendations'].append("✓ Audio processing pipeline tests passed")
  197. else:
  198. results['recommendations'].append("❌ Audio processing pipeline tests failed")
  199. if results['compatibility']['compatible']:
  200. results['recommendations'].append("✓ Torchaudio compatibility validated")
  201. else:
  202. results['recommendations'].append("❌ Torchaudio compatibility issues found")
  203. # Add fix suggestions
  204. fix_suggestions = self.create_compatibility_fixes()
  205. if fix_suggestions != "No compatibility fixes needed.":
  206. results['recommendations'].append("Suggested fixes:")
  207. results['recommendations'].append(fix_suggestions)
  208. return results
  209. def main():
  210. """Main function."""
  211. fixer = AudioProcessingFixer()
  212. results = fixer.run_comprehensive_check()
  213. print("\n" + "="*60)
  214. print("AUDIO PROCESSING COMPATIBILITY CHECK RESULTS")
  215. print("="*60)
  216. print(f"\nTorch version: {results['compatibility']['torch_version']}")
  217. print(f"Torchaudio version: {results['compatibility']['torchaudio_version']}")
  218. print(f"\nIssues found: {len(results['issues_found'])}")
  219. print(f"Pipeline test passed: {results['pipeline_test']}")
  220. print(f"Compatibility validated: {results['compatibility']['compatible']}")
  221. print("\nRecommendations:")
  222. for rec in results['recommendations']:
  223. print(f" {rec}")
  224. if results['issues_found']:
  225. print("\nDetailed issues:")
  226. for issue in results['issues_found']:
  227. print(f" {issue['file']}:{issue['line']} - {issue['description']}")
  228. print(f" Match: {issue['match']}")
  229. print("\n" + "="*60)
  230. return results
  231. if __name__ == "__main__":
  232. main()