Skip to content

Commit

Permalink
Merge pull request #361 from BhammarArjun/stage
Browse files Browse the repository at this point in the history
Fixing One hot encoding part
  • Loading branch information
ATaylorAerospace authored Oct 27, 2024
2 parents 436794b + 4617b15 commit d7edaf9
Showing 1 changed file with 40 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -579,29 +579,49 @@
"outputs": [],
"source": [
"def train_transforms(batch):\n",
" # convert all images in batch to RGB to avoid grayscale or transparent images\n",
" batch['image'] = [x.convert('RGB') for x in batch['image']]\n",
" # apply torchvision.transforms per sample in the batch\n",
" inputs = [train_tfms(x) for x in batch['image']]\n",
" batch['pixel_values'] = inputs\n",
" \n",
" # one-hot encoding the labels\n",
" labels = torch.tensor(batch['classes'])\n",
" batch['labels'] = nn.functional.one_hot(labels,num_classes=20).sum(dim=1)\n",
" \n",
" # Convert all images to RGB format\n",
" if isinstance(batch['image'], list):\n",
" # Batch processing\n",
" batch['image'] = [x.convert('RGB') for x in batch['image']]\n",
" inputs = [train_tfms(x) for x in batch['image']]\n",
" batch['pixel_values'] = torch.stack(inputs) # Stack tensor outputs\n",
" else:\n",
" # Single sample processing\n",
" batch['image'] = batch['image'].convert('RGB')\n",
" batch['pixel_values'] = train_tfms(batch['image'])\n",
"\n",
" # One-hot encode the multilabels\n",
" all_labels = [torch.tensor(labels) for labels in batch['classes']]\n",
"\n",
" # Create one-hot encoding for each image's classes\n",
" one_hot_labels = [nn.functional.one_hot(label, num_classes=20).sum(dim=0) for label in all_labels]\n",
"\n",
" # Stack them into a batch\n",
" batch['labels'] = torch.stack(one_hot_labels)\n",
"\n",
" return batch\n",
"\n",
"def valid_transforms(batch):\n",
" # convert all images in batch to RGB to avoid grayscale or transparent images\n",
" batch['image'] = [x.convert('RGB') for x in batch['image']]\n",
" # apply torchvision.transforms per sample in the batch\n",
" inputs = [valid_tfms(x) for x in batch['image']]\n",
" batch['pixel_values'] = inputs\n",
" \n",
" # one-hot encoding the labels\n",
" labels = torch.tensor(batch['classes'])\n",
" batch['labels'] = nn.functional.one_hot(labels,num_classes=20).sum(dim=1)\n",
" \n",
" # Convert all images to RGB format\n",
" if isinstance(batch['image'], list):\n",
" # Batch processing\n",
" batch['image'] = [x.convert('RGB') for x in batch['image']]\n",
" inputs = [train_tfms(x) for x in batch['image']]\n",
" batch['pixel_values'] = torch.stack(inputs) # Stack tensor outputs\n",
" else:\n",
" # Single sample processing\n",
" batch['image'] = batch['image'].convert('RGB')\n",
" batch['pixel_values'] = train_tfms(batch['image'])\n",
"\n",
" # One-hot encode the multilabels\n",
" all_labels = [torch.tensor(labels) for labels in batch['classes']]\n",
"\n",
" # Create one-hot encoding for each image's classes\n",
" one_hot_labels = [nn.functional.one_hot(label, num_classes=20).sum(dim=0) for label in all_labels]\n",
"\n",
" # Stack them into a batch\n",
" batch['labels'] = torch.stack(one_hot_labels)\n",
"\n",
" return batch"
]
},
Expand Down

0 comments on commit d7edaf9

Please sign in to comment.