img should be PIL Image. Got lt;class #39;torch.Tensor#39;gt;(img 应该是 PIL Image.得到了 lt;class torch.Tensorgt;)
问题描述
我正在尝试遍历加载器以检查它是否正常工作,但是给出了以下错误:
I'm trying to iterate through a loader to check if it's working, however the below error is given:
TypeError: img 应该是 PIL Image.得到了
我已经尝试添加 transforms.ToTensor()
和 transforms.ToPILImage()
并且它给了我一个错误要求相反.即,使用 ToPILImage()
,它将要求张量,反之亦然.
I've tried adding both transforms.ToTensor()
and transforms.ToPILImage()
and it gives me an error asking for the opposite. i.e, with ToPILImage()
, it will ask for tensor, and vice versa.
# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np
data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'
#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)
#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32,
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))
它应该允许我在运行 plt.imshow(images[0])
后简单地看到图像,如果它工作正常.
It should allow me to simply see the image once I run plt.imshow(images[0])
, if its working correctly.
推荐答案
transforms.RandomHorizontalFlip()
适用于 PIL.Images
,而不是 torch.Tensor代码>.在上面的代码中,您在
transforms.RandomHorizontalFlip()
之前应用 transforms.ToTensor()
,这会产生张量.
transforms.RandomHorizontalFlip()
works on PIL.Images
, not torch.Tensor
. In your code above, you are applying transforms.ToTensor()
prior to transforms.RandomHorizontalFlip()
, which results in tensor.
但是,根据官方 pytorch 文档这里、
But, as per the official pytorch documentation here,
transforms.RandomHorizontalFlip() 水平翻转给定的 PIL以给定的概率随机图像.
transforms.RandomHorizontalFlip() horizontally flip the given PIL Image randomly with a given probability.
因此,只需更改上面代码中的转换顺序,如下所示:
So, just change the order of your transformation in above code, like below:
train_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
这篇关于img 应该是 PIL Image.得到了 <class 'torch.Tensor'>的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:img 应该是 PIL Image.得到了 <class 'torch.Te


- 如何使用PYSPARK从Spark获得批次行 2022-01-01
- 使用 Cython 将 Python 链接到共享库 2022-01-01
- 我如何卸载 PyTorch? 2022-01-01
- 使用公司代理使Python3.x Slack(松弛客户端) 2022-01-01
- 我如何透明地重定向一个Python导入? 2022-01-01
- CTR 中的 AES 如何用于 Python 和 PyCrypto? 2022-01-01
- 检查具有纬度和经度的地理点是否在 shapefile 中 2022-01-01
- 计算测试数量的Python单元测试 2022-01-01
- YouTube API v3 返回截断的观看记录 2022-01-01
- ";find_element_by_name(';name';)";和&QOOT;FIND_ELEMENT(BY NAME,';NAME';)";之间有什么区别? 2022-01-01