[pytorch、學習] - 9.1 圖像增廣

參考

9.1 圖像增廣

在5.6節(深度卷積神經網絡)里我們提過,大規模數據集是成功應用神經網絡的前提。圖像增廣(image augmentation)技術通過對訓練圖像做一系列隨機改變,來產生相似但又不相同的訓練樣本,從而擴大訓練數據集的規模。圖像增廣的另一種解釋是,隨機改變訓練樣本可以降低模型對某些屬性的依賴,從而提高模型的泛化能力。

import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Imageimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

9.1.1 常用的圖像增廣方法

d2l.set_figsize()
img = Image.open('../img/cat2.jpg')
d2l.plt.imshow(img)

在這里插入圖片描述
下面定義繪圖函數show_images

# 傳入圖片, 行、列 和規模
def show_images(imgs, num_rows, num_cols, scale=2):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)for i in range(num_rows):for j in range(num_cols):axes[i][j].imshow(imgs[i * num_cols + j])axes[i][j].axes.get_xaxis().set_visible(False)axes[i][j].axes.get_yaxis().set_visible(False)return axes

大部分圖像增廣都有一定的隨機性。為了方便觀察圖像增廣效果,接下來我們定義一個輔助函數apply。這個函數對輸入圖像img多次運行圖像增廣方法aug并展示所有的結果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for  _ in range(num_rows * num_cols)]show_images(Y, num_rows, num_cols, scale)

9.1.1.1 翻轉和裁剪

左右翻轉圖像通常不改變物體的類別。它是最早也是最廣泛使用的一種圖像增廣方法。下面我們通過torchvision.transforms模塊創建RandomHorizontalFlip實例來實現一半概率的圖像水平(左右)翻轉。

apply(img, torchvision.transforms.RandomHorizontalFlip(),num_rows= 2, scale=3)

在這里插入圖片描述
上下翻轉不如左右翻轉通用。但是至少對于樣例圖像,上下翻轉不會造成識別障礙。下面我們創建RandomVerticalFllip實例來實現一半概率的圖像垂直(上下翻轉)

apply(img, torchvision.transforms.RandomVerticalFlip(), scale=2.7)

在這里插入圖片描述
在我們使用的樣例圖像里,貓在圖像正中間,但一般情況下可能不是這樣。在5.4節(池化層)里我們解釋了池化層能降低卷積層對目標位置的敏感度。除此之外,我們還可以通過對圖像隨機裁剪來讓物體以不同的比例出現在圖像的不同位置,這樣能夠降低模型對目標位置的敏感性。

在下面的代碼里,我們每次隨機裁剪出一塊面積為原面積10% ~ 100%的區域,且該區域的寬高和高之比隨機取自 0.5 ~ 2, 然后再將該區域的寬和高分別縮放到200像素。若無特殊說明, 本節中 a 和 b之間的隨機數指的是從區間[a,b]中隨機均勻采樣所得的連續值.

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio = (0.5, 2))
apply(img, shape_aug, scale =3)

在這里插入圖片描述

9.1.1.2 變化顏色

另一類增廣方法是變化顏色。我們可以從4個方面改變圖像的顏色: 亮度(brightness)、對比度(contrast)、飽和度(saturation)和色調(hue)。再下面的例子里,我們將圖像的亮度隨機變化為原亮度的50% (1 - 0.5) ~ 150% (1 + 0.5)

apply(img, torchvision.transforms.ColorJitter(brightness = 0.5), scale = 2.5)

在這里插入圖片描述
我們也可以隨機色調

apply(img, torchvision.transforms.ColorJitter(hue = 0.5), scale = 3)

在這里插入圖片描述
類似地,我們也可以隨機變化圖像的對比度。

apply(img, torchvision.transforms.ColorJitter(contrast = 0.5), scale = 3)

在這里插入圖片描述
我們也可以同時設置如何隨機變化圖像的亮度(brightness)、對比度(contrast)、飽和度(saturation)和色調(hue)。

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue = 0.5
)apply(img, color_aug, num_cols=3, num_rows=2, scale =2.7)

在這里插入圖片描述

9.1.1.3 疊加多個圖像增廣方法

實際應用中我們將會將多個圖像增廣方法疊加使用。我們可以通過Compose實例將上面定義的多個圖像增廣方法疊加起來,再應用到每張圖像上mm

augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])apply(img, augs)

在這里插入圖片描述

9.1.2 使用圖像增廣訓練模型

