@@ -83,48 +83,67 @@ def add_examples(self, texts: List[str], labels: List[str]):
8383 raise ValueError ("Empty input lists" )
8484 if len (texts ) != len (labels ):
8585 raise ValueError ("Mismatched text and label lists" )
86-
86+
87+ # Check if classifier has any existing classes (before updating mappings)
88+ has_existing_classes = len (self .label_to_id ) > 0
89+
8790 # Check for new classes
8891 new_classes = set (labels ) - set (self .label_to_id .keys ())
8992 is_adding_new_classes = len (new_classes ) > 0
90-
93+
9194 # Update label mappings - sort new classes alphabetically for consistent IDs
9295 for label in sorted (new_classes ):
9396 idx = len (self .label_to_id )
9497 self .label_to_id [label ] = idx
9598 self .id_to_label [idx ] = label
96-
99+
97100 # Get embeddings for all texts
98101 embeddings = self ._get_embeddings (texts )
99-
102+
100103 # Add examples to memory and update training history
101104 for text , embedding , label in zip (texts , embeddings , labels ):
102105 example = Example (text , label , embedding )
103106 self .memory .add_example (example , label )
104-
107+
105108 # Update training history
106109 if label not in self .training_history :
107110 self .training_history [label ] = 0
108111 self .training_history [label ] += 1
109-
110- # Special handling for new classes
111- if is_adding_new_classes :
112- # Store old head for EWC
112+
113+ # Determine training strategy: only use special new class handling for incremental learning
114+ is_incremental_learning = is_adding_new_classes and has_existing_classes
115+
116+ if is_incremental_learning :
117+ # Adding new classes to existing classifier - use special handling
118+ # Store old head for EWC before modifying structure
113119 old_head = copy .deepcopy (self .adaptive_head ) if self .adaptive_head is not None else None
114-
115- # Initialize new head with more output classes
116- self ._initialize_adaptive_head ()
117-
120+
121+ # Expand existing head to accommodate new classes (preserves weights)
122+ num_classes = len (self .label_to_id )
123+ self .adaptive_head .update_num_classes (num_classes )
124+ # Move to correct device after update
125+ self .adaptive_head = self .adaptive_head .to (self .device )
126+
118127 # Train with focus on new classes
119128 self ._train_new_classes (old_head , new_classes )
120129 else :
121- # Regular training for existing classes
130+ # Initial training or regular updates - use normal training
131+ # Initialize head if needed
132+ if self .adaptive_head is None :
133+ self ._initialize_adaptive_head ()
134+ elif is_adding_new_classes :
135+ # Edge case: expanding head for new classes but treating as regular training
136+ num_classes = len (self .label_to_id )
137+ self .adaptive_head .update_num_classes (num_classes )
138+ self .adaptive_head = self .adaptive_head .to (self .device )
139+
140+ # Regular training
122141 self ._train_adaptive_head ()
123-
142+
124143 # Strategic training step if enabled
125144 if self .strategic_mode and self .train_steps % self .config .strategic_training_frequency == 0 :
126145 self ._perform_strategic_training ()
127-
146+
128147 # Ensure FAISS index is up to date after adding examples
129148 self .memory ._rebuild_index ()
130149
@@ -142,48 +161,94 @@ def _train_new_classes(self, old_head: Optional[nn.Module], new_classes: Set[str
142161 for label in self .memory .examples :
143162 examples_per_class [label ] = len (self .memory .examples [label ])
144163
145- # Calculate sampling weights to balance old and new classes
164+ # Improved sampling strategy for many-class scenarios
146165 min_examples = min (examples_per_class .values ())
147- sampling_weights = {}
148-
149- for label , count in examples_per_class .items ():
150- if label in new_classes :
151- # Oversample new classes
152- sampling_weights [label ] = 2.0
153- else :
154- # Sample old classes proportionally
155- sampling_weights [label ] = min_examples / count
156-
157- # Sample examples with weights
158- for label , examples in self .memory .examples .items ():
159- weight = sampling_weights [label ]
160- num_samples = max (min_examples , int (len (examples ) * weight ))
161-
162- # Randomly sample with replacement if needed
163- indices = np .random .choice (
164- len (examples ),
165- size = num_samples ,
166- replace = num_samples > len (examples )
167- )
168-
169- for idx in indices :
170- example = examples [idx ]
171- all_embeddings .append (example .embedding )
172- all_labels .append (self .label_to_id [label ])
166+ max_examples = max (examples_per_class .values ())
167+
168+ # For many-class scenarios, use a more balanced approach
169+ num_classes = len (examples_per_class )
170+ target_samples_per_class = max (5 , min (10 , min_examples * 2 )) # Adaptive target
171+
172+ if num_classes > 20 : # Many-class scenario
173+ # Use stratified sampling to ensure all classes get representation
174+ for label , examples in self .memory .examples .items ():
175+ if label in new_classes :
176+ # Give new classes more representation, but not excessive
177+ num_samples = min (len (examples ), target_samples_per_class * 2 )
178+ else :
179+ # Ensure old classes maintain representation
180+ num_samples = min (len (examples ), target_samples_per_class )
181+
182+ # Sample without replacement first, then with if needed
183+ if num_samples <= len (examples ):
184+ indices = np .random .choice (len (examples ), size = num_samples , replace = False )
185+ else :
186+ indices = np .random .choice (len (examples ), size = num_samples , replace = True )
187+
188+ for idx in indices :
189+ example = examples [idx ]
190+ all_embeddings .append (example .embedding )
191+ all_labels .append (self .label_to_id [label ])
192+ else :
193+ # Original strategy for fewer classes
194+ sampling_weights = {}
195+
196+ for label , count in examples_per_class .items ():
197+ if label in new_classes :
198+ # Oversample new classes
199+ sampling_weights [label ] = 2.0
200+ else :
201+ # Sample old classes proportionally
202+ sampling_weights [label ] = min_examples / count
203+
204+ # Sample examples with weights
205+ for label , examples in self .memory .examples .items ():
206+ weight = sampling_weights [label ]
207+ num_samples = max (min_examples , int (len (examples ) * weight ))
208+
209+ # Randomly sample with replacement if needed
210+ indices = np .random .choice (
211+ len (examples ),
212+ size = num_samples ,
213+ replace = num_samples > len (examples )
214+ )
215+
216+ for idx in indices :
217+ example = examples [idx ]
218+ all_embeddings .append (example .embedding )
219+ all_labels .append (self .label_to_id [label ])
173220
174221 all_embeddings = torch .stack (all_embeddings )
175222 all_labels = torch .tensor (all_labels )
176223
177224 # Create dataset and initialize EWC with lower penalty for new classes
178225 dataset = torch .utils .data .TensorDataset (all_embeddings , all_labels )
179-
226+
227+ ewc = None
180228 if old_head is not None :
181- ewc = EWC (
182- old_head ,
183- dataset ,
184- device = self .device ,
185- ewc_lambda = 10.0 # Lower EWC penalty to allow better learning of new classes
186- )
229+ # Create a dataset for EWC that only includes examples from old classes
230+ old_embeddings = []
231+ old_labels = []
232+ old_label_to_id = {label : idx for idx , label in enumerate (self .id_to_label .values ())
233+ if label not in new_classes }
234+
235+ for label , examples in self .memory .examples .items ():
236+ if label not in new_classes : # Only old classes
237+ for example in examples [:5 ]: # Limit to representative examples
238+ old_embeddings .append (example .embedding )
239+ old_labels .append (old_label_to_id [label ])
240+
241+ if old_embeddings : # Only create EWC if we have old examples
242+ old_embeddings = torch .stack (old_embeddings )
243+ old_labels = torch .tensor (old_labels , dtype = torch .long )
244+ old_dataset = torch .utils .data .TensorDataset (old_embeddings , old_labels )
245+
246+ ewc = EWC (
247+ old_head ,
248+ old_dataset ,
249+ device = self .device ,
250+ ewc_lambda = 5.0 # Balanced EWC penalty
251+ )
187252
188253 # Training setup
189254 self .adaptive_head .train ()
@@ -220,7 +285,7 @@ def _train_new_classes(self, old_head: Optional[nn.Module], new_classes: Set[str
220285 task_loss = criterion (outputs , batch_labels )
221286
222287 # Add EWC loss if applicable
223- if old_head is not None :
288+ if ewc is not None :
224289 ewc_loss = ewc .ewc_loss (batch_size = len (batch_embeddings ))
225290 loss = task_loss + ewc_loss
226291 else :
@@ -302,10 +367,12 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
302367 # Get embedding
303368 embedding = self ._get_embeddings ([text ])[0 ]
304369
305- # Get prototype predictions
306- proto_preds = self .memory .get_nearest_prototypes (embedding , k = k )
307-
308- # Get neural predictions if available
370+ # Get prototype predictions for ALL classes (not limited by k)
371+ # This ensures complete scoring information for proper combination
372+ max_classes = len (self .id_to_label ) if self .id_to_label else k
373+ proto_preds = self .memory .get_nearest_prototypes (embedding , k = max_classes )
374+
375+ # Get neural predictions if available for ALL classes (not limited by k)
309376 if self .adaptive_head is not None :
310377 self .adaptive_head .eval () # Ensure eval mode
311378 # Add batch dimension and move to device
@@ -314,8 +381,9 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
314381 # Squeeze batch dimension
315382 logits = logits .squeeze (0 )
316383 probs = F .softmax (logits , dim = 0 )
317-
318- values , indices = torch .topk (probs , min (k , len (self .id_to_label )))
384+
385+ # Get predictions for ALL classes for proper scoring combination
386+ values , indices = torch .topk (probs , len (self .id_to_label ))
319387 head_preds = [
320388 (self .id_to_label [idx .item ()], val .item ())
321389 for val , idx in zip (values , indices )
0 commit comments