|
|
@@ -0,0 +1,279 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+BIO-Tag Generierung und Decoding fuer Slot-Extraktion.
|
|
|
+
|
|
|
+Generiert Wort-Level und Token-Level BIO-Tags fuer Trainings-Samples.
|
|
|
+Decodiert BIO-Tags zurueck zu Slot-Werten bei der Inferenz.
|
|
|
+
|
|
|
+BIO-Schema:
|
|
|
+ O = Kein Slot (Outside)
|
|
|
+ PAD = Padding / Spezial-Token (ignoriert bei Loss)
|
|
|
+ B-{slot} = Beginn eines Slots (z.B. B-query, B-city)
|
|
|
+ I-{slot} = Fortsetzung eines Slots (z.B. I-query, I-city)
|
|
|
+
|
|
|
+Beispiel:
|
|
|
+ Text: "spiele musik von rammstein"
|
|
|
+ Words: ["spiele", "musik", "von", "rammstein"]
|
|
|
+ Slots: {query: "rammstein"}
|
|
|
+ BIO: ["O", "O", "O", "B-query"]
|
|
|
+
|
|
|
+ Tokenized: ["<s>", "spiel", "e", "mus", "ik", "von", "ramm", "stein", "</s>"]
|
|
|
+ Token-BIO: ["PAD", "O", "O", "O", "O", "O", "B-query", "I-query", "PAD"]
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+from trixy_core.utils.debug import pdebug
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class BIOTagSet:
|
|
|
+ """Definiert alle verfuegbaren BIO-Tags."""
|
|
|
+ slot_names: list[str] = field(default_factory=list)
|
|
|
+
|
|
|
+ def __post_init__(self) -> None:
|
|
|
+ self._tag_to_idx: dict[str, int] = {}
|
|
|
+ self._idx_to_tag: dict[int, str] = {}
|
|
|
+ self._build_index()
|
|
|
+
|
|
|
+ def _build_index(self) -> None:
|
|
|
+ """Baut den Tag-Index auf."""
|
|
|
+ tags = ["O", "PAD"]
|
|
|
+ for slot in sorted(self.slot_names):
|
|
|
+ tags.append(f"B-{slot}")
|
|
|
+ tags.append(f"I-{slot}")
|
|
|
+
|
|
|
+ self._tag_to_idx = {tag: i for i, tag in enumerate(tags)}
|
|
|
+ self._idx_to_tag = {i: tag for tag, i in self._tag_to_idx.items()}
|
|
|
+
|
|
|
+ @property
|
|
|
+ def num_tags(self) -> int:
|
|
|
+ return len(self._tag_to_idx)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def tags(self) -> list[str]:
|
|
|
+ return [self._idx_to_tag[i] for i in range(len(self._idx_to_tag))]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def o_idx(self) -> int:
|
|
|
+ return self._tag_to_idx["O"]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def pad_idx(self) -> int:
|
|
|
+ return self._tag_to_idx["PAD"]
|
|
|
+
|
|
|
+ def tag_to_idx(self, tag: str) -> int:
|
|
|
+ return self._tag_to_idx.get(tag, self.o_idx)
|
|
|
+
|
|
|
+ def idx_to_tag(self, idx: int) -> str:
|
|
|
+ return self._idx_to_tag.get(idx, "O")
|
|
|
+
|
|
|
+ def to_dict(self) -> dict[str, int]:
|
|
|
+ return dict(self._tag_to_idx)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_dict(cls, data: dict[str, int]) -> "BIOTagSet":
|
|
|
+ """Laedt aus gespeichertem Dict."""
|
|
|
+ # Slot-Namen aus Tags extrahieren
|
|
|
+ slots = set()
|
|
|
+ for tag in data:
|
|
|
+ if tag.startswith("B-"):
|
|
|
+ slots.add(tag[2:])
|
|
|
+ tagset = cls(slot_names=sorted(slots))
|
|
|
+ return tagset
|
|
|
+
|
|
|
+
|
|
|
+def generate_word_bio_tags(
|
|
|
+ text: str,
|
|
|
+ slots: dict[str, str | list[str]],
|
|
|
+) -> list[str]:
|
|
|
+ """
|
|
|
+ Generiert Wort-Level BIO-Tags.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text: Der Eingabetext
|
|
|
+ slots: Dict mit Slot-Namen und Werten
|
|
|
+ z.B. {"city": "Berlin", "date": "morgen"}
|
|
|
+ oder {"topping": ["Salami", "Pilze"]}
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Liste von BIO-Tags pro Wort
|
|
|
+ """
|
|
|
+ words = text.split()
|
|
|
+ tags = ["O"] * len(words)
|
|
|
+ text_lower = text.lower()
|
|
|
+
|
|
|
+ if not slots:
|
|
|
+ return tags
|
|
|
+
|
|
|
+ for slot_name, slot_value in slots.items():
|
|
|
+ # Einzelwert oder Liste
|
|
|
+ values = slot_value if isinstance(slot_value, list) else [slot_value]
|
|
|
+
|
|
|
+ for value in values:
|
|
|
+ value_lower = value.lower()
|
|
|
+ value_words = value_lower.split()
|
|
|
+
|
|
|
+ # Position im Wort-Array suchen
|
|
|
+ for i in range(len(words)):
|
|
|
+ match = True
|
|
|
+ for j, vw in enumerate(value_words):
|
|
|
+ if i + j >= len(words):
|
|
|
+ match = False
|
|
|
+ break
|
|
|
+ if words[i + j].lower().rstrip(".,!?") != vw.rstrip(".,!?"):
|
|
|
+ match = False
|
|
|
+ break
|
|
|
+
|
|
|
+ if match:
|
|
|
+ tags[i] = f"B-{slot_name}"
|
|
|
+ for j in range(1, len(value_words)):
|
|
|
+ if i + j < len(tags):
|
|
|
+ tags[i + j] = f"I-{slot_name}"
|
|
|
+ break # Nur erstes Vorkommen pro Wert
|
|
|
+
|
|
|
+ return tags
|
|
|
+
|
|
|
+
|
|
|
+def align_bio_to_tokens(
|
|
|
+ word_bio_tags: list[str],
|
|
|
+ text: str,
|
|
|
+ tokenizer: Any,
|
|
|
+ max_length: int = 64,
|
|
|
+) -> list[str]:
|
|
|
+ """
|
|
|
+ Aligniert Wort-Level BIO-Tags mit Subword-Tokens.
|
|
|
+
|
|
|
+ Der Tokenizer teilt Woerter in Subwords:
|
|
|
+ "Wohnzimmer" → ["▁Wohn", "zimmer"]
|
|
|
+
|
|
|
+ BIO-Alignment:
|
|
|
+ - Erstes Subword eines Worts behaelt den Wort-BIO-Tag
|
|
|
+ - Folgende Subwords: B-xxx → I-xxx, I-xxx → I-xxx, O → O
|
|
|
+ - Spezial-Tokens ([CLS], [SEP], [PAD]) → PAD
|
|
|
+
|
|
|
+ Args:
|
|
|
+ word_bio_tags: BIO-Tags pro Wort
|
|
|
+ text: Der Original-Text
|
|
|
+ tokenizer: HuggingFace Tokenizer
|
|
|
+ max_length: Maximale Token-Laenge (mit Padding)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ BIO-Tags pro Subword-Token (Laenge = max_length)
|
|
|
+ """
|
|
|
+ words = text.split()
|
|
|
+
|
|
|
+ # Tokenisieren mit word_ids (zeigt welches Wort zu welchem Token gehoert)
|
|
|
+ encoding = tokenizer(
|
|
|
+ text,
|
|
|
+ padding="max_length",
|
|
|
+ truncation=True,
|
|
|
+ max_length=max_length,
|
|
|
+ return_offsets_mapping=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ word_ids = encoding.word_ids() # None fuer Spezial-Tokens, int fuer Wort-Index
|
|
|
+ token_tags: list[str] = []
|
|
|
+
|
|
|
+ prev_word_id = None
|
|
|
+ for word_id in word_ids:
|
|
|
+ if word_id is None:
|
|
|
+ # Spezial-Token ([CLS], [SEP], [PAD])
|
|
|
+ token_tags.append("PAD")
|
|
|
+ elif word_id >= len(word_bio_tags):
|
|
|
+ # Wort-Index ausserhalb der BIO-Tags (sollte nicht passieren)
|
|
|
+ token_tags.append("O")
|
|
|
+ elif word_id != prev_word_id:
|
|
|
+ # Erstes Subword des Worts → Original BIO-Tag
|
|
|
+ token_tags.append(word_bio_tags[word_id])
|
|
|
+ else:
|
|
|
+ # Weiteres Subword desselben Worts
|
|
|
+ bio = word_bio_tags[word_id]
|
|
|
+ if bio.startswith("B-"):
|
|
|
+ # B-xxx → I-xxx fuer Folge-Subwords
|
|
|
+ token_tags.append("I-" + bio[2:])
|
|
|
+ else:
|
|
|
+ token_tags.append(bio)
|
|
|
+
|
|
|
+ prev_word_id = word_id
|
|
|
+
|
|
|
+ return token_tags
|
|
|
+
|
|
|
+
|
|
|
+def decode_bio_tags(
|
|
|
+ tag_indices: list[int],
|
|
|
+ token_ids: list[int],
|
|
|
+ tagset: BIOTagSet,
|
|
|
+ tokenizer: Any,
|
|
|
+) -> dict[str, str | list[str]]:
|
|
|
+ """
|
|
|
+ Decodiert BIO-Tag-Indizes zurueck zu Slot-Werten.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tag_indices: Vorhergesagte Tag-Indizes pro Token
|
|
|
+ token_ids: Token-IDs aus dem Tokenizer
|
|
|
+ tagset: BIOTagSet mit Tag-Mapping
|
|
|
+ tokenizer: HuggingFace Tokenizer fuer Token-Decoding
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Dict mit Slot-Namen und extrahierten Werten
|
|
|
+ Multi-Value: {"topping": ["Salami", "Pilze"]}
|
|
|
+ """
|
|
|
+ slots: dict[str, list[str]] = {}
|
|
|
+ current_slot: str | None = None
|
|
|
+ current_tokens: list[int] = []
|
|
|
+
|
|
|
+ for i, (tag_idx, token_id) in enumerate(zip(tag_indices, token_ids)):
|
|
|
+ tag = tagset.idx_to_tag(tag_idx)
|
|
|
+
|
|
|
+ if tag == "PAD" or tag == "O":
|
|
|
+ # Aktuellen Slot abschliessen
|
|
|
+ if current_slot and current_tokens:
|
|
|
+ value = tokenizer.decode(current_tokens, skip_special_tokens=True).strip()
|
|
|
+ if value:
|
|
|
+ slots.setdefault(current_slot, []).append(value)
|
|
|
+ current_slot = None
|
|
|
+ current_tokens = []
|
|
|
+
|
|
|
+ elif tag.startswith("B-"):
|
|
|
+ # Vorherigen Slot abschliessen
|
|
|
+ if current_slot and current_tokens:
|
|
|
+ value = tokenizer.decode(current_tokens, skip_special_tokens=True).strip()
|
|
|
+ if value:
|
|
|
+ slots.setdefault(current_slot, []).append(value)
|
|
|
+
|
|
|
+ # Neuen Slot starten
|
|
|
+ current_slot = tag[2:]
|
|
|
+ current_tokens = [token_id]
|
|
|
+
|
|
|
+ elif tag.startswith("I-") and current_slot == tag[2:]:
|
|
|
+ # Slot fortsetzen
|
|
|
+ current_tokens.append(token_id)
|
|
|
+
|
|
|
+ else:
|
|
|
+ # I-Tag ohne passenden B-Tag → ignorieren
|
|
|
+ if current_slot and current_tokens:
|
|
|
+ value = tokenizer.decode(current_tokens, skip_special_tokens=True).strip()
|
|
|
+ if value:
|
|
|
+ slots.setdefault(current_slot, []).append(value)
|
|
|
+ current_slot = None
|
|
|
+ current_tokens = []
|
|
|
+
|
|
|
+ # Letzten Slot abschliessen
|
|
|
+ if current_slot and current_tokens:
|
|
|
+ value = tokenizer.decode(current_tokens, skip_special_tokens=True).strip()
|
|
|
+ if value:
|
|
|
+ slots.setdefault(current_slot, []).append(value)
|
|
|
+
|
|
|
+ # Single-Value Slots: Liste → String wenn nur ein Wert
|
|
|
+ result: dict[str, str | list[str]] = {}
|
|
|
+ for name, values in slots.items():
|
|
|
+ if len(values) == 1:
|
|
|
+ result[name] = values[0]
|
|
|
+ else:
|
|
|
+ result[name] = values
|
|
|
+
|
|
|
+ return result
|