Skip to content

Commit e2e819e

Browse files
authored
Merge pull request #62 from codelion/feat-trust-remote-enable
Feat trust remote enable
2 parents 6e7fe46 + 1aef514 commit e2e819e

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.1.1",
18+
version="0.1.2",
1919
author="codelion",
2020
author_email="codelion@okyasoft.com",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .multilabel import MultiLabelAdaptiveClassifier, MultiLabelAdaptiveHead
55
from huggingface_hub import ModelHubMixin
66

7-
__version__ = "0.1.1"
7+
__version__ = "0.1.2"
88

99
__all__ = [
1010
"AdaptiveClassifier",

src/adaptive_classifier/classifier.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __init__(
3333
device: Optional[str] = None,
3434
config: Optional[Dict[str, Any]] = None,
3535
seed: int = 42, # Add seed parameter
36-
use_onnx: Optional[Union[bool, str]] = "auto" # "auto", True, False
36+
use_onnx: Optional[Union[bool, str]] = "auto", # "auto", True, False
37+
trust_remote_code: bool = False
3738
):
3839
"""Initialize the adaptive classifier.
3940
@@ -44,6 +45,7 @@ def __init__(
4445
seed: Random seed for initialization
4546
use_onnx: Whether to use ONNX Runtime ("auto", True, False).
4647
"auto" uses ONNX on CPU, PyTorch on GPU.
48+
trust_remote_code: Whether to trust remote code when loading models (default: False)
4749
"""
4850
# Set seed for initialization
4951
torch.manual_seed(seed)
@@ -60,7 +62,8 @@ def __init__(
6062
logger.info(f"Initializing ONNX model for {model_name}")
6163
self.model = ORTModelForFeatureExtraction.from_pretrained(
6264
model_name,
63-
export=True # Auto-export to ONNX if not already in ONNX format
65+
export=True, # Auto-export to ONNX if not already in ONNX format
66+
trust_remote_code=trust_remote_code
6467
)
6568
logger.info("Successfully loaded ONNX model")
6669
except ImportError:
@@ -69,17 +72,17 @@ def __init__(
6972
"Install with: pip install optimum[onnxruntime]"
7073
)
7174
self.use_onnx = False
72-
self.model = AutoModel.from_pretrained(model_name).to(self.device)
75+
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device)
7376
except Exception as e:
7477
logger.warning(
7578
f"Failed to load ONNX model: {e}. Falling back to PyTorch."
7679
)
7780
self.use_onnx = False
78-
self.model = AutoModel.from_pretrained(model_name).to(self.device)
81+
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device)
7982
else:
80-
self.model = AutoModel.from_pretrained(model_name).to(self.device)
83+
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code).to(self.device)
8184

82-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
85+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
8386

8487
# Initialize memory system
8588
self.embedding_dim = self.model.config.hidden_size
@@ -637,6 +640,7 @@ def _from_pretrained(
637640
token: Optional[Union[str, bool]] = None,
638641
use_onnx: Optional[Union[bool, str]] = "auto",
639642
prefer_quantized: bool = True,
643+
trust_remote_code: bool = False,
640644
**kwargs
641645
) -> "AdaptiveClassifier":
642646
"""Load a model from the HuggingFace Hub or local directory.
@@ -653,6 +657,7 @@ def _from_pretrained(
653657
use_onnx: Whether to use ONNX Runtime ("auto", True, False)
654658
prefer_quantized: Use quantized ONNX model if available (default: True)
655659
Set to False to use unquantized model for maximum accuracy
660+
trust_remote_code: Whether to trust remote code when loading models (default: False)
656661
**kwargs: Additional arguments passed to from_pretrained
657662
658663
Returns:
@@ -667,6 +672,9 @@ def _from_pretrained(
667672
>>>
668673
>>> # Force PyTorch (no ONNX)
669674
>>> classifier = AdaptiveClassifier.load("adaptive-classifier/llm-router", use_onnx=False)
675+
>>>
676+
>>> # Load model requiring custom code
677+
>>> classifier = AdaptiveClassifier.load("model-with-custom-code", trust_remote_code=True)
670678
"""
671679

672680
# Check if model_id is a local directory
@@ -814,9 +822,10 @@ def _from_pretrained(
814822

815823
classifier.model = ORTModelForFeatureExtraction.from_pretrained(
816824
onnx_path,
817-
file_name=onnx_file
825+
file_name=onnx_file,
826+
trust_remote_code=trust_remote_code
818827
)
819-
classifier.tokenizer = AutoTokenizer.from_pretrained(config_dict['model_name'])
828+
classifier.tokenizer = AutoTokenizer.from_pretrained(config_dict['model_name'], trust_remote_code=trust_remote_code)
820829

821830
# Initialize memory and other components
822831
classifier.embedding_dim = classifier.model.config.hidden_size
@@ -852,7 +861,8 @@ def _from_pretrained(
852861
config_dict['model_name'],
853862
device=device,
854863
config=config_dict.get('config', None),
855-
use_onnx=final_use_onnx if isinstance(final_use_onnx, bool) else False
864+
use_onnx=final_use_onnx if isinstance(final_use_onnx, bool) else False,
865+
trust_remote_code=trust_remote_code
856866
)
857867

858868
# Restore label mappings
@@ -1187,19 +1197,20 @@ def save(self, save_dir: str, include_onnx: bool = True, quantize_onnx: bool = T
11871197
)
11881198

11891199
@classmethod
1190-
def load(cls, save_dir: str, device: Optional[str] = None, use_onnx: Optional[Union[bool, str]] = "auto", prefer_quantized: bool = True) -> 'AdaptiveClassifier':
1200+
def load(cls, save_dir: str, device: Optional[str] = None, use_onnx: Optional[Union[bool, str]] = "auto", prefer_quantized: bool = True, trust_remote_code: bool = False) -> 'AdaptiveClassifier':
11911201
"""Legacy load method for backwards compatibility.
11921202
11931203
Args:
11941204
save_dir: Directory to load from
11951205
device: Device to load model on
11961206
use_onnx: Whether to use ONNX Runtime ("auto", True, False)
11971207
prefer_quantized: Use quantized ONNX model if available (default: True)
1208+
trust_remote_code: Whether to trust remote code when loading models (default: False)
11981209
"""
11991210
kwargs = {}
12001211
if device is not None:
12011212
kwargs['device'] = device
1202-
return cls._from_pretrained(save_dir, use_onnx=use_onnx, prefer_quantized=prefer_quantized, **kwargs)
1213+
return cls._from_pretrained(save_dir, use_onnx=use_onnx, prefer_quantized=prefer_quantized, trust_remote_code=trust_remote_code, **kwargs)
12031214

12041215
def to(self, device: str) -> 'AdaptiveClassifier':
12051216
"""Move the model to specified device.

0 commit comments

Comments
 (0)