Skip to content

Commit 4e50064

Browse files
mufaddal-rohawalamufiAmazonrsareddy0329
authored
fix: Add lambda_type parameter for Nova models and migrate to ModelPackage (#5374)
* Add lambda_type parameter for Nova models and migrate to ModelPackage.get() - Add lambda_type: rft parameter for Nova models in custom_scorer_evaluator - Update pipeline templates to conditionally include lambda_type hyperparameter - Migrate model_resolution.py from boto3 to sagemaker.core ModelPackage.get() - Fix hub_content_arn construction when not provided by API - Update and add unit tests for all changes Testing done: 1. Ran Nova eval custom scorer. 2. Updated/ran unit tests * skip eval tests --------- Co-authored-by: Mufaddal Rohawala <mufi@amazon.com> Co-authored-by: rsareddy0329 <rsareddy0329@gmail.com>
1 parent d9bfe7c commit 4e50064

File tree

11 files changed

+251
-150
lines changed

11 files changed

+251
-150
lines changed

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def __init__(self, sagemaker_session=None):
5858
5959
Args:
6060
sagemaker_session: SageMaker session to use for API calls.
61-
If None, will be created with beta endpoint if configured.
61+
If None, will be created with endpoint if configured.
6262
"""
6363
self.sagemaker_session = sagemaker_session
64-
self._beta_endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
64+
self._endpoint = os.environ.get('SAGEMAKER_ENDPOINT')
6565

6666
def resolve_model_info(
6767
self,
@@ -188,8 +188,21 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model
188188
base_model_name = hub_content_name
189189
if hasattr(container.base_model, 'hub_content_arn'):
190190
base_model_arn = container.base_model.hub_content_arn
191+
192+
# If hub_content_arn is not present, construct it from hub_content_name and version
193+
if not base_model_arn and hasattr(container.base_model, 'hub_content_version'):
194+
hub_content_version = container.base_model.hub_content_version
195+
model_pkg_arn = getattr(model_package, 'model_package_arn', None)
196+
197+
if hub_content_name and hub_content_version and model_pkg_arn:
198+
# Extract region from model package ARN
199+
arn_parts = model_pkg_arn.split(':')
200+
if len(arn_parts) >= 4:
201+
region = arn_parts[3]
202+
# Construct hub content ARN for SageMaker public hub
203+
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
191204

192-
# If we couldn't extract base model ARN, this is not a supported model package
205+
# If we couldn't extract or construct base model ARN, this is not a supported model package
193206
if not base_model_arn:
194207
raise ValueError(
195208
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
@@ -234,82 +247,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo:
234247
# Validate ARN format
235248
self._validate_model_package_arn(model_package_arn)
236249

237-
# TODO: Switch to sagemaker_core ModelPackage.get() once the bug is fixed
238-
# Currently, ModelPackage.get() has a Pydantic validation issue where
239-
# the transform() function doesn't include model_package_name in the response,
240-
# causing: "1 validation error for ModelPackage - model_package_name: Field required"
241-
# Using boto3 directly as a workaround.
242-
243-
# Use the sagemaker client from the session (which has the correct endpoint configured)
244-
sm_client = session.sagemaker_client if hasattr(session, 'sagemaker_client') else session.boto_session.client('sagemaker')
245-
response = sm_client.describe_model_package(ModelPackageName=model_package_arn)
246-
247-
# Extract base model info from response
248-
base_model_name = None
249-
base_model_arn = None
250-
hub_content_name = None
250+
# Use sagemaker.core ModelPackage.get() to retrieve model package information
251+
from sagemaker.core.resources import ModelPackage
251252

252-
# Check inference specification
253-
if 'InferenceSpecification' not in response:
254-
raise ValueError(
255-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
256-
f"The provided model package (ARN: {model_package_arn}) "
257-
f"does not have an inference_specification."
258-
)
253+
import logging
254+
logger = logging.getLogger(__name__)
259255

260-
inf_spec = response['InferenceSpecification']
261-
if 'Containers' not in inf_spec or len(inf_spec['Containers']) == 0:
262-
raise ValueError(
263-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
264-
f"The provided model package (ARN: {model_package_arn}) "
265-
f"does not have any containers in its inference_specification."
266-
)
267-
268-
container = inf_spec['Containers'][0]
269-
270-
# Extract base model info
271-
if 'BaseModel' not in container:
272-
raise ValueError(
273-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
274-
f"The provided model package (ARN: {model_package_arn}) "
275-
f"does not have base_model metadata in its inference_specification.containers[0]. "
276-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
277-
)
278-
279-
base_model_info = container['BaseModel']
280-
hub_content_name = base_model_info.get('HubContentName')
281-
hub_content_version = base_model_info.get('HubContentVersion')
282-
base_model_arn = base_model_info.get('HubContentArn')
283-
284-
# If HubContentArn is None, construct it from HubContentName and version
285-
# This handles cases where the API doesn't return the full ARN
286-
if not base_model_arn and hub_content_name and hub_content_version:
287-
# Extract region from model_package_arn
288-
arn_parts = model_package_arn.split(':')
289-
if len(arn_parts) >= 4:
290-
region = arn_parts[3]
291-
# Construct hub content ARN for SageMaker public hub
292-
base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}"
293-
294-
if not base_model_arn:
295-
raise ValueError(
296-
f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. "
297-
f"The provided model package (ARN: {model_package_arn}) "
298-
f"does not have base_model metadata with HubContentArn or sufficient information to construct it. "
299-
f"Please ensure the model was created using SageMaker's fine-tuning capabilities."
300-
)
256+
# Get the model package using sagemaker.core
257+
model_package = ModelPackage.get(
258+
model_package_name=model_package_arn,
259+
session=session.boto_session,
260+
region=session.boto_session.region_name
261+
)
301262

302-
# Use hub_content_name as base_model_name
303-
base_model_name = hub_content_name if hub_content_name else response.get('ModelPackageGroupName', 'unknown')
263+
logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}")
304264

305-
return _ModelInfo(
306-
base_model_name=base_model_name,
307-
base_model_arn=base_model_arn,
308-
source_model_package_arn=model_package_arn,
309-
model_type=_ModelType.FINE_TUNED,
310-
hub_content_name=hub_content_name,
311-
additional_metadata={}
312-
)
265+
# Now use the existing _resolve_model_package_object method to extract base model info
266+
return self._resolve_model_package_object(model_package)
313267

314268
except ValueError:
315269
# Re-raise ValueError as-is (our custom error messages)
@@ -342,7 +296,7 @@ def _validate_model_package_arn(self, arn: str) -> bool:
342296

343297
def _get_session(self):
344298
"""
345-
Get or create SageMaker session with beta endpoint support.
299+
Get or create SageMaker session with endpoint support.
346300
347301
Returns:
348302
SageMaker session
@@ -352,12 +306,11 @@ def _get_session(self):
352306

353307
from sagemaker.core.helper.session_helper import Session
354308

355-
# Check for beta endpoint in environment variable
356-
if self._beta_endpoint:
309+
# Check for endpoint in environment variable
310+
if self._endpoint:
357311
sm_client = boto3.client(
358312
'sagemaker',
359-
endpoint_url=self._beta_endpoint,
360-
region_name=os.environ.get('AWS_REGION', 'us-west-2')
313+
endpoint_url=self._endpoint
361314
)
362315
return Session(sagemaker_client=sm_client)
363316

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,8 @@ def _get_or_create_artifact_arn(self, source_uri: str, region: str) -> str:
546546
properties['HubContentArn'] = source_uri
547547
else:
548548
properties['SourceUri'] = source_uri
549+
550+
_logger.info(f"source_uri: {source_uri}, region: {region}, properties: {properties}")
549551

550552
# Create artifact using Artifact.create()
551553
artifact = Artifact.create(

sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def _get_custom_scorer_template_additions(self, evaluator_config: dict) -> dict:
308308
'evaluator_arn': evaluator_config['evaluator_arn'],
309309
}
310310

311+
# Add lambda_type for Nova models
312+
if is_nova:
313+
custom_scorer_context['lambda_type'] = 'rft'
314+
311315
# Add preset_reward_function if present
312316
if evaluator_config['preset_reward_function']:
313317
custom_scorer_context['preset_reward_function'] = evaluator_config['preset_reward_function']

sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,8 @@
632632
"task": "{{ task }}",
633633
"strategy": "{{ strategy }}"{% if metric is defined %},
634634
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
635-
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
635+
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
636+
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
636637
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
637638
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
638639
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
@@ -694,7 +695,8 @@
694695
"task": "{{ task }}",
695696
"strategy": "{{ strategy }}"{% if metric is defined %},
696697
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
697-
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
698+
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
699+
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
698700
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
699701
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
700702
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
@@ -872,7 +874,8 @@
872874
"task": "{{ task }}",
873875
"strategy": "{{ strategy }}"{% if metric is defined %},
874876
"metric": "{{ metric }}"{% elif evaluation_metric is defined %},
875-
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %},
877+
"evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %},
878+
"lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %},
876879
"max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %},
877880
"temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %},
878881
"top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %},
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Integration tests for SageMaker Modules Evaluate"""
14+
from __future__ import absolute_import

tests/integ/sagemaker/modules/evaluate/test_benchmark_evaluator.py renamed to sagemaker-train/tests/integ/train/test_benchmark_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
}
7373

7474

75+
@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
7576
class TestBenchmarkEvaluatorIntegration:
7677
"""Integration tests for BenchmarkEvaluator with fine-tuned model package"""
7778

tests/integ/sagemaker/modules/evaluate/test_custom_scorer_evaluator.py renamed to sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
}
5656

5757

58+
@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
5859
class TestCustomScorerEvaluatorIntegration:
5960
"""Integration tests for CustomScorerEvaluator with custom evaluator"""
6061

tests/integ/sagemaker/modules/evaluate/test_llm_as_judge_evaluator.py renamed to sagemaker-train/tests/integ/train/test_llm_as_judge_evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
}
8585

8686

87+
@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
8788
class TestLLMAsJudgeEvaluatorIntegration:
8889
"""Integration tests for LLMAsJudgeEvaluator"""
8990

0 commit comments

Comments
 (0)