Skip to content

Commit 0a656e5

Browse files
authored
Merge pull request #59 from codelion/feat-add-onnx-support
Feat add onnx support
2 parents 064ecdc + d019067 commit 0a656e5

File tree

9 files changed

+1237
-44
lines changed

9 files changed

+1237
-44
lines changed

README.md

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ Adaptive Classifier is a PyTorch-based machine learning library that revolutioni
3030

3131
### 🎯 **Core Capabilities**
3232
- **🚀 Universal Compatibility** - Works with any HuggingFace transformer model
33+
- **⚡ Optimized Inference** - Built-in ONNX Runtime for 2-4x faster CPU predictions
3334
- **📈 Continuous Learning** - Add new examples without catastrophic forgetting
3435
- **🔄 Dynamic Classes** - Add new classes at runtime without retraining
35-
- ** Zero Downtime** - Update models in production without service interruption
36+
- **⏱️ Zero Downtime** - Update models in production without service interruption
3637

3738
### 🛡️ **Advanced Defense**
3839
- **🎮 Strategic Classification** - Game-theoretic defense against adversarial manipulation
@@ -99,6 +100,8 @@ Tested on arena-hard-auto-v0.1 dataset (500 queries):
99100
pip install adaptive-classifier
100101
```
101102

103+
**Includes:** ONNX Runtime for 2-4x faster CPU inference out-of-the-box
104+
102105
### 🛠️ Development Setup
103106
```bash
104107
# Clone the repository
@@ -191,6 +194,74 @@ predictions = strategic_classifier.predict("This product has amazing quality fea
191194
# Returns predictions that consider potential gaming attempts
192195
```
193196

197+
### ⚡ Optimized CPU Inference with ONNX
198+
199+
Adaptive Classifier includes **built-in ONNX Runtime support** for **2-4x faster CPU inference** with zero code changes required.
200+
201+
#### Automatic Optimization (Default)
202+
203+
ONNX Runtime is automatically used on CPU for optimal performance:
204+
205+
```python
206+
# Automatically uses ONNX on CPU, PyTorch on GPU
207+
classifier = AdaptiveClassifier("bert-base-uncased")
208+
209+
# That's it! Predictions are 2-4x faster on CPU
210+
predictions = classifier.predict("Fast inference!")
211+
```
212+
213+
#### Performance Comparison
214+
215+
| Configuration | Speed | Use Case |
216+
|--------------|-------|----------|
217+
| PyTorch (GPU) | Fastest | GPU servers |
218+
| **ONNX (CPU)** | **2-4x faster** | **Production CPU deployments** |
219+
| PyTorch (CPU) | Baseline | Development, training |
220+
221+
#### Save & Deploy with ONNX
222+
223+
```python
224+
# Save with ONNX export (both quantized & unquantized versions)
225+
classifier.save("./model")
226+
227+
# Push to Hub with ONNX (both versions included by default)
228+
classifier.push_to_hub("username/model")
229+
230+
# Load automatically uses quantized ONNX on CPU (fastest, 4x smaller)
231+
fast_classifier = AdaptiveClassifier.load("./model")
232+
233+
# Choose unquantized ONNX for maximum accuracy
234+
accurate_classifier = AdaptiveClassifier.load("./model", prefer_quantized=False)
235+
236+
# Force PyTorch (no ONNX)
237+
pytorch_classifier = AdaptiveClassifier.load("./model", use_onnx=False)
238+
239+
# Opt-out of ONNX export when saving
240+
classifier.save("./model", include_onnx=False)
241+
```
242+
243+
**ONNX Model Versions:**
244+
- **Quantized (default)**: INT8 quantized, 4x smaller, ~1.14x faster on ARM, 2-4x faster on x86
245+
- **Unquantized**: Full precision, maximum accuracy, larger file size
246+
247+
By default, models are saved with both versions, and the quantized version is automatically loaded for best performance. Use `prefer_quantized=False` if you need maximum accuracy.
248+
249+
#### Benchmark Your Model
250+
251+
```bash
252+
# Compare PyTorch vs ONNX performance
253+
python scripts/benchmark_onnx.py --model bert-base-uncased --runs 100
254+
```
255+
256+
**Example Results:**
257+
```
258+
Model: bert-base-uncased (CPU)
259+
PyTorch: 8.3ms/query (baseline)
260+
ONNX: 2.1ms/query (4.0x faster) ✓
261+
```
262+
263+
> **Note:** ONNX optimization is included by default. For GPU inference, PyTorch is automatically used for best performance.
264+
194265
## Advanced Usage
195266

196267
### Adding New Classes Dynamically

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ tqdm>=4.65.0
77
setuptools>=65.0.0
88
wheel>=0.40.0
99
scikit-learn
10-
huggingface_hub>=0.17.0
10+
huggingface_hub>=0.17.0
11+
optimum[onnxruntime]>=1.14.0

scripts/benchmark_onnx.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""Benchmark script comparing PyTorch vs ONNX vs Quantized ONNX performance."""
2+
3+
import time
4+
import argparse
5+
import tempfile
6+
from pathlib import Path
7+
import numpy as np
8+
from adaptive_classifier import AdaptiveClassifier
9+
10+
11+
def check_optimum_installed():
12+
"""Check if optimum is installed."""
13+
try:
14+
import optimum.onnxruntime
15+
return True
16+
except ImportError:
17+
return False
18+
19+
20+
def benchmark_inference(classifier, texts, num_runs=100):
21+
"""Benchmark inference speed."""
22+
# Warmup
23+
for _ in range(5):
24+
classifier.predict(texts[0])
25+
26+
# Benchmark
27+
start_time = time.time()
28+
for _ in range(num_runs):
29+
for text in texts:
30+
classifier.predict(text)
31+
32+
end_time = time.time()
33+
total_time = end_time - start_time
34+
avg_time_per_query = (total_time / (num_runs * len(texts))) * 1000 # ms
35+
36+
return avg_time_per_query, total_time
37+
38+
39+
def main():
40+
parser = argparse.ArgumentParser(description="Benchmark ONNX vs PyTorch performance")
41+
parser.add_argument("--model", type=str, default="prajjwal1/bert-tiny",
42+
help="HuggingFace model name to benchmark")
43+
parser.add_argument("--runs", type=int, default=100,
44+
help="Number of benchmark runs")
45+
parser.add_argument("--skip-quantized", action="store_true",
46+
help="Skip quantized ONNX benchmarking")
47+
args = parser.parse_args()
48+
49+
if not check_optimum_installed():
50+
print("⚠️ optimum[onnxruntime] not installed. Skipping ONNX benchmarks.")
51+
print("Install with: pip install optimum[onnxruntime]")
52+
return
53+
54+
print("=" * 70)
55+
print("ONNX Runtime Benchmark for Adaptive Classifier")
56+
print("=" * 70)
57+
print(f"Model: {args.model}")
58+
print(f"Runs per test: {args.runs}")
59+
print()
60+
61+
# Prepare test data
62+
test_texts = [
63+
"This is a positive example",
64+
"This seems negative to me",
65+
"A neutral statement here",
66+
"Another test case for benchmarking performance",
67+
"The quick brown fox jumps over the lazy dog"
68+
]
69+
70+
print("Preparing classifiers...")
71+
print()
72+
73+
# Train a baseline classifier
74+
classifier_base = AdaptiveClassifier(args.model, use_onnx=False, device="cpu")
75+
training_texts = [
76+
"great product", "terrible experience", "okay item",
77+
"loved it", "hated it", "it's fine",
78+
"amazing quality", "poor service", "average performance"
79+
]
80+
training_labels = [
81+
"positive", "negative", "neutral",
82+
"positive", "negative", "neutral",
83+
"positive", "negative", "neutral"
84+
]
85+
classifier_base.add_examples(training_texts, training_labels)
86+
87+
# Save and create ONNX versions
88+
with tempfile.TemporaryDirectory() as tmpdir:
89+
save_path = Path(tmpdir) / "classifier"
90+
91+
# Save with ONNX versions
92+
print("Exporting ONNX models...")
93+
classifier_base._save_pretrained(
94+
save_path,
95+
include_onnx=True,
96+
quantize_onnx=not args.skip_quantized
97+
)
98+
99+
# Load PyTorch version
100+
print("Loading PyTorch model...")
101+
classifier_pytorch = AdaptiveClassifier._from_pretrained(
102+
str(save_path),
103+
use_onnx=False
104+
)
105+
106+
# Load ONNX version
107+
print("Loading ONNX model...")
108+
classifier_onnx = AdaptiveClassifier._from_pretrained(
109+
str(save_path),
110+
use_onnx=True
111+
)
112+
113+
print()
114+
print("Starting benchmarks...")
115+
print("-" * 70)
116+
117+
# Benchmark PyTorch
118+
print("\n1. PyTorch Baseline")
119+
print(" Running benchmark...")
120+
pytorch_avg, pytorch_total = benchmark_inference(
121+
classifier_pytorch, test_texts, args.runs
122+
)
123+
print(f" ✓ Average time per query: {pytorch_avg:.2f}ms")
124+
print(f" ✓ Total time: {pytorch_total:.2f}s")
125+
126+
# Benchmark ONNX
127+
print("\n2. ONNX Runtime")
128+
print(" Running benchmark...")
129+
onnx_avg, onnx_total = benchmark_inference(
130+
classifier_onnx, test_texts, args.runs
131+
)
132+
print(f" ✓ Average time per query: {onnx_avg:.2f}ms")
133+
print(f" ✓ Total time: {onnx_total:.2f}s")
134+
speedup = pytorch_avg / onnx_avg
135+
print(f" ✓ Speedup: {speedup:.2f}x faster than PyTorch")
136+
137+
# Test prediction accuracy
138+
print("\n3. Accuracy Verification")
139+
test_text = "This is amazing!"
140+
pred_pytorch = classifier_pytorch.predict(test_text)
141+
pred_onnx = classifier_onnx.predict(test_text)
142+
143+
print(f" PyTorch top prediction: {pred_pytorch[0]}")
144+
print(f" ONNX top prediction: {pred_onnx[0]}")
145+
146+
if pred_pytorch[0][0] == pred_onnx[0][0]:
147+
print(" ✓ Predictions match!")
148+
else:
149+
print(" ⚠️ Predictions differ slightly")
150+
151+
print()
152+
print("=" * 70)
153+
print("SUMMARY")
154+
print("=" * 70)
155+
print(f"PyTorch: {pytorch_avg:.2f}ms/query (baseline)")
156+
print(f"ONNX: {onnx_avg:.2f}ms/query ({speedup:.2f}x faster)")
157+
print()
158+
159+
if speedup > 2.0:
160+
print("🚀 ONNX provides significant speedup! (>2x)")
161+
elif speedup > 1.2:
162+
print("⚡ ONNX provides moderate speedup")
163+
else:
164+
print("ℹ️ ONNX provides marginal speedup")
165+
166+
print()
167+
print("=" * 70)
168+
print("\nRecommendation:")
169+
if speedup > 1.5:
170+
print("✓ Use ONNX for CPU inference for better performance!")
171+
print(" classifier = AdaptiveClassifier(model_name, use_onnx=True)")
172+
else:
173+
print("ℹ️ ONNX speedup is modest for this model.")
174+
print(" Consider using smaller models (distilbert, MiniLM) for better gains.")
175+
176+
177+
if __name__ == "__main__":
178+
main()

0 commit comments

Comments
 (0)