|
476 | 476 | } |
477 | 477 | ], |
478 | 478 | "source": [ |
479 | | - "# ---------------------------------------------\n", |
| 479 | + "\n", |
480 | 480 | "# Check number of samples per split\n", |
481 | | - "# ---------------------------------------------\n", |
| 481 | + "\n", |
482 | 482 | "print(\"\\nDataset sample counts per split:\")\n", |
483 | 483 | "for split_name in dataset_dict:\n", |
484 | 484 | " print(f\"{split_name}: {len(dataset_dict[split_name])} samples\")\n", |
485 | 485 | "\n", |
486 | | - "# ---------------------------------------------\n", |
| 486 | + "\n", |
487 | 487 | "# Print keys of first example\n", |
488 | | - "# ---------------------------------------------\n", |
| 488 | + "\n", |
489 | 489 | "example = dataset_dict[\"train\"][0]\n", |
490 | 490 | "print(\"\\nFirst example keys:\", list(example.keys()))\n", |
491 | 491 | "\n", |
|
704 | 704 | " ax.plot(stroke[:, 0], -stroke[:, 1], color=color, linewidth=2)\n", |
705 | 705 | " ax.axis(\"off\")\n", |
706 | 706 | "\n", |
707 | | - "# From your dataloader, get raw traces instead of padded\n", |
| 707 | + "# From the dataloader, the model can pull raw traces instead of padded\n", |
708 | 708 | "for batch in tf_train.take(1):\n", |
709 | | - " # This assumes you still have original \"traces\" column in dataset_dict\n", |
| 709 | + " # This is if you still have original \"traces\" in a dataset_dict\n", |
710 | 710 | " sample_idx = 0\n", |
711 | 711 | " sample_file_path = dataset_dict[\"train\"][sample_idx][\"file_path\"]\n", |
712 | 712 | " sample_label = dataset_dict[\"train\"][sample_idx][\"normalized_label\"]\n", |
|
0 commit comments