Skip to content

Commit aceb811

Browse files
authored
Merge pull request #54 from codelion/fix-multi-class-learning
fix multi class learning
2 parents cdb540d + ef4b2cd commit aceb811

File tree

4 files changed

+520
-74
lines changed

4 files changed

+520
-74
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.0.16",
18+
version="0.0.17",
1919
author="codelion",
2020
author_email="codelion@okyasoft.com",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/__init__.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,7 @@
33
from .memory import PrototypeMemory
44
from huggingface_hub import ModelHubMixin
55

6-
import os
7-
import re
8-
9-
def get_version_from_setup():
10-
try:
11-
setup_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'setup.py')
12-
with open(setup_path, 'r') as f:
13-
content = f.read()
14-
version_match = re.search(r'version=["\']([^"\']+)["\']', content)
15-
if version_match:
16-
return version_match.group(1)
17-
except Exception:
18-
pass
19-
return "unknown"
20-
21-
__version__ = get_version_from_setup()
6+
__version__ = "0.0.17"
227

238
__all__ = [
249
"AdaptiveClassifier",

src/adaptive_classifier/classifier.py

Lines changed: 125 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -83,48 +83,67 @@ def add_examples(self, texts: List[str], labels: List[str]):
8383
raise ValueError("Empty input lists")
8484
if len(texts) != len(labels):
8585
raise ValueError("Mismatched text and label lists")
86-
86+
87+
# Check if classifier has any existing classes (before updating mappings)
88+
has_existing_classes = len(self.label_to_id) > 0
89+
8790
# Check for new classes
8891
new_classes = set(labels) - set(self.label_to_id.keys())
8992
is_adding_new_classes = len(new_classes) > 0
90-
93+
9194
# Update label mappings - sort new classes alphabetically for consistent IDs
9295
for label in sorted(new_classes):
9396
idx = len(self.label_to_id)
9497
self.label_to_id[label] = idx
9598
self.id_to_label[idx] = label
96-
99+
97100
# Get embeddings for all texts
98101
embeddings = self._get_embeddings(texts)
99-
102+
100103
# Add examples to memory and update training history
101104
for text, embedding, label in zip(texts, embeddings, labels):
102105
example = Example(text, label, embedding)
103106
self.memory.add_example(example, label)
104-
107+
105108
# Update training history
106109
if label not in self.training_history:
107110
self.training_history[label] = 0
108111
self.training_history[label] += 1
109-
110-
# Special handling for new classes
111-
if is_adding_new_classes:
112-
# Store old head for EWC
112+
113+
# Determine training strategy: only use special new class handling for incremental learning
114+
is_incremental_learning = is_adding_new_classes and has_existing_classes
115+
116+
if is_incremental_learning:
117+
# Adding new classes to existing classifier - use special handling
118+
# Store old head for EWC before modifying structure
113119
old_head = copy.deepcopy(self.adaptive_head) if self.adaptive_head is not None else None
114-
115-
# Initialize new head with more output classes
116-
self._initialize_adaptive_head()
117-
120+
121+
# Expand existing head to accommodate new classes (preserves weights)
122+
num_classes = len(self.label_to_id)
123+
self.adaptive_head.update_num_classes(num_classes)
124+
# Move to correct device after update
125+
self.adaptive_head = self.adaptive_head.to(self.device)
126+
118127
# Train with focus on new classes
119128
self._train_new_classes(old_head, new_classes)
120129
else:
121-
# Regular training for existing classes
130+
# Initial training or regular updates - use normal training
131+
# Initialize head if needed
132+
if self.adaptive_head is None:
133+
self._initialize_adaptive_head()
134+
elif is_adding_new_classes:
135+
# Edge case: expanding head for new classes but treating as regular training
136+
num_classes = len(self.label_to_id)
137+
self.adaptive_head.update_num_classes(num_classes)
138+
self.adaptive_head = self.adaptive_head.to(self.device)
139+
140+
# Regular training
122141
self._train_adaptive_head()
123-
142+
124143
# Strategic training step if enabled
125144
if self.strategic_mode and self.train_steps % self.config.strategic_training_frequency == 0:
126145
self._perform_strategic_training()
127-
146+
128147
# Ensure FAISS index is up to date after adding examples
129148
self.memory._rebuild_index()
130149

@@ -142,48 +161,94 @@ def _train_new_classes(self, old_head: Optional[nn.Module], new_classes: Set[str
142161
for label in self.memory.examples:
143162
examples_per_class[label] = len(self.memory.examples[label])
144163

145-
# Calculate sampling weights to balance old and new classes
164+
# Improved sampling strategy for many-class scenarios
146165
min_examples = min(examples_per_class.values())
147-
sampling_weights = {}
148-
149-
for label, count in examples_per_class.items():
150-
if label in new_classes:
151-
# Oversample new classes
152-
sampling_weights[label] = 2.0
153-
else:
154-
# Sample old classes proportionally
155-
sampling_weights[label] = min_examples / count
156-
157-
# Sample examples with weights
158-
for label, examples in self.memory.examples.items():
159-
weight = sampling_weights[label]
160-
num_samples = max(min_examples, int(len(examples) * weight))
161-
162-
# Randomly sample with replacement if needed
163-
indices = np.random.choice(
164-
len(examples),
165-
size=num_samples,
166-
replace=num_samples > len(examples)
167-
)
168-
169-
for idx in indices:
170-
example = examples[idx]
171-
all_embeddings.append(example.embedding)
172-
all_labels.append(self.label_to_id[label])
166+
max_examples = max(examples_per_class.values())
167+
168+
# For many-class scenarios, use a more balanced approach
169+
num_classes = len(examples_per_class)
170+
target_samples_per_class = max(5, min(10, min_examples * 2)) # Adaptive target
171+
172+
if num_classes > 20: # Many-class scenario
173+
# Use stratified sampling to ensure all classes get representation
174+
for label, examples in self.memory.examples.items():
175+
if label in new_classes:
176+
# Give new classes more representation, but not excessive
177+
num_samples = min(len(examples), target_samples_per_class * 2)
178+
else:
179+
# Ensure old classes maintain representation
180+
num_samples = min(len(examples), target_samples_per_class)
181+
182+
# Sample without replacement first, then with if needed
183+
if num_samples <= len(examples):
184+
indices = np.random.choice(len(examples), size=num_samples, replace=False)
185+
else:
186+
indices = np.random.choice(len(examples), size=num_samples, replace=True)
187+
188+
for idx in indices:
189+
example = examples[idx]
190+
all_embeddings.append(example.embedding)
191+
all_labels.append(self.label_to_id[label])
192+
else:
193+
# Original strategy for fewer classes
194+
sampling_weights = {}
195+
196+
for label, count in examples_per_class.items():
197+
if label in new_classes:
198+
# Oversample new classes
199+
sampling_weights[label] = 2.0
200+
else:
201+
# Sample old classes proportionally
202+
sampling_weights[label] = min_examples / count
203+
204+
# Sample examples with weights
205+
for label, examples in self.memory.examples.items():
206+
weight = sampling_weights[label]
207+
num_samples = max(min_examples, int(len(examples) * weight))
208+
209+
# Randomly sample with replacement if needed
210+
indices = np.random.choice(
211+
len(examples),
212+
size=num_samples,
213+
replace=num_samples > len(examples)
214+
)
215+
216+
for idx in indices:
217+
example = examples[idx]
218+
all_embeddings.append(example.embedding)
219+
all_labels.append(self.label_to_id[label])
173220

174221
all_embeddings = torch.stack(all_embeddings)
175222
all_labels = torch.tensor(all_labels)
176223

177224
# Create dataset and initialize EWC with lower penalty for new classes
178225
dataset = torch.utils.data.TensorDataset(all_embeddings, all_labels)
179-
226+
227+
ewc = None
180228
if old_head is not None:
181-
ewc = EWC(
182-
old_head,
183-
dataset,
184-
device=self.device,
185-
ewc_lambda=10.0 # Lower EWC penalty to allow better learning of new classes
186-
)
229+
# Create a dataset for EWC that only includes examples from old classes
230+
old_embeddings = []
231+
old_labels = []
232+
old_label_to_id = {label: idx for idx, label in enumerate(self.id_to_label.values())
233+
if label not in new_classes}
234+
235+
for label, examples in self.memory.examples.items():
236+
if label not in new_classes: # Only old classes
237+
for example in examples[:5]: # Limit to representative examples
238+
old_embeddings.append(example.embedding)
239+
old_labels.append(old_label_to_id[label])
240+
241+
if old_embeddings: # Only create EWC if we have old examples
242+
old_embeddings = torch.stack(old_embeddings)
243+
old_labels = torch.tensor(old_labels, dtype=torch.long)
244+
old_dataset = torch.utils.data.TensorDataset(old_embeddings, old_labels)
245+
246+
ewc = EWC(
247+
old_head,
248+
old_dataset,
249+
device=self.device,
250+
ewc_lambda=5.0 # Balanced EWC penalty
251+
)
187252

188253
# Training setup
189254
self.adaptive_head.train()
@@ -220,7 +285,7 @@ def _train_new_classes(self, old_head: Optional[nn.Module], new_classes: Set[str
220285
task_loss = criterion(outputs, batch_labels)
221286

222287
# Add EWC loss if applicable
223-
if old_head is not None:
288+
if ewc is not None:
224289
ewc_loss = ewc.ewc_loss(batch_size=len(batch_embeddings))
225290
loss = task_loss + ewc_loss
226291
else:
@@ -302,10 +367,12 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
302367
# Get embedding
303368
embedding = self._get_embeddings([text])[0]
304369

305-
# Get prototype predictions
306-
proto_preds = self.memory.get_nearest_prototypes(embedding, k=k)
307-
308-
# Get neural predictions if available
370+
# Get prototype predictions for ALL classes (not limited by k)
371+
# This ensures complete scoring information for proper combination
372+
max_classes = len(self.id_to_label) if self.id_to_label else k
373+
proto_preds = self.memory.get_nearest_prototypes(embedding, k=max_classes)
374+
375+
# Get neural predictions if available for ALL classes (not limited by k)
309376
if self.adaptive_head is not None:
310377
self.adaptive_head.eval() # Ensure eval mode
311378
# Add batch dimension and move to device
@@ -314,8 +381,9 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
314381
# Squeeze batch dimension
315382
logits = logits.squeeze(0)
316383
probs = F.softmax(logits, dim=0)
317-
318-
values, indices = torch.topk(probs, min(k, len(self.id_to_label)))
384+
385+
# Get predictions for ALL classes for proper scoring combination
386+
values, indices = torch.topk(probs, len(self.id_to_label))
319387
head_preds = [
320388
(self.id_to_label[idx.item()], val.item())
321389
for val, idx in zip(values, indices)

0 commit comments

Comments
 (0)