|
| 1 | +"""Benchmark script comparing PyTorch vs ONNX vs Quantized ONNX performance.""" |
| 2 | + |
| 3 | +import time |
| 4 | +import argparse |
| 5 | +import tempfile |
| 6 | +from pathlib import Path |
| 7 | +import numpy as np |
| 8 | +from adaptive_classifier import AdaptiveClassifier |
| 9 | + |
| 10 | + |
| 11 | +def check_optimum_installed(): |
| 12 | + """Check if optimum is installed.""" |
| 13 | + try: |
| 14 | + import optimum.onnxruntime |
| 15 | + return True |
| 16 | + except ImportError: |
| 17 | + return False |
| 18 | + |
| 19 | + |
| 20 | +def benchmark_inference(classifier, texts, num_runs=100): |
| 21 | + """Benchmark inference speed.""" |
| 22 | + # Warmup |
| 23 | + for _ in range(5): |
| 24 | + classifier.predict(texts[0]) |
| 25 | + |
| 26 | + # Benchmark |
| 27 | + start_time = time.time() |
| 28 | + for _ in range(num_runs): |
| 29 | + for text in texts: |
| 30 | + classifier.predict(text) |
| 31 | + |
| 32 | + end_time = time.time() |
| 33 | + total_time = end_time - start_time |
| 34 | + avg_time_per_query = (total_time / (num_runs * len(texts))) * 1000 # ms |
| 35 | + |
| 36 | + return avg_time_per_query, total_time |
| 37 | + |
| 38 | + |
| 39 | +def main(): |
| 40 | + parser = argparse.ArgumentParser(description="Benchmark ONNX vs PyTorch performance") |
| 41 | + parser.add_argument("--model", type=str, default="prajjwal1/bert-tiny", |
| 42 | + help="HuggingFace model name to benchmark") |
| 43 | + parser.add_argument("--runs", type=int, default=100, |
| 44 | + help="Number of benchmark runs") |
| 45 | + parser.add_argument("--skip-quantized", action="store_true", |
| 46 | + help="Skip quantized ONNX benchmarking") |
| 47 | + args = parser.parse_args() |
| 48 | + |
| 49 | + if not check_optimum_installed(): |
| 50 | + print("⚠️ optimum[onnxruntime] not installed. Skipping ONNX benchmarks.") |
| 51 | + print("Install with: pip install optimum[onnxruntime]") |
| 52 | + return |
| 53 | + |
| 54 | + print("=" * 70) |
| 55 | + print("ONNX Runtime Benchmark for Adaptive Classifier") |
| 56 | + print("=" * 70) |
| 57 | + print(f"Model: {args.model}") |
| 58 | + print(f"Runs per test: {args.runs}") |
| 59 | + print() |
| 60 | + |
| 61 | + # Prepare test data |
| 62 | + test_texts = [ |
| 63 | + "This is a positive example", |
| 64 | + "This seems negative to me", |
| 65 | + "A neutral statement here", |
| 66 | + "Another test case for benchmarking performance", |
| 67 | + "The quick brown fox jumps over the lazy dog" |
| 68 | + ] |
| 69 | + |
| 70 | + print("Preparing classifiers...") |
| 71 | + print() |
| 72 | + |
| 73 | + # Train a baseline classifier |
| 74 | + classifier_base = AdaptiveClassifier(args.model, use_onnx=False, device="cpu") |
| 75 | + training_texts = [ |
| 76 | + "great product", "terrible experience", "okay item", |
| 77 | + "loved it", "hated it", "it's fine", |
| 78 | + "amazing quality", "poor service", "average performance" |
| 79 | + ] |
| 80 | + training_labels = [ |
| 81 | + "positive", "negative", "neutral", |
| 82 | + "positive", "negative", "neutral", |
| 83 | + "positive", "negative", "neutral" |
| 84 | + ] |
| 85 | + classifier_base.add_examples(training_texts, training_labels) |
| 86 | + |
| 87 | + # Save and create ONNX versions |
| 88 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 89 | + save_path = Path(tmpdir) / "classifier" |
| 90 | + |
| 91 | + # Save with ONNX versions |
| 92 | + print("Exporting ONNX models...") |
| 93 | + classifier_base._save_pretrained( |
| 94 | + save_path, |
| 95 | + include_onnx=True, |
| 96 | + quantize_onnx=not args.skip_quantized |
| 97 | + ) |
| 98 | + |
| 99 | + # Load PyTorch version |
| 100 | + print("Loading PyTorch model...") |
| 101 | + classifier_pytorch = AdaptiveClassifier._from_pretrained( |
| 102 | + str(save_path), |
| 103 | + use_onnx=False |
| 104 | + ) |
| 105 | + |
| 106 | + # Load ONNX version |
| 107 | + print("Loading ONNX model...") |
| 108 | + classifier_onnx = AdaptiveClassifier._from_pretrained( |
| 109 | + str(save_path), |
| 110 | + use_onnx=True |
| 111 | + ) |
| 112 | + |
| 113 | + print() |
| 114 | + print("Starting benchmarks...") |
| 115 | + print("-" * 70) |
| 116 | + |
| 117 | + # Benchmark PyTorch |
| 118 | + print("\n1. PyTorch Baseline") |
| 119 | + print(" Running benchmark...") |
| 120 | + pytorch_avg, pytorch_total = benchmark_inference( |
| 121 | + classifier_pytorch, test_texts, args.runs |
| 122 | + ) |
| 123 | + print(f" ✓ Average time per query: {pytorch_avg:.2f}ms") |
| 124 | + print(f" ✓ Total time: {pytorch_total:.2f}s") |
| 125 | + |
| 126 | + # Benchmark ONNX |
| 127 | + print("\n2. ONNX Runtime") |
| 128 | + print(" Running benchmark...") |
| 129 | + onnx_avg, onnx_total = benchmark_inference( |
| 130 | + classifier_onnx, test_texts, args.runs |
| 131 | + ) |
| 132 | + print(f" ✓ Average time per query: {onnx_avg:.2f}ms") |
| 133 | + print(f" ✓ Total time: {onnx_total:.2f}s") |
| 134 | + speedup = pytorch_avg / onnx_avg |
| 135 | + print(f" ✓ Speedup: {speedup:.2f}x faster than PyTorch") |
| 136 | + |
| 137 | + # Test prediction accuracy |
| 138 | + print("\n3. Accuracy Verification") |
| 139 | + test_text = "This is amazing!" |
| 140 | + pred_pytorch = classifier_pytorch.predict(test_text) |
| 141 | + pred_onnx = classifier_onnx.predict(test_text) |
| 142 | + |
| 143 | + print(f" PyTorch top prediction: {pred_pytorch[0]}") |
| 144 | + print(f" ONNX top prediction: {pred_onnx[0]}") |
| 145 | + |
| 146 | + if pred_pytorch[0][0] == pred_onnx[0][0]: |
| 147 | + print(" ✓ Predictions match!") |
| 148 | + else: |
| 149 | + print(" ⚠️ Predictions differ slightly") |
| 150 | + |
| 151 | + print() |
| 152 | + print("=" * 70) |
| 153 | + print("SUMMARY") |
| 154 | + print("=" * 70) |
| 155 | + print(f"PyTorch: {pytorch_avg:.2f}ms/query (baseline)") |
| 156 | + print(f"ONNX: {onnx_avg:.2f}ms/query ({speedup:.2f}x faster)") |
| 157 | + print() |
| 158 | + |
| 159 | + if speedup > 2.0: |
| 160 | + print("🚀 ONNX provides significant speedup! (>2x)") |
| 161 | + elif speedup > 1.2: |
| 162 | + print("⚡ ONNX provides moderate speedup") |
| 163 | + else: |
| 164 | + print("ℹ️ ONNX provides marginal speedup") |
| 165 | + |
| 166 | + print() |
| 167 | + print("=" * 70) |
| 168 | + print("\nRecommendation:") |
| 169 | + if speedup > 1.5: |
| 170 | + print("✓ Use ONNX for CPU inference for better performance!") |
| 171 | + print(" classifier = AdaptiveClassifier(model_name, use_onnx=True)") |
| 172 | + else: |
| 173 | + print("ℹ️ ONNX speedup is modest for this model.") |
| 174 | + print(" Consider using smaller models (distilbert, MiniLM) for better gains.") |
| 175 | + |
| 176 | + |
| 177 | +if __name__ == "__main__": |
| 178 | + main() |
0 commit comments