|
4 | 4 | with progress tracking and MLflow integration. |
5 | 5 | """ |
6 | 6 |
|
| 7 | +import logging |
7 | 8 | import time |
| 9 | +from contextlib import contextmanager |
8 | 10 | from typing import Optional, Tuple |
9 | 11 |
|
10 | 12 | from sagemaker.core.resources import TrainingJob |
|
13 | 15 | from sagemaker.train.common_utils.mlflow_metrics_util import _MLflowMetricsUtil |
14 | 16 |
|
15 | 17 |
|
| 18 | +@contextmanager |
| 19 | +def _suppress_info_logging(): |
| 20 | + """Context manager to temporarily suppress INFO level logging.""" |
| 21 | + logger = logging.getLogger() |
| 22 | + original_level = logger.level |
| 23 | + logger.setLevel(logging.WARNING) |
| 24 | + try: |
| 25 | + yield |
| 26 | + finally: |
| 27 | + logger.setLevel(original_level) |
| 28 | + |
| 29 | + |
16 | 30 | def _setup_mlflow_integration(training_job: TrainingJob) -> Tuple[ |
17 | 31 | Optional[str], Optional[_MLflowMetricsUtil], Optional[str]]: |
18 | 32 | """Setup MLflow integration for training job monitoring. |
@@ -172,124 +186,123 @@ def wait( |
172 | 186 | from rich.panel import Panel |
173 | 187 | from rich.text import Text |
174 | 188 | from rich.console import Group |
| 189 | + with _suppress_info_logging(): |
| 190 | + console = Console(force_jupyter=True) |
| 191 | + |
| 192 | + iteration = 0 |
| 193 | + while True: |
| 194 | + iteration += 1 |
| 195 | + time.sleep(poll) |
| 196 | + training_job.refresh() |
| 197 | + clear_output(wait=True) |
| 198 | + |
| 199 | + status = training_job.training_job_status |
| 200 | + secondary_status = training_job.secondary_status |
| 201 | + elapsed = time.time() - start_time |
| 202 | + |
| 203 | + # Header section with training job name and MLFlow URL |
| 204 | + header_table = Table(show_header=False, box=None, padding=(0, 1)) |
| 205 | + header_table.add_column("Property", style="cyan bold", width=20) |
| 206 | + header_table.add_column("Value", style="white") |
| 207 | + header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]") |
| 208 | + if mlflow_url: |
| 209 | + header_table.add_row("MLFlow URL", |
| 210 | + f"[link={mlflow_url}][bold bright_blue underline]{mlflow_run_name}(link valid for 5 mins)[/bright_blue bold underline][/link]") |
| 211 | + |
| 212 | + status_table = Table(show_header=False, box=None, padding=(0, 1)) |
| 213 | + status_table.add_column("Property", style="cyan bold", width=20) |
| 214 | + status_table.add_column("Value", style="white") |
| 215 | + |
| 216 | + status_table.add_row("Job Status", f"[bold][orange3]{status}[/][/]") |
| 217 | + status_table.add_row("Secondary Status", f"[bold yellow]{secondary_status}[/bold yellow]") |
| 218 | + status_table.add_row("Elapsed Time", f"[bold bright_red]{elapsed:.1f}s[/bold bright_red]") |
| 219 | + |
| 220 | + failure_reason = training_job.failure_reason |
| 221 | + if failure_reason and not _is_unassigned_attribute(failure_reason): |
| 222 | + status_table.add_row("Failure Reason", f"[bright_red]{failure_reason}[/bright_red]") |
| 223 | + |
| 224 | + # Calculate training progress |
| 225 | + training_progress_pct = None |
| 226 | + training_progress_text = "" |
| 227 | + if secondary_status == "Training" and training_job.progress_info: |
| 228 | + if not progress_started: |
| 229 | + progress_started = True |
| 230 | + time.sleep(poll) |
| 231 | + training_job.refresh() |
| 232 | + |
| 233 | + training_progress_pct, training_progress_text = _calculate_training_progress( |
| 234 | + training_job.progress_info, metrics_util, mlflow_run_name, training_job |
| 235 | + ) |
| 236 | + |
| 237 | + # Build transitions table if available |
| 238 | + transitions_table = None |
| 239 | + if training_job.secondary_status_transitions: |
| 240 | + from rich.box import SIMPLE |
| 241 | + transitions_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE, padding=(0, 1)) |
| 242 | + transitions_table.add_column("", style="green", width=2) |
| 243 | + transitions_table.add_column("Step", style="cyan", width=15) |
| 244 | + transitions_table.add_column("Details", style="orange3", width=35) |
| 245 | + transitions_table.add_column("Duration", style="green", width=12) |
| 246 | + |
| 247 | + for trans in training_job.secondary_status_transitions: |
| 248 | + duration, check = _calculate_transition_duration(trans) |
| 249 | + |
| 250 | + # Add progress bar for Training step |
| 251 | + if trans.status == "Training" and training_progress_pct is not None: |
| 252 | + bar = f"[green][{'█' * int(training_progress_pct / 5)}{'░' * (20 - int(training_progress_pct / 5))}][/green] {training_progress_pct:.1f}% {training_progress_text}" |
| 253 | + transitions_table.add_row(check, trans.status, bar, duration) |
| 254 | + else: |
| 255 | + transitions_table.add_row(check, trans.status, trans.status_message or "", duration) |
| 256 | + |
| 257 | + # Prepare metrics table for terminal states |
| 258 | + metrics_table = None |
| 259 | + if status in ["Completed", "Failed", "Stopped"]: |
| 260 | + try: |
| 261 | + steps_per_epoch = training_job.progress_info.total_step_count_per_epoch |
| 262 | + loss_metrics_by_epoch = metrics_util._get_loss_metrics_by_epoch(run_name=mlflow_run_name, |
| 263 | + steps_per_epoch=steps_per_epoch) |
| 264 | + if loss_metrics_by_epoch: |
| 265 | + metrics_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE, |
| 266 | + padding=(0, 1)) |
| 267 | + metrics_table.add_column("Epochs", style="cyan", width=8) |
| 268 | + metrics_table.add_column("Loss Metrics", style="white") |
175 | 269 |
|
176 | | - console = Console(force_jupyter=True) |
177 | | - |
178 | | - iteration = 0 |
179 | | - while True: |
180 | | - iteration += 1 |
181 | | - time.sleep(poll) |
182 | | - training_job.refresh() |
183 | | - clear_output(wait=False) |
184 | | - |
185 | | - status = training_job.training_job_status |
186 | | - secondary_status = training_job.secondary_status |
187 | | - elapsed = time.time() - start_time |
188 | | - |
189 | | - # Header section with training job name and MLFlow URL |
190 | | - header_table = Table(show_header=False, box=None, padding=(0, 1)) |
191 | | - header_table.add_column("Property", style="cyan bold", width=20) |
192 | | - header_table.add_column("Value", style="white") |
193 | | - header_table.add_row("TrainingJob Name", f"[bold green]{training_job.training_job_name}[/bold green]") |
194 | | - if mlflow_url: |
195 | | - header_table.add_row("MLFlow URL", |
196 | | - f"[link={mlflow_url}][bold bright_blue underline]{mlflow_run_name}(link valid for 5 mins)[/bright_blue bold underline][/link]") |
197 | | - |
198 | | - status_table = Table(show_header=False, box=None, padding=(0, 1)) |
199 | | - status_table.add_column("Property", style="cyan bold", width=20) |
200 | | - status_table.add_column("Value", style="white") |
201 | | - |
202 | | - status_table.add_row("Job Status", f"[bold][orange3]{status}[/][/]") |
203 | | - status_table.add_row("Secondary Status", f"[bold yellow]{secondary_status}[/bold yellow]") |
204 | | - status_table.add_row("Elapsed Time", f"[bold bright_red]{elapsed:.1f}s[/bold bright_red]") |
205 | | - |
206 | | - failure_reason = training_job.failure_reason |
207 | | - if failure_reason and not _is_unassigned_attribute(failure_reason): |
208 | | - status_table.add_row("Failure Reason", f"[bright_red]{failure_reason}[/bright_red]") |
209 | | - |
210 | | - # Calculate training progress |
211 | | - training_progress_pct = None |
212 | | - training_progress_text = "" |
213 | | - if secondary_status == "Training" and training_job.progress_info: |
214 | | - if not progress_started: |
215 | | - progress_started = True |
216 | | - time.sleep(10) |
217 | | - continue |
218 | | - # training_job.refresh() |
219 | | - |
220 | | - training_progress_pct, training_progress_text = _calculate_training_progress( |
221 | | - training_job.progress_info, metrics_util, mlflow_run_name, training_job |
222 | | - ) |
223 | | - |
224 | | - # Build transitions table if available |
225 | | - transitions_table = None |
226 | | - if training_job.secondary_status_transitions: |
227 | | - from rich.box import SIMPLE |
228 | | - transitions_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE, padding=(0, 1)) |
229 | | - transitions_table.add_column("", style="green", width=2) |
230 | | - transitions_table.add_column("Step", style="cyan", width=15) |
231 | | - transitions_table.add_column("Details", style="orange3", width=35) |
232 | | - transitions_table.add_column("Duration", style="green", width=12) |
233 | | - |
234 | | - for trans in training_job.secondary_status_transitions: |
235 | | - duration, check = _calculate_transition_duration(trans) |
| 270 | + for epoch, metrics in list(loss_metrics_by_epoch.items())[:-1]: |
| 271 | + metrics_str = ", ".join([f"{k}: {v:.6f}" for k, v in metrics.items()]) |
| 272 | + metrics_table.add_row(str(epoch + 1), metrics_str, style="yellow") |
| 273 | + except Exception: |
| 274 | + pass |
236 | 275 |
|
237 | | - # Add progress bar for Training step |
238 | | - if trans.status == "Training" and training_progress_pct is not None: |
239 | | - bar = f"[green][{'█' * int(training_progress_pct / 5)}{'░' * (20 - int(training_progress_pct / 5))}][/green] {training_progress_pct:.1f}% {training_progress_text}" |
240 | | - transitions_table.add_row(check, trans.status, bar, duration) |
| 276 | + # Build combined group with metrics if available |
| 277 | + if training_job.secondary_status_transitions: |
| 278 | + if metrics_table: |
| 279 | + combined = Group(header_table, Text(""), status_table, Text(""), |
| 280 | + Text("Status Transitions", style="bold magenta"), transitions_table, Text(""), |
| 281 | + Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table) |
241 | 282 | else: |
242 | | - transitions_table.add_row(check, trans.status, trans.status_message or "", duration) |
243 | | - |
244 | | - # Prepare metrics table for terminal states |
245 | | - metrics_table = None |
246 | | - if status in ["Completed", "Failed", "Stopped"]: |
247 | | - try: |
248 | | - steps_per_epoch = training_job.progress_info.total_step_count_per_epoch |
249 | | - loss_metrics_by_epoch = metrics_util._get_loss_metrics_by_epoch(run_name=mlflow_run_name, |
250 | | - steps_per_epoch=steps_per_epoch) |
251 | | - if loss_metrics_by_epoch: |
252 | | - metrics_table = Table(show_header=True, header_style="bold magenta", box=SIMPLE, |
253 | | - padding=(0, 1)) |
254 | | - metrics_table.add_column("Epochs", style="cyan", width=8) |
255 | | - metrics_table.add_column("Loss Metrics", style="white") |
256 | | - |
257 | | - for epoch, metrics in list(loss_metrics_by_epoch.items())[:-1]: |
258 | | - metrics_str = ", ".join([f"{k}: {v:.6f}" for k, v in metrics.items()]) |
259 | | - metrics_table.add_row(str(epoch + 1), metrics_str, style="yellow") |
260 | | - except Exception: |
261 | | - pass |
262 | | - |
263 | | - # Build combined group with metrics if available |
264 | | - if training_job.secondary_status_transitions: |
265 | | - if metrics_table: |
266 | | - combined = Group(header_table, Text(""), status_table, Text(""), |
267 | | - Text("Status Transitions", style="bold magenta"), transitions_table, Text(""), |
268 | | - Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table) |
269 | | - else: |
270 | | - combined = Group(header_table, Text(""), status_table, Text(""), |
271 | | - Text("Status Transitions", style="bold magenta"), transitions_table) |
272 | | - else: |
273 | | - if metrics_table: |
274 | | - combined = Group(header_table, Text(""), status_table, Text(""), |
275 | | - Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table) |
| 283 | + combined = Group(header_table, Text(""), status_table, Text(""), |
| 284 | + Text("Status Transitions", style="bold magenta"), transitions_table) |
276 | 285 | else: |
277 | | - combined = Group(header_table, Text(""), status_table) |
| 286 | + if metrics_table: |
| 287 | + combined = Group(header_table, Text(""), status_table, Text(""), |
| 288 | + Text("Loss Metrics by Epoch", style="bold magenta"), metrics_table) |
| 289 | + else: |
| 290 | + combined = Group(header_table, Text(""), status_table) |
278 | 291 |
|
279 | | - panel_width = 80 |
280 | | - if console.width and not _is_unassigned_attribute(console.width): |
281 | | - panel_width = int(console.width * 0.8) |
282 | | - console.print(Panel(combined, title="[bold bright_blue]Training Job Status[/bold bright_blue]", |
283 | | - border_style="orange3", width=panel_width)) |
| 292 | + panel_width = 80 |
| 293 | + if console.width and not _is_unassigned_attribute(console.width): |
| 294 | + panel_width = int(console.width * 0.8) |
| 295 | + console.print(Panel(combined, title="[bold bright_blue]Training Job Status[/bold bright_blue]", |
| 296 | + border_style="orange3", width=panel_width)) |
284 | 297 |
|
285 | | - if status in ["Completed", "Failed", "Stopped"]: |
286 | | - return |
| 298 | + if status in ["Completed", "Failed", "Stopped"]: |
| 299 | + return |
287 | 300 |
|
288 | | - if status == "Failed" or (failure_reason and not _is_unassigned_attribute(failure_reason)): |
289 | | - raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason) |
| 301 | + if status == "Failed" or (failure_reason and not _is_unassigned_attribute(failure_reason)): |
| 302 | + raise FailedStatusError(resource_type="TrainingJob", status=status, reason=failure_reason) |
290 | 303 |
|
291 | | - if timeout and elapsed >= timeout: |
292 | | - raise TimeoutExceededError(resouce_type="TrainingJob", status=status) |
| 304 | + if timeout and elapsed >= timeout: |
| 305 | + raise TimeoutExceededError(resouce_type="TrainingJob", status=status) |
293 | 306 |
|
294 | 307 | else: |
295 | 308 | print(f"\nTrainingJob Name: {training_job.training_job_name}") |
|
0 commit comments