@@ -33,7 +33,8 @@ def __init__(
3333 device : Optional [str ] = None ,
3434 config : Optional [Dict [str , Any ]] = None ,
3535 seed : int = 42 , # Add seed parameter
36- use_onnx : Optional [Union [bool , str ]] = "auto" # "auto", True, False
36+ use_onnx : Optional [Union [bool , str ]] = "auto" , # "auto", True, False
37+ trust_remote_code : bool = False
3738 ):
3839 """Initialize the adaptive classifier.
3940
@@ -44,6 +45,7 @@ def __init__(
4445 seed: Random seed for initialization
4546 use_onnx: Whether to use ONNX Runtime ("auto", True, False).
4647 "auto" uses ONNX on CPU, PyTorch on GPU.
48+ trust_remote_code: Whether to trust remote code when loading models (default: False)
4749 """
4850 # Set seed for initialization
4951 torch .manual_seed (seed )
@@ -60,7 +62,8 @@ def __init__(
6062 logger .info (f"Initializing ONNX model for { model_name } " )
6163 self .model = ORTModelForFeatureExtraction .from_pretrained (
6264 model_name ,
63- export = True # Auto-export to ONNX if not already in ONNX format
65+ export = True , # Auto-export to ONNX if not already in ONNX format
66+ trust_remote_code = trust_remote_code
6467 )
6568 logger .info ("Successfully loaded ONNX model" )
6669 except ImportError :
@@ -69,17 +72,17 @@ def __init__(
6972 "Install with: pip install optimum[onnxruntime]"
7073 )
7174 self .use_onnx = False
72- self .model = AutoModel .from_pretrained (model_name ).to (self .device )
75+ self .model = AutoModel .from_pretrained (model_name , trust_remote_code = trust_remote_code ).to (self .device )
7376 except Exception as e :
7477 logger .warning (
7578 f"Failed to load ONNX model: { e } . Falling back to PyTorch."
7679 )
7780 self .use_onnx = False
78- self .model = AutoModel .from_pretrained (model_name ).to (self .device )
81+ self .model = AutoModel .from_pretrained (model_name , trust_remote_code = trust_remote_code ).to (self .device )
7982 else :
80- self .model = AutoModel .from_pretrained (model_name ).to (self .device )
83+ self .model = AutoModel .from_pretrained (model_name , trust_remote_code = trust_remote_code ).to (self .device )
8184
82- self .tokenizer = AutoTokenizer .from_pretrained (model_name )
85+ self .tokenizer = AutoTokenizer .from_pretrained (model_name , trust_remote_code = trust_remote_code )
8386
8487 # Initialize memory system
8588 self .embedding_dim = self .model .config .hidden_size
@@ -637,6 +640,7 @@ def _from_pretrained(
637640 token : Optional [Union [str , bool ]] = None ,
638641 use_onnx : Optional [Union [bool , str ]] = "auto" ,
639642 prefer_quantized : bool = True ,
643+ trust_remote_code : bool = False ,
640644 ** kwargs
641645 ) -> "AdaptiveClassifier" :
642646 """Load a model from the HuggingFace Hub or local directory.
@@ -653,6 +657,7 @@ def _from_pretrained(
653657 use_onnx: Whether to use ONNX Runtime ("auto", True, False)
654658 prefer_quantized: Use quantized ONNX model if available (default: True)
655659 Set to False to use unquantized model for maximum accuracy
660+ trust_remote_code: Whether to trust remote code when loading models (default: False)
656661 **kwargs: Additional arguments passed to from_pretrained
657662
658663 Returns:
@@ -667,6 +672,9 @@ def _from_pretrained(
667672 >>>
668673 >>> # Force PyTorch (no ONNX)
669674 >>> classifier = AdaptiveClassifier.load("adaptive-classifier/llm-router", use_onnx=False)
675+ >>>
676+ >>> # Load model requiring custom code
677+ >>> classifier = AdaptiveClassifier.load("model-with-custom-code", trust_remote_code=True)
670678 """
671679
672680 # Check if model_id is a local directory
@@ -814,9 +822,10 @@ def _from_pretrained(
814822
815823 classifier .model = ORTModelForFeatureExtraction .from_pretrained (
816824 onnx_path ,
817- file_name = onnx_file
825+ file_name = onnx_file ,
826+ trust_remote_code = trust_remote_code
818827 )
819- classifier .tokenizer = AutoTokenizer .from_pretrained (config_dict ['model_name' ])
828+ classifier .tokenizer = AutoTokenizer .from_pretrained (config_dict ['model_name' ], trust_remote_code = trust_remote_code )
820829
821830 # Initialize memory and other components
822831 classifier .embedding_dim = classifier .model .config .hidden_size
@@ -852,7 +861,8 @@ def _from_pretrained(
852861 config_dict ['model_name' ],
853862 device = device ,
854863 config = config_dict .get ('config' , None ),
855- use_onnx = final_use_onnx if isinstance (final_use_onnx , bool ) else False
864+ use_onnx = final_use_onnx if isinstance (final_use_onnx , bool ) else False ,
865+ trust_remote_code = trust_remote_code
856866 )
857867
858868 # Restore label mappings
@@ -1187,19 +1197,20 @@ def save(self, save_dir: str, include_onnx: bool = True, quantize_onnx: bool = T
11871197 )
11881198
11891199 @classmethod
1190- def load (cls , save_dir : str , device : Optional [str ] = None , use_onnx : Optional [Union [bool , str ]] = "auto" , prefer_quantized : bool = True ) -> 'AdaptiveClassifier' :
1200+ def load (cls , save_dir : str , device : Optional [str ] = None , use_onnx : Optional [Union [bool , str ]] = "auto" , prefer_quantized : bool = True , trust_remote_code : bool = False ) -> 'AdaptiveClassifier' :
11911201 """Legacy load method for backwards compatibility.
11921202
11931203 Args:
11941204 save_dir: Directory to load from
11951205 device: Device to load model on
11961206 use_onnx: Whether to use ONNX Runtime ("auto", True, False)
11971207 prefer_quantized: Use quantized ONNX model if available (default: True)
1208+ trust_remote_code: Whether to trust remote code when loading models (default: False)
11981209 """
11991210 kwargs = {}
12001211 if device is not None :
12011212 kwargs ['device' ] = device
1202- return cls ._from_pretrained (save_dir , use_onnx = use_onnx , prefer_quantized = prefer_quantized , ** kwargs )
1213+ return cls ._from_pretrained (save_dir , use_onnx = use_onnx , prefer_quantized = prefer_quantized , trust_remote_code = trust_remote_code , ** kwargs )
12031214
12041215 def to (self , device : str ) -> 'AdaptiveClassifier' :
12051216 """Move the model to specified device.
0 commit comments