Skip to content

Commit 479ee4d

Browse files
committed
Updating evaluator to append timestamp to function name
1 parent db12b1d commit 479ee4d

File tree

2 files changed

+124
-111
lines changed

2 files changed

+124
-111
lines changed

sagemaker-train/src/sagemaker/ai_registry/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _create_lambda_function(cls, name: str, source_file: str, role: Optional[str
381381

382382
# Create Lambda function
383383
lambda_client = boto3.client("lambda")
384-
function_name = f"SageMaker-evaluator-{name}"
384+
function_name = f"SageMaker-evaluator-{name}-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
385385
handler_name = f"{os.path.splitext(os.path.basename(source_file))[0]}.lambda_handler"
386386

387387
try:

sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py

Lines changed: 123 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
with progress tracking and MLflow integration.
55
"""
66

7+
import logging
78
import time
9+
from contextlib import contextmanager
810
from typing import Optional, Tuple
911

1012
from sagemaker.core.resources import TrainingJob
@@ -13,6 +15,18 @@
1315
from sagemaker.train.common_utils.mlflow_metrics_util import _MLflowMetricsUtil
1416

1517

18+
@contextmanager
19+
def _suppress_info_logging():
20+
"""Context manager to temporarily suppress INFO level logging."""
21+
logger = logging.getLogger()
22+
original_level = logger.level
23+
logger.setLevel(logging.WARNING)
24+
try:
25+
yield
26+
finally:
27+
logger.setLevel(original_level)
28+
29+
1630
def _setup_mlflow_integration(training_job: TrainingJob) -> Tuple[
1731
Optional[str], Optional[_MLflowMetricsUtil], Optional[str]]:
1832
"""Setup MLflow integration for training job monitoring.
@@ -172,124 +186,123 @@ def wait(
172186
from rich.panel import Panel
173187
from rich.text import Text
174188
from rich.console import Group
189+
with _suppress_info_logging():
190+
console = Console(force_jupyter=True)
191+
192+
iteration = 0
193+
while True:
194+
iteration += 1
195+
time.sleep(poll)
196+
training_job.refresh()
197+
clear_output(wait=True)
198+
199+
status = training_job.training_job_status
200+
secondary_status = training_job.secondary_status
201+
elapsed = time.time() - start_time
202+
203+
# Header section with training job name and MLFlow URL
204+
header_table = Table(show_header=False, box=None, padding=(0, 1))
205+
header_table.add_column("Property", style="cyan bold", width=20)
206+
header_table.add_column("Value", style="white")
207+
header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]")
208+
if mlflow_url:
209+
header_table.add_row("MLFlow URL",
210+
f"[link={mlflow_url}][bold bright_blue underline]{mlflow_run_name}(link valid for 5 mins)[/bright_blue bold underline][/link]")
211+
212+
status_table = Table(show_header=False, box=None, padding=(0, 1))
213+
status_table.add_column("Property", style="cyan bold", width=20)
214+
status_table.add_column("Value", style="white")
215+
216+
status_table.add_row("Job Status", f"[bold][orange3]{status}[/][/]")
217+
status_table.add_row("Secondary Status", f"[bold yellow]{secondary_status}[/bold yellow]")
218+
status_table.add_row("Elapsed Time", f"[bold bright_red]{elapsed:.1f}s[/bold bright_red]")
219+
220+
failure_reason = training_job.failure_reason
221+
if failure_reason and not _is_unassigned_attribute(failure_reason):
222+
status_table.add_row("Failure Reason", f"[bright_red]{failure_reason}[/bright_red]")
223+
224+
# Calculate training progress
225+
training_progress_pct = None
226+
training_progress_text = ""
227+
if secondary_status == "Training" and training_job.progress_info:
228+
if not progress_started:
229+
progress_started = True
230+
time.sleep(poll)
231+
training_job.refresh()
232+
233+
training_progress_pct, training_progress_text = _calculate_training_progress(
234+
training_job.progress_info, metrics_util, mlflow_run_name, training_job
235+
)
236+
237+
# Build transitions table if available
238+
transitions_table = None
239+
if training_job.secondary_status_transitions:
240+
from rich.box import SIMPLE
241+
transitions_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE, padding=(0, 1))
242+
transitions_table.add_column("", style="green", width=2)
243+
transitions_table.add_column("Step", style="cyan", width=15)
244+
transitions_table.add_column("Details", style="orange3", width=35)
245+
transitions_table.add_column("Duration", style="green", width=12)
246+
247+
for trans in training_job.secondary_status_transitions:
248+
duration, check = _calculate_transition_duration(trans)
249+
250+
# Add progress bar for Training step
251+
if trans.status == "Training" and training_progress_pct is not None:
252+
bar = f"[green][{'█' * int(training_progress_pct / 5)}{'░' * (20 - int(training_progress_pct / 5))}][/green] {training_progress_pct:.1f}% {training_progress_text}"
253+
transitions_table.add_row(check, trans.status, bar, duration)
254+
else:
255+
transitions_table.add_row(check, trans.status, trans.status_message or "", duration)
256+
257+
# Prepare metrics table for terminal states
258+
metrics_table = None
259+
if status in ["Completed", "Failed", "Stopped"]:
260+
try:
261+
steps_per_epoch = training_job.progress_info.total_step_count_per_epoch
262+
loss_metrics_by_epoch = metrics_util._get_loss_metrics_by_epoch(run_name=mlflow_run_name,
263+
steps_per_epoch=steps_per_epoch)
264+
if loss_metrics_by_epoch:
265+
metrics_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE,
266+
padding=(0, 1))
267+
metrics_table.add_column("Epochs", style="cyan", width=8)
268+
metrics_table.add_column("Loss Metrics", style="white")
175269

