Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 33 additions & 80 deletions sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def __init__(self, sagemaker_session=None):

Args:
sagemaker_session: SageMaker session to use for API calls.
If None, will be created with beta endpoint if configured.
If None, will be created with endpoint if configured.
"""
self.sagemaker_session = sagemaker_session
self._beta_endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
self._endpoint = os.environ.get('SAGEMAKER_ENDPOINT')

def resolve_model_info(
self,
Expand Down Expand Up @@ -188,8 +188,21 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
base_model_name = hub_content_name
if hasattr(container.base_model, 'hub_content_arn'):
base_model_arn = container.base_model.hub_content_arn

# If hub_content_arn is not present, construct it from hub_content_name and version
if not base_model_arn and hasattr(container.base_model, 'hub_content_version'):
hub_content_version = container.base_model.hub_content_version
model_pkg_arn = getattr(model_package, 'model_package_arn', None)

if hub_content_name and hub_content_version and model_pkg_arn:
# Extract region from model package ARN
arn_parts = model_pkg_arn.split(':')
if len(arn_parts) >= 4:
region = arn_parts[3]
# Construct hub content ARN for SageMaker public hub
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"

# If we couldn't extract base model ARN, this is not a supported model package
# If we couldn't extract or construct base model ARN, this is not a supported model package
if not base_model_arn:
raise ValueError(
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
Expand Down Expand Up @@ -234,82 +247,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo:
# Validate ARN format
self._validate_model_package_arn(model_package_arn)

# TODO: Switch to sagemaker_core ModelPackage.get() once the bug is fixed
# Currently, ModelPackage.get() has a Pydantic validation issue where
# the transform() function doesn't include model_package_name in the response,
# causing: "1 validation error for ModelPackage - model_package_name: Field required"
# Using boto3 directly as a workaround.

# Use the sagemaker client from the session (which has the correct endpoint configured)
sm_client = session.sagemaker_client if hasattr(session, 'sagemaker_client') else session.boto_session.client('sagemaker')
response = sm_client.describe_model_package(ModelPackageName=model_package_arn)

# Extract base model info from response
base_model_name = None
base_model_arn = None
hub_content_name = None
# Use sagemaker.core ModelPackage.get() to retrieve model package information
from sagemaker.core.resources import ModelPackage

# Check inference specification
if 'InferenceSpecification' not in response:
raise ValueError(
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
f"The provided model package (ARN: {model_package_arn}) "
f"does not have an inference_specification."
)
import logging
logger = logging.getLogger(__name__)

inf_spec = response['InferenceSpecification']
if 'Containers' not in inf_spec or len(inf_spec['Containers']) == 0:
raise ValueError(
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
f"The provided model package (ARN: {model_package_arn}) "
f"does not have any containers in its inference_specification."
)

container = inf_spec['Containers'][0]

# Extract base model info
if 'BaseModel' not in container:
raise ValueError(
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
f"The provided model package (ARN: {model_package_arn}) "
f"does not have base_model metadata in its inference_specification.containers[0]. "
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
)

base_model_info = container['BaseModel']
hub_content_name = base_model_info.get('HubContentName')
hub_content_version = base_model_info.get('HubContentVersion')
base_model_arn = base_model_info.get('HubContentArn')

# If HubContentArn is None, construct it from HubContentName and version
# This handles cases where the API doesn't return the full ARN
if not base_model_arn and hub_content_name and hub_content_version:
# Extract region from model_package_arn
arn_parts = model_package_arn.split(':')
if len(arn_parts) >= 4:
region = arn_parts[3]
# Construct hub content ARN for SageMaker public hub
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"

if not base_model_arn:
raise ValueError(
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
f"The provided model package (ARN: {model_package_arn}) "
f"does not have base_model metadata with HubContentArn or sufficient information to construct it. "
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
)
# Get the model package using sagemaker.core
model_package = ModelPackage.get(
model_package_name=model_package_arn,
session=session.boto_session,
region=session.boto_session.region_name
)

# Use hub_content_name as base_model_name
base_model_name = hub_content_name if hub_content_name else response.get('ModelPackageGroupName', 'unknown')
logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}")

return _ModelInfo(
base_model_name=base_model_name,
base_model_arn=base_model_arn,
source_model_package_arn=model_package_arn,
model_type=_ModelType.FINE_TUNED,
hub_content_name=hub_content_name,
additional_metadata={}
)
# Now use the existing _resolve_model_package_object method to extract base model info
return self._resolve_model_package_object(model_package)

except ValueError:
# Re-raise ValueError as-is (our custom error messages)
Expand Down Expand Up @@ -342,7 +296,7 @@ def _validate_model_package_arn(self, arn: str) -> bool:

def _get_session(self):
"""
Get or create SageMaker session with beta endpoint support.
Get or create SageMaker session with endpoint support.

Returns:
SageMaker session
Expand All @@ -352,12 +306,11 @@ def _get_session(self):

from sagemaker.core.helper.session_helper import Session

# Check for beta endpoint in environment variable
if self._beta_endpoint:
# Check for endpoint in environment variable
if self._endpoint:
sm_client = boto3.client(
'sagemaker',
endpoint_url=self._beta_endpoint,
region_name=os.environ.get('AWS_REGION', 'us-west-2')
endpoint_url=self._endpoint
)
return Session(sagemaker_client=sm_client)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ def _get_or_create_artifact_arn(self, source_uri: str, region: str) -> str:
properties['HubContentArn'] = source_uri
else:
properties['SourceUri'] = source_uri

_logger.info(f"source_uri: {source_uri}, region: {region}, properties: {properties}")

# Create artifact using Artifact.create()
artifact = Artifact.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ def _get_custom_scorer_template_additions(self, evaluator_config: dict) -> dict:
'evaluator_arn': evaluator_config['evaluator_arn'],
}

# Add lambda_type for Nova models
if is_nova:
custom_scorer_context['lambda_type'] = 'rft'

# Add preset_reward_function if present
if evaluator_config['preset_reward_function']:
custom_scorer_context['preset_reward_function'] = evaluator_config['preset_reward_function']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@
"task": "{{ task }}",
"strategy": "{{ strategy }}"{% if metric is defined %},
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
Expand Down Expand Up @@ -694,7 +695,8 @@
"task": "{{ task }}",
"strategy": "{{ strategy }}"{% if metric is defined %},
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
Expand Down Expand Up @@ -872,7 +874,8 @@
"task": "{{ task }}",
"strategy": "{{ strategy }}"{% if metric is defined %},
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
Expand Down
14 changes: 14 additions & 0 deletions sagemaker-train/tests/integ/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Integration tests for SageMaker Modules Evaluate"""
from __future__ import absolute_import
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
}


@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
class TestBenchmarkEvaluatorIntegration:
"""Integration tests for BenchmarkEvaluator with fine-tuned model package"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
}


@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
class TestCustomScorerEvaluatorIntegration:
"""Integration tests for CustomScorerEvaluator with custom evaluator"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
}


@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
class TestLLMAsJudgeEvaluatorIntegration:
"""Integration tests for LLMAsJudgeEvaluator"""

Expand Down
Loading
Loading