下面我們來看一個將圖像增廣應用在實際訓練中的例子。這里我們使用CIFAR-10數據集,而不是之前我們一直使用的Fashion-MNIST數據集。這是因為Fashion-MNIST數據集中物體的位置和尺寸都已經經過歸一化處理,而CIFAR-10數據集中物體的顏色和大小更加顯著。下面展示了CIFAR-10數據集中前32張訓練圖像。

all_images = torchvision.datasets.CIFAR10(train= True, root="~/Datasets/CIFAR", download=True)
# all_images的每一個元素都是(image, label)

在這里插入圖片描述

注: 此處根據下載的位置,用迅雷下載會比較快

show_images([all_images[i][0] for i in range(8)], 2, 4, scale=3)

在這里插入圖片描述
為了在預測時得到確定的結果,我們通常只將圖像增廣應用在訓練樣本上,而不在預測時使用含隨機操作的圖像增廣。在這里我們只使用最簡單的隨機左右翻轉。此外,我們使用ToTensor將小批量圖像轉成PyTorch需要的格式,即形狀為(batch_size, channels, height, width)、值域在0~1之間且類型為32位浮點數

flip_aug = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()
])no_aug = torchvision.transforms.Compose([torchvision.transforms.ToTensor()
])

接下來,我們定義一個輔助函數來方便讀取圖像并應用圖像增廣.

num_workers = 0 if sys.platform.startswith('win32') else 4def load_cifar10(is_train, augs, batch_size, root = "~/Datasets/CIFAR"):datasets = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download =True)return DataLoader(datasets, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)

使用圖像增廣訓練模型

# 定義train函數使用GPU訓練并評價模型def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):net = net.to(device)print("training on", device)batch_count =0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = d2l.evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

然后可以定義train_with_data_aug函數使用圖像增廣來訓練模型了。該函數使用Adam算法作為訓練使用的優化算法,然后將圖像增廣應用于訓練數據集之上,最好調用剛才定義的train函數訓練并評價模型。

def train_with_data_aug(train_augs, test_augs, lr = 0.001):batch_size, net = 256, d2l.resnet18(10)optimizer = torch.optim.Adam(net.parameters(), lr=lr)loss = torch.nn.CrossEntropyLoss()train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=20)

下面使用隨機左右翻轉的圖像增廣來訓練模型

train_with_data_aug(flip_aug, no_aug)

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/250113.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/250113.shtml
英文地址,請注明出處:http://en.pswp.cn/news/250113.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

mysql綠色版安裝

導讀:MySQL是一款關系型數據庫產品,官網給出了兩種安裝包格式:MSI和ZIP。MSI格式是圖形界面安裝方式,基本只需下一步即可,這篇文章主要介紹ZIP格式的安裝過程。ZIP Archive版是免安裝的。只要解壓就行了。 一、首先下…

在微信瀏覽器字體被調大導致頁面錯亂的解決辦法

iOS的解決方案是覆蓋掉微信的樣式: body { /* IOS禁止微信調整字體大小 */-webkit-text-size-adjust: 100% !important; } 安卓的解決方案是通過 WeixinJSBridge 對象將網頁的字體大小設置為默認大小,并且重寫設置字體大小的方法,讓用戶不能在…

[pytorch、學習] - 9.2 微調

參考 9.2 微調 在前面得一些章節中,我們介紹了如何在只有6萬張圖像的Fashion-MNIST訓練數據集上訓練模型。我們還描述了學術界當下使用最廣泛規模圖像數據集ImageNet,它有超過1000萬的圖像和1000類的物體。然而,我們平常接觸到數據集的規模通常在這兩者之間。 假設我們想從圖…

Springboot默認加載application.yml原理

Springboot默認加載application.yml原理以及擴展 SpringApplication.run(…)默認會加載classpath下的application.yml或application.properties配置文件。公司要求搭建的框架默認加載一套默認的配置文件demo.properties,讓開發人員實現“零”配置開發,但…

java 集合(Set接口)

Set接口:無序集合,不允許有重復值,允許有null值 存入與取出的順序有可能不一致 HashSet:具有set集合的基本特性,不允許重復值,允許null值 底層實現是哈希表結構 初始容量為16 保存自定義對象時,保證數據的唯…

關于mac機抓包的幾點基礎知識

1. 我使用的抓包工具為WireShark,以下操作按我當前的版本(Version 2.6.1)做的,以前的版本或者以后的版本可能有稍微的區別。 2. 將mac設置為熱點:打開系統偏好設置,點擊共享: 然后點擊WIFI選項,設置WIFI名…

SpringBoot啟動如何加載application.yml配置文件

一、前言 在spring時代配置文件的加載都是通過web.xml配置加載的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多數都是通過指定路徑的文件名的形式去告訴spring該加載哪個文件&#xff1b; <context-param><param-name>contextConfigLocat…

