trainer.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # -*- coding: utf-8 -*-
  2. """
  3. Trainer-Anwendung für Trixy ML-Modelle.
  4. Trainiert Wakeword- und Sprechererkennungsmodelle.
  5. """
  6. from enum import Enum, auto
  7. from pathlib import Path
  8. from trixy_core.application import IApplication
  9. from trixy_core.config.datasets.trainer import TrainerConfig
  10. from trixy_core.utils.debug import pinfo, pdebug, perror, pwarn
  11. class TrainingTarget(Enum):
  12. """Zu trainierende Modelltypen."""
  13. WAKEWORD = auto()
  14. VOICE_RECOGNITION = auto()
  15. ALL = auto()
  16. class TrainerApplication(IApplication):
  17. """
  18. Trainer-Anwendung für ML-Modelle.
  19. Unterstützt das Training von:
  20. - Wakeword-Modellen
  21. - Sprechererkennungsmodellen
  22. """
  23. def __init__(
  24. self,
  25. target: TrainingTarget = TrainingTarget.ALL,
  26. config_path: str | Path = "config/trainer_config.json",
  27. debug: bool = False
  28. ) -> None:
  29. """
  30. Initialisiert den Trainer.
  31. Args:
  32. target: Zu trainierende Modelltypen
  33. config_path: Pfad zur Konfigurationsdatei
  34. debug: Debug-Modus aktivieren
  35. """
  36. super().__init__(debug)
  37. self._target = target
  38. self._config_path = Path(config_path)
  39. self._trainer_config: TrainerConfig | None = None
  40. @property
  41. def trainer_config(self) -> TrainerConfig:
  42. """Trainer-Konfiguration."""
  43. if self._trainer_config is None:
  44. raise RuntimeError("Trainer nicht initialisiert")
  45. return self._trainer_config
  46. @property
  47. def target(self) -> TrainingTarget:
  48. """Trainings-Ziel."""
  49. return self._target
  50. async def initialize(self) -> None:
  51. """Initialisiert den Trainer."""
  52. pinfo("Initialisiere Trainer...")
  53. # Konfiguration laden
  54. self._trainer_config = self.config_manager.load(
  55. self._config_path,
  56. TrainerConfig,
  57. name="trainer"
  58. )
  59. # Verzeichnisse erstellen
  60. data_dir = Path(self._trainer_config.data.raw_data_directory)
  61. processed_dir = Path(self._trainer_config.data.processed_data_directory)
  62. cache_dir = Path(self._trainer_config.data.cache_directory)
  63. output_dir = Path(self._trainer_config.output.models_directory)
  64. for directory in [data_dir, processed_dir, cache_dir, output_dir]:
  65. directory.mkdir(parents=True, exist_ok=True)
  66. pdebug(f"Trainer initialisiert für: {self._target.name}")
  67. async def start(self) -> None:
  68. """Startet das Training."""
  69. pinfo(f"Starte Training: {self._target.name}")
  70. if self._target in (TrainingTarget.WAKEWORD, TrainingTarget.ALL):
  71. await self._train_wakeword()
  72. if self._target in (TrainingTarget.VOICE_RECOGNITION, TrainingTarget.ALL):
  73. await self._train_voice_recognition()
  74. pinfo("Training abgeschlossen")
  75. # Trainer beendet sich nach dem Training
  76. self.shutdown()
  77. async def stop(self) -> None:
  78. """Stoppt den Trainer."""
  79. pinfo("Trainer beendet")
  80. self.config_manager.stop_watching()
  81. async def _train_wakeword(self) -> None:
  82. """Trainiert Wakeword-Modelle."""
  83. pinfo("Trainiere Wakeword-Modell...")
  84. config = self._trainer_config.wakeword
  85. pdebug(f"Modell: {config.model_name}")
  86. pdebug(f"Epochen: {config.epochs}")
  87. pdebug(f"Batch-Größe: {config.batch_size}")
  88. # TODO: Tatsächliches Training implementieren
  89. # Hier nur Platzhalter
  90. data_path = Path(self._trainer_config.data.raw_data_directory) / "wakeword"
  91. if not data_path.exists():
  92. pwarn(f"Keine Trainingsdaten gefunden: {data_path}")
  93. return
  94. # Trainingsdaten zählen
  95. samples = list(data_path.rglob("*.wav"))
  96. pinfo(f"Gefundene Samples: {len(samples)}")
  97. if len(samples) < self._trainer_config.data.min_samples_per_class:
  98. pwarn("Zu wenige Trainingsdaten")
  99. return
  100. pinfo("Wakeword-Training abgeschlossen")
  101. async def _train_voice_recognition(self) -> None:
  102. """Trainiert Sprechererkennungsmodelle."""
  103. pinfo("Trainiere Sprechererkennungs-Modell...")
  104. config = self._trainer_config.voice_recognition
  105. pdebug(f"Modell: {config.model_name}")
  106. pdebug(f"Epochen: {config.epochs}")
  107. pdebug(f"Embedding-Größe: {config.embedding_size}")
  108. # TODO: Tatsächliches Training implementieren
  109. # Hier nur Platzhalter
  110. data_path = Path(self._trainer_config.data.raw_data_directory) / "voice_recognition"
  111. if not data_path.exists():
  112. pwarn(f"Keine Trainingsdaten gefunden: {data_path}")
  113. return
  114. # Sprecher-Verzeichnisse zählen
  115. speakers = [d for d in data_path.iterdir() if d.is_dir()]
  116. pinfo(f"Gefundene Sprecher: {len(speakers)}")
  117. for speaker_dir in speakers:
  118. samples = list(speaker_dir.glob("*.wav"))
  119. pdebug(f" {speaker_dir.name}: {len(samples)} Samples")
  120. pinfo("Sprechererkennungs-Training abgeschlossen")