@@ -58,10 +58,10 @@ def __init__(self, sagemaker_session=None):
5858
5959 Args:
6060 sagemaker_session: SageMaker session to use for API calls.
61- If None, will be created with beta endpoint if configured.
61+ If None, will be created with endpoint if configured.
6262 """
6363 self .sagemaker_session = sagemaker_session
64- self ._beta_endpoint = os .environ .get ('SAGEMAKER_ENDPOINT' )
64+ self ._endpoint = os .environ .get ('SAGEMAKER_ENDPOINT' )
6565
6666 def resolve_model_info (
6767 self ,
@@ -188,8 +188,21 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
188188 base_model_name = hub_content_name
189189 if hasattr (container .base_model , 'hub_content_arn' ):
190190 base_model_arn = container .base_model .hub_content_arn
191+
192+ # If hub_content_arn is not present, construct it from hub_content_name and version
193+ if not base_model_arn and hasattr (container .base_model , 'hub_content_version' ):
194+ hub_content_version = container .base_model .hub_content_version
195+ model_pkg_arn = getattr (model_package , 'model_package_arn' , None )
196+
197+ if hub_content_name and hub_content_version and model_pkg_arn :
198+ # Extract region from model package ARN
199+ arn_parts = model_pkg_arn .split (':' )
200+ if len (arn_parts ) >= 4 :
201+ region = arn_parts [3 ]
202+ # Construct hub content ARN for SageMaker public hub
203+ base_model_arn = f"arn:aws:sagemaker:{ region } :aws:hub-content/SageMakerPublicHub/Model/{ hub_content_name } /{ hub_content_version } "
191204
192- # If we couldn't extract base model ARN, this is not a supported model package
205+ # If we couldn't extract or construct base model ARN, this is not a supported model package
193206 if not base_model_arn :
194207 raise ValueError (
195208 f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
@@ -234,82 +247,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo:
234247 # Validate ARN format
235248 self ._validate_model_package_arn (model_package_arn )
236249
237- # TODO: Switch to sagemaker_core ModelPackage.get() once the bug is fixed
238- # Currently, ModelPackage.get() has a Pydantic validation issue where
239- # the transform() function doesn't include model_package_name in the response,
240- # causing: "1 validation error for ModelPackage - model_package_name: Field required"
241- # Using boto3 directly as a workaround.
242-
243- # Use the sagemaker client from the session (which has the correct endpoint configured)
244- sm_client = session .sagemaker_client if hasattr (session , 'sagemaker_client' ) else session .boto_session .client ('sagemaker' )
245- response = sm_client .describe_model_package (ModelPackageName = model_package_arn )
246-
247- # Extract base model info from response
248- base_model_name = None
249- base_model_arn = None
250- hub_content_name = None
250+ # Use sagemaker.core ModelPackage.get() to retrieve model package information
251+ from sagemaker .core .resources import ModelPackage
251252
252- # Check inference specification
253- if 'InferenceSpecification' not in response :
254- raise ValueError (
255- f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
256- f"The provided model package (ARN: { model_package_arn } ) "
257- f"does not have an inference_specification."
258- )
253+ import logging
254+ logger = logging .getLogger (__name__ )
259255
260- inf_spec = response ['InferenceSpecification' ]
261- if 'Containers' not in inf_spec or len (inf_spec ['Containers' ]) == 0 :
262- raise ValueError (
263- f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
264- f"The provided model package (ARN: { model_package_arn } ) "
265- f"does not have any containers in its inference_specification."
266- )
267-
268- container = inf_spec ['Containers' ][0 ]
269-
270- # Extract base model info
271- if 'BaseModel' not in container :
272- raise ValueError (
273- f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
274- f"The provided model package (ARN: { model_package_arn } ) "
275- f"does not have base_model metadata in its inference_specification.containers[0]. "
276- f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
277- )
278-
279- base_model_info = container ['BaseModel' ]
280- hub_content_name = base_model_info .get ('HubContentName' )
281- hub_content_version = base_model_info .get ('HubContentVersion' )
282- base_model_arn = base_model_info .get ('HubContentArn' )
283-
284- # If HubContentArn is None, construct it from HubContentName and version
285- # This handles cases where the API doesn't return the full ARN
286- if not base_model_arn and hub_content_name and hub_content_version :
287- # Extract region from model_package_arn
288- arn_parts = model_package_arn .split (':' )
289- if len (arn_parts ) >= 4 :
290- region = arn_parts [3 ]
291- # Construct hub content ARN for SageMaker public hub
292- base_model_arn = f"arn:aws:sagemaker:{ region } :aws:hub-content/SageMakerPublicHub/Model/{ hub_content_name } /{ hub_content_version } "
293-
294- if not base_model_arn :
295- raise ValueError (
296- f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
297- f"The provided model package (ARN: { model_package_arn } ) "
298- f"does not have base_model metadata with HubContentArn or sufficient information to construct it. "
299- f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
300- )
256+ # Get the model package using sagemaker.core
257+ model_package = ModelPackage .get (
258+ model_package_name = model_package_arn ,
259+ session = session .boto_session ,
260+ region = session .boto_session .region_name
261+ )
301262
302- # Use hub_content_name as base_model_name
303- base_model_name = hub_content_name if hub_content_name else response .get ('ModelPackageGroupName' , 'unknown' )
263+ logger .info (f"Retrieved ModelPackage in region: { session .boto_session .region_name } " )
304264
305- return _ModelInfo (
306- base_model_name = base_model_name ,
307- base_model_arn = base_model_arn ,
308- source_model_package_arn = model_package_arn ,
309- model_type = _ModelType .FINE_TUNED ,
310- hub_content_name = hub_content_name ,
311- additional_metadata = {}
312- )
265+ # Now use the existing _resolve_model_package_object method to extract base model info
266+ return self ._resolve_model_package_object (model_package )
313267
314268 except ValueError :
315269 # Re-raise ValueError as-is (our custom error messages)
@@ -342,7 +296,7 @@ def _validate_model_package_arn(self, arn: str) -> bool:
342296
343297 def _get_session (self ):
344298 """
345- Get or create SageMaker session with beta endpoint support.
299+ Get or create SageMaker session with endpoint support.
346300
347301 Returns:
348302 SageMaker session
@@ -352,12 +306,11 @@ def _get_session(self):
352306
353307 from sagemaker .core .helper .session_helper import Session
354308
355- # Check for beta endpoint in environment variable
356- if self ._beta_endpoint :
309+ # Check for endpoint in environment variable
310+ if self ._endpoint :
357311 sm_client = boto3 .client (
358312 'sagemaker' ,
359- endpoint_url = self ._beta_endpoint ,
360- region_name = os .environ .get ('AWS_REGION' , 'us-west-2' )
313+ endpoint_url = self ._endpoint
361314 )
362315 return Session (sagemaker_client = sm_client )
363316
0 commit comments