[github] - git使用小結(分支拉取、版本回退)

1. 首次(fork項目之后) $ git clone [master] $ git branch -a $ git checkout -b [自己的分支名] [遠程倉庫的分支名]克隆的是主干網絡 2. 再次拉取代碼 $ git pull [master下選擇分支名] [分支名] $ git push origin HEAD:[分支名]拉取首先得進入主倉(不是自己的遠程倉)然后…

MYSQL 查看最大連接數和修改最大連接數

MySQL查看最大連接數和修改最大連接數 1、查看最大連接數show variables like %max_connections%;2、修改最大連接數set GLOBAL max_connections 200; 以下的文章主要是向大家介紹的是MySQL最大連接數的修改&#xff0c;我們大家都知道MySQL最大連接數的默認值是100, 這個數值…

阿里云服務器端口開放對外訪問權限

登陸阿里云管理控制臺 點擊自己的實例 點擊安全組配置 點擊配置規則 點擊添加安全組規則 配置出入放心&#xff0c;和開放的端口號&#xff0c;以及那些網段可以訪問&#xff0c;這里設置所有網段都可以訪問 轉自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

數據分頁功能是我們軟件系統中必備的功能&#xff0c;在持久層使用mybatis的情況下&#xff0c;pageHelper來實現后臺分頁則是我們常用的一個選擇&#xff0c;所以本文專門類介紹下。 PageHelper原理 相關依賴 <dependency><groupId>org.mybatis</groupId>&…

10-多寫一個@Autowired導致程序崩了

再是javaweb實驗六中&#xff0c;是讓我們改代碼&#xff0c;讓它跑起來&#xff0c;結果我少注釋了一個&#xff0c;導致一直報錯&#xff0c;檢查許久沒有找到&#xff0c;最后通過代碼替換逐步查找&#xff0c;才發現問題。 轉載于:https://www.cnblogs.com/zhumengdexiaoba…

Java class不分32位和64位

1、32位JDK編譯的java class在32位系統和64位系統下都可以運行&#xff0c;64位系統兼容32位程序&#xff0c;可以理解。2、無論是Linux還是Windows平臺下的JDK編譯的java class在Linux、Windows平臺下通用&#xff0c;Java跨平臺特性。3、64位JDK編譯的java class在32位的系統…

包裝對象

原文地址&#xff1a;https://wangdoc.com/javascript/ 定義 對象是JavaScript語言最主要的數據類型&#xff0c;三種原始類型的值--數值、字符串、布爾值--在一定條件下&#xff0c;也會自動轉為對象&#xff0c;也就是原始類型的包裝對象。所謂包裝對象&#xff0c;就是分別與…

[C++] 轉義序列

參考 C Primer(第5版)P36 名稱轉義序列換行符\n橫向制表符\t報警(響鈴)符\a縱向制表符\v退格符\b雙引號"反斜杠\問號?單引號’回車符\r進紙符\f

vue使用(二)

本節目標&#xff1a; 1.數據路徑的三種方式 2.{{}}和v-html的區別 1.綁定圖片的路徑 方法一&#xff1a;直接寫路徑 <img src"http://pic.baike.soso.com/p/20140109/20140109142534-188809525.jpg"> 方法二&#xff1a;在data中寫路徑&#xff0c;在…

typedef 為類型取別名

#include <stdio.h> int main() {   typedef int myint; // 為int 類型取自己想要的名字   myint a 10;   printf("%d", a);   return 0;} 其他類型的用法也是一樣的 typedef 類型 自己想要取得名字; 轉載于:https://www.cnblogs.com/hello-dummy/p/9…

【C++】如何提高Cache的命中率,示例

參考鏈接 https://stackoverflow.com/questions/16699247/what-is-a-cache-friendly-code 只是堆積&#xff1a;緩存不友好與緩存友好代碼的典型例子是矩陣乘法的“緩存阻塞”。 樸素矩陣乘法看起來像 for(i0;i<N;i) {for(j0;j<N;j) {dest[i][j] 0;for( k;k<N;i)…

springboot---整合redis

pom.xml新增 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency>代碼結構如下 其中redis.yml是連接redis的配置文件&#xff0c;RedisConfig.java是java配置…

[Head First Java] - 簡單的建議程序

參考 - p481、p484 與我對接的業務層使用的是JAVA語言,因此花點時間入門java.下面幾篇博客可能都是關于java的,我覺得在工作中可能會遇到的 簡單的通信 DailyAdviceClient(客戶端程序) import java.io.*; import java.net.*;public class DailyAdviceClient{public void go()…