推广 热搜: 京东  联通  iphone11  摄像头  企业存储  iPhone  XSKY  京东智能采购  网络安全  自动驾驶 

PyTorch 中 数据集Torchvision和Torchtext

   日期:2021-09-06     来源:51cto    作者:itcg    浏览:390    我要评论    
导读:对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。

对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。

之前使用 torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集

Torchvision 中的数据集 MNIST

MNIST是一个由标准化和中心裁剪的手写图像组成的数据集。它有超过 60,000 张训练图像和 10,000 张测试图像。这是用于学习和实验目的最常用的数据集之一。要加载和使用数据集,使用以下语法导入:torchvision.datasets.MNIST()。

Fashion MNIST

Fashion MNIST数据集类似于MNIST,但该数据集包含T恤、裤子、包包等服装项目,而不是手写数字,训练和测试样本数分别为60,000和10,000。要加载和使用数据集,使用以下语法导入:torchvision.datasets.FashionMNIST()

CIFAR

CIFAR数据集有两个版本,CIFAR10和CIFAR100。CIFAR10 由 10 个不同标签的图像组成,而 CIFAR100 有 100 个不同的类。这些包括常见的图像,如卡车、青蛙、船、汽车、鹿等。

torchvision.datasets.CIFAR10() torchvision.datasets.CIFAR100()  COCO

COCO数据集包含超过 100,000 个日常对象,如人、瓶子、文具、书籍等。这个图像数据集广泛用于对象检测和图像字幕应用。下面是可以加载 COCO 的位置:torchvision.datasets.CocoCaptions()

EMNIST

EMNIST数据集是 MNIST 数据集的高级版本。它由包括数字和字母的图像组成。如果您正在处理基于从图像中识别文本的问题,EMNIST是一个不错的选择。下面是可以加载 EMNIST的位置::torchvision.datasets.EMNIST()

IMAGE-NET

ImageNet 是用于训练高端神经网络的旗舰数据集之一。它由分布在 10,000 个类别中的超过 120 万张图像组成。通常,这个数据集加载在高端硬件系统上,因为单独的 CPU 无法处理这么大的数据集。下面是加载 ImageNet 数据集的类:torchvision.datasets.ImageNet()

Torchtext 中的数据集 IMDB

IMDB是一个用于情感分类的数据集,其中包含一组 25,000 条高度极端的电影评论用于训练,另外 25,000 条用于测试。使用以下类加载这些数据torchtext:torchtext.datasets.IMDB()

WikiText2

WikiText2语言建模数据集是一个超过 1 亿个标记的集合。它是从维基百科中提取的,并保留了标点符号和实际的字母大小写。它广泛用于涉及长期依赖的应用程序。可以从torchtext以下位置加载此数据:torchtext.datasets.WikiText2()

除了上述两个流行的数据集,torchtext库中还有更多可用的数据集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。

深入查看 MNIST 数据集

MNIST 是最受欢迎的数据集之一。现在我们将看到 PyTorch 如何从 pytorch/vision 存储库加载 MNIST 数据集。让我们首先下载数据集并将其加载到名为 的变量中data_train

from torchvision.datasets import MNIST  # Download MNIST  data_train = MNIST('~/mnist_data', train=True, download=True)  import matplotlib.pyplot as plt  random_image = data_train[0][0] random_image_label = data_train[0][1]  # Print the Image using Matplotlib plt.imshow(random_image) print("The label of the image is:", random_image_label)  DataLoader加载MNIST

下面我们使用DataLoader该类加载数据集,如下所示。

import torch from torchvision import transforms  data_train = torch.utils.data.DataLoader(     MNIST(           '~/mnist_data', train=True, download=True,            transform = transforms.Compose([               transforms.ToTensor()           ])),           batch_size=64,           shuffle=True           )  for batch_idx, samples in enumerate(data_train):       print(batch_idx, samples)  CUDA加载

我们可以启用 GPU 来更快地训练我们的模型。现在让我们使用CUDA加载数据时可以使用的(GPU 支持 PyTorch)的配置。

device = "cuda" if torch.cuda.is_available() else "cpu" kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}  train_loader = torch.utils.data.DataLoader(   torchvision.datasets.MNIST('/files/', train=True, download=True),   batch_size=batch_size_train, **kwargs)  test_loader = torch.utils.data.DataLoader(   torchvision.datasets.MNIST('files/', train=False, download=True),   batch_size=batch_size, **kwargs)  ImageFolder

ImageFolder是一个通用数据加载器类torchvision,可帮助加载自己的图像数据集。处理一个分类问题并构建一个神经网络来识别给定的图像是apple还是orange。要在 PyTorch 中执行此操作,第一步是在默认文件夹结构中排列图像,如下所示:

root ├── orange │   ├── orange_image1.png │   └── orange_image1.png ├── apple │   └── apple_image1.png │   └── apple_image2.png │   └── apple_image3.png 