176-
console = Console(force_jupyter=True)
177-
178-
iteration = 0
179-
while True:
180-
iteration += 1
181-
time.sleep(poll)
182-
training_job.refresh()
183-
clear_output(wait=False)
184-
185-
status = training_job.training_job_status
186-
secondary_status = training_job.secondary_status
187-
elapsed = time.time() - start_time
188-
189-
# Header section with training job name and MLFlow URL
190-
header_table = Table(show_header=False, box=None, padding=(0, 1))
191-
header_table.add_column("Property", style="cyan bold", width=20)
192-
header_table.add_column("Value", style="white")
193-
header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]")
194-
if mlflow_url:
195-
header_table.add_row("MLFlow URL",
196-
f"[link={mlflow_url}][bold bright_blue underline]{mlflow_run_name}(link valid for 5 mins)[/bright_blue bold underline][/link]")
197-
198-
status_table = Table(show_header=False, box=None, padding=(0, 1))
199-
status_table.add_column("Property", style="cyan bold", width=20)
200-
status_table.add_column("Value", style="white")
201-
202-
status_table.add_row("Job Status", f"[bold][orange3]{status}[/][/]")
203-
status_table.add_row("Secondary Status", f"[bold yellow]{secondary_status}[/bold yellow]")
204-
status_table.add_row("Elapsed Time", f"[bold bright_red]{elapsed:.1f}s[/bold bright_red]")
205-
206-
failure_reason = training_job.failure_reason
207-
if failure_reason and not _is_unassigned_attribute(failure_reason):
208-
status_table.add_row("Failure Reason", f"[bright_red]{failure_reason}[/bright_red]")
209-
210-
# Calculate training progress
211-
training_progress_pct = None
212-
training_progress_text = ""
213-
if secondary_status == "Training" and training_job.progress_info:
214-
if not progress_started:
215-
progress_started = True
216-
time.sleep(10)
217-
continue
218-
# training_job.refresh()
219-
220-
training_progress_pct, training_progress_text = _calculate_training_progress(
221-
training_job.progress_info, metrics_util, mlflow_run_name, training_job
222-
)
223-
224-
# Build transitions table if available
225-
transitions_table = None
226-
if training_job.secondary_status_transitions:
227-
from rich.box import SIMPLE
228-
transitions_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE, padding=(0, 1))
229-
transitions_table.add_column("", style="green", width=2)
230-
transitions_table.add_column("Step", style="cyan", width=15)
231-
transitions_table.add_column("Details", style="orange3", width=35)
232-
transitions_table.add_column("Duration", style="green", width=12)
233-
234-
for trans in training_job.secondary_status_transitions:
235-
duration, check = _calculate_transition_duration(trans)
270+
for epoch, metrics in list(loss_metrics_by_epoch.items())[:-1]:
271+
metrics_str = ", ".join([f"{k}: {v:.6f}" for k, v in metrics.items()])
272+
metrics_table.add_row(str(epoch + 1), metrics_str, style="yellow")
273+
except Exception:
274+
pass
236275

