Augmenting only the training set in K-folds cross validation(在 K 折交叉验证中仅增加训练集)
问题描述
我正在尝试为不平衡数据集(第 0 类 = 4000 张图像,第 1 类 = 大约 250 张图像)创建二进制 CNN 分类器,我想对其执行 5 折交叉验证.目前,我正在将训练集加载到 ImageLoader 中,该 ImageLoader 应用我的转换/增强(?)并将其加载到 DataLoader 中.然而,这导致我的训练分割和验证分割都包含增强数据.
I am trying to create a binary CNN classifier for an unbalanced dataset (class 0 = 4000 images, class 1 = around 250 images), which I want to perform 5-fold cross validation on. Currently I am loading my training set into an ImageLoader that applies my transformations/augmentations(?) and loads it into a DataLoader. However, this results in both my training splits and validation splits containing the augmented data.
我最初应用离线转换(离线增强?)来平衡我的数据集,但是从这个线程 (https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split),似乎最好只增加训练集.我也更愿意在单独的增强训练数据上训练我的模型,然后在非增强数据上进行 5 折交叉验证
I originally applied transformations offline (offline augmentation?) to balance my dataset, but from this thread (https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split), it seems it would be ideal to only augment the training set. I would also prefer to train my model on solely augmented training data and then validate it on non-augmented data in a 5-fold cross validation
我的数据按根/标签/图像组织,其中有 2 个标签文件夹(0 和 1),图像分类到各自的标签中.
My data is organized as root/label/images, where there are 2 label folders (0 and 1) and images sorted into the respective labels.
total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])
//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in splits.split(total_set):
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)
model.train()
//Model train/eval works but may be overpredict
我确定我在这段代码中做了一些次优或错误的事情,但我似乎找不到任何关于专门增加交叉验证中的训练分割的文档!
I'm sure I'm doing something sub-optimally or wrong in this code, but I can't seem to find any documentation on specifically augmenting only the training splits in cross-validation!
任何帮助将不胜感激!
推荐答案
一种方法是实现一个包装 Dataset 类,该类将转换应用于 ImageFolder 数据集的输出.例如
One approach is to implement a wrapper Dataset class that applies transforms to the output of your ImageFolder dataset. For example
class WrapperDataset:
def __init__(self, dataset, transform=None, target_transform=None):
self.dataset = dataset
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
image, label = self.dataset[index]
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
def __len__(self):
return len(self.dataset)
然后你可以在你的代码中使用它,方法是用不同的转换包装更大的数据集.
Then you could use this in your code by wrapping the larger dataset with different transforms.
total_set = datasets.ImageFolder(ROOT)
# Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in splits.split(total_set):
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['train_transforms']),
batch_size=32, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['valid_transforms']),
batch_size=32, sampler=valid_sampler)
# train/validate now
我没有测试这段代码,因为我没有你的完整代码/模型,但概念应该很清楚.
I haven't tested this code since I don't have your full code/models but the concept should be clear.
这篇关于在 K 折交叉验证中仅增加训练集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:在 K 折交叉验证中仅增加训练集


- 如何使用PYSPARK从Spark获得批次行 2022-01-01
- 我如何卸载 PyTorch? 2022-01-01
- YouTube API v3 返回截断的观看记录 2022-01-01
- 使用公司代理使Python3.x Slack(松弛客户端) 2022-01-01
- 使用 Cython 将 Python 链接到共享库 2022-01-01
- ";find_element_by_name(';name';)";和&QOOT;FIND_ELEMENT(BY NAME,';NAME';)";之间有什么区别? 2022-01-01
- 我如何透明地重定向一个Python导入? 2022-01-01
- 检查具有纬度和经度的地理点是否在 shapefile 中 2022-01-01
- CTR 中的 AES 如何用于 Python 和 PyCrypto? 2022-01-01
- 计算测试数量的Python单元测试 2022-01-01