可以使用ImageLoader该类加载所有这些图像。

torchvision.datasets.ImageFolder(root, transform)  transforms

PyTorch 转换定义了简单的图像转换技术,可将整个数据集转换为独特的格式。

如果是一个包含不同分辨率的不同汽车图片的数据集,在训练时,我们训练数据集中的所有图像都应该具有相同的分辨率大小。如果我们手动将所有图像转换为所需的输入大小,则很耗时,因此我们可以使用transforms;使用几行 PyTorch 代码,我们数据集中的所有图像都可以转换为所需的输入大小和分辨率。

现在让我们加载 CIFAR10torchvision.datasets并应用以下转换:

将所有图像调整为 32×32 对图像应用中心裁剪变换 将裁剪后的图像转换为张量 标准化图像 import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np  transform = transforms.Compose([     # resize 32×32     transforms.Resize(32),     # center-crop裁剪变换     transforms.CenterCrop(32),     # to-tensor     transforms.ToTensor(),     # normalize 标准化     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ])  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,                                         download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,                                           shuffle=False)  在 PyTorch 中创建自定义数据集

下面将创建一个由数字和文本组成的简单自定义数据集。需要封装Dataset 类中的__getitem__()和__len__()方法。

__getitem__()方法通过索引返回数据集中的选定样本。 __len__()方法返回数据集的总大小。

下面是曾经封装FruitImagesDataset数据集的代码,基本是比较好的 PyTorch 中创建自定义数据集的模板。

import os import numpy as np import cv2 import torch import matplotlib.patches as patches import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 from matplotlib import pyplot as plt from torch.utils.data import Dataset from xml.etree import ElementTree as et from torchvision import transforms as torchtrans  class FruitImagesDataset(torch.utils.data.Dataset):     def __init__(self, files_dir, width, height, transforms=None):         self.transforms = transforms         self.files_dir = files_dir         self.height = height         self.width = width           self.imgs = [image for image in sorted(os.listdir(files_dir))                      if image[-4:] == '.jpg']          self.classes = ['_','apple', 'banana', 'orange']      def __getitem__(self, idx):          img_name = self.imgs[idx]         image_path = os.path.join(self.files_dir, img_name)          # reading the images and converting them to correct size and color         img = cv2.imread(image_path)         img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)         img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)         # diving by 255         img_res /= 255.0          # annotation file         annot_filename = img_name[:-4] + '.xml'         annot_file_path = os.path.join(self.files_dir, annot_filename)          boxes = []         labels = []         tree = et.parse(annot_file_path)         root = tree.getroot()          # cv2 image gives size as height x width         wt = img.shape[1]         ht = img.shape[0]          # box coordinates for xml files are extracted and corrected for image size given         for member in root.findall('object'):             labels.append(self.classes.index(member.find('name').text))              # bounding box             xmin = int(member.find('bndbox').find('xmin').text)             xmax = int(member.find('bndbox').find('xmax').text)              ymin = int(member.find('bndbox').find('ymin').text)             ymax = int(member.find('bndbox').find('ymax').text)              xmin_corr = (xmin / wt) * self.width             xmax_corr = (xmax / wt) * self.width             ymin_corr = (ymin / ht) * self.height             ymax_corr = (ymax / ht) * self.height              boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])          # convert boxes into a torch.Tensor         boxes = torch.as_tensor(boxes, dtype=torch.float32)          # getting the areas of the boxes         area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])          # suppose all instances are not crowd         iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)          labels = torch.as_tensor(labels, dtype=torch.int64)          target = {}         target["boxes"] = boxes         target["labels"] = labels         target["area"] = area         target["iscrowd"] = iscrowd         # image_id         image_id = torch.tensor([idx])         target["image_id"] = image_id          if self.transforms:             sample = self.transforms(image=img_res,                                      bboxes=target['boxes'],                                      labels=labels)              img_res = sample['image']             target['boxes'] = torch.Tensor(sample['bboxes'])         return img_res, target     def __len__(self):         return len(self.imgs)  def get_transform(train):     if train:         return A.Compose([             A.HorizontalFlip(0.5),             ToTensorV2(p=1.0)         ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})     else:         return A.Compose([             ToTensorV2(p=1.0)         ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})  files_dir = '../input/fruit-images-for-object-detection/train_zip/train' test_dir = '../input/fruit-images-for-object-detection/test_zip/test'  dataset = FruitImagesDataset(train_dir, 480, 480) 

 
反对 0举报 0 收藏 0 打赏 0评论 0
 
更多>同类资讯
0相关评论

头条阅读
推荐图文
相关资讯
网站首页  |  物流配送  |  关于我们  |  联系方式  |  使用协议  |  版权隐私  |  网站地图  |  排名推广  |  广告服务  |  积分换礼  |  RSS订阅  |  违规举报  |  京ICP备14047533号-2