237-
# Add progress bar for Training step
238-
if trans.status == "Training" and training_progress_pct is not None:
239-
bar = f"[green][{'█' * int(training_progress_pct / 5)}{'░' * (20 - int(training_progress_pct / 5))}][/green] {training_progress_pct:.1f}% {training_progress_text}"
240-
transitions_table.add_row(check, trans.status, bar, duration)
276+
# Build combined group with metrics if available
277+
if training_job.secondary_status_transitions:
278+
if metrics_table:
279+
combined = Group(header_table, Text(""), status_table, Text(""),
280+
Text("Status Transitions", style="bold magenta"), transitions_table, Text(""),
281+
Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table)
241282
else:
242-
transitions_table.add_row(check, trans.status, trans.status_message or "", duration)
243-
244-
# Prepare metrics table for terminal states
245-
metrics_table = None
246-
if status in ["Completed", "Failed", "Stopped"]:
247-
try:
248-
steps_per_epoch = training_job.progress_info.total_step_count_per_epoch
249-
loss_metrics_by_epoch = metrics_util._get_loss_metrics_by_epoch(run_name=mlflow_run_name,
250-
steps_per_epoch=steps_per_epoch)
251-
if loss_metrics_by_epoch:
252-
metrics_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE,
253-
padding=(0, 1))
254-
metrics_table.add_column("Epochs", style="cyan", width=8)
255-
metrics_table.add_column("Loss Metrics", style="white")
256-
257-
for epoch, metrics in list(loss_metrics_by_epoch.items())[:-1]:
258-
metrics_str = ", ".join([f"{k}: {v:.6f}" for k, v in metrics.items()])
259-
metrics_table.add_row(str(epoch + 1), metrics_str, style="yellow")
260-
except Exception:
261-
pass
262-
263-
# Build combined group with metrics if available
264-
if training_job.secondary_status_transitions:
265-
if metrics_table:
266-
combined = Group(header_table, Text(""), status_table, Text(""),
267-
Text("Status Transitions", style="bold magenta"), transitions_table, Text(""),
268-
Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table)
269-
else:
270-
combined = Group(header_table, Text(""), status_table, Text(""),
271-
Text("Status Transitions", style="bold magenta"), transitions_table)
272-
else:
273-
if metrics_table:
274-
combined = Group(header_table, Text(""), status_table, Text(""),
275-
Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table)
283+
combined = Group(header_table, Text(""), status_table, Text(""),
284+
Text("Status Transitions", style="bold magenta"), transitions_table)
276285
else:
277-
combined = Group(header_table, Text(""), status_table)
286+
if metrics_table:
287+
combined = Group(header_table, Text(""), status_table, Text(""),
288+
Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table)
289+
else:
290+
combined = Group(header_table, Text(""), status_table)
278291

279-
panel_width = 80
280-
if console.width and not _is_unassigned_attribute(console.width):
281-
panel_width = int(console.width * 0.8)
282-
console.print(Panel(combined, title="[bold bright_blue]Training Job Status[/bold bright_blue]",
283-
border_style="orange3", width=panel_width))
292+
panel_width = 80
293+
if console.width and not _is_unassigned_attribute(console.width):
294+
panel_width = int(console.width * 0.8)
295+
console.print(Panel(combined, title="[bold bright_blue]Training Job Status[/bold bright_blue]",
296+
border_style="orange3", width=panel_width))
284297

285-
if status in ["Completed", "Failed", "Stopped"]:
286-
return
298+
if status in ["Completed", "Failed", "Stopped"]:
299+
return
287300

288-
if status == "Failed" or (failure_reason and not _is_unassigned_attribute(failure_reason)):
289-
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)
301+
if status == "Failed" or (failure_reason and not _is_unassigned_attribute(failure_reason)):
302+
raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason)
290303

291-
if timeout and elapsed >= timeout:
292-
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
304+
if timeout and elapsed >= timeout:
305+
raise TimeoutExceededError(resouce_type="TrainingJob", status=status)
293306

294307
else:
295308
print(f"\nTrainingJob Name: {training_job.training_job_name}")

0 commit comments

Comments
 (0)