pytorch學習筆記(五)-- 計算機視覺的遷移學習

系列文章目錄

pytorch學習筆記(一)-- pytorch深度學習框架基本知識了解

pytorch學習筆記(二)-- pytorch模型開發步驟詳解

pytorch學習筆記(三)-- TensorBoard的介紹

pytorch學習筆記(四)-- TorchVision 物體檢測微調教程

pytorch學習筆記(五)-- 計算機視覺的遷移學習

文章目錄

系列文章目錄

文章目錄

前言

一、加載數據?

二、訓練模型

三、可視化模型預測

四、卷積網絡微調

???微調 ConvNet:

ConvNet 作為固定特征提取器:

五、自定義圖像的推理

總結


前言

????????在本章節,您將學習如何使用遷移學習訓練卷積神經網絡進行圖像分類。您可以在 cs231n notes 筆記中閱讀有關遷移學習的更多信息。

????????一般來說,大家都不會從頭開始訓練卷積神經網絡,而是先在較大的數據集上做預訓練,差不多成熟了,然后再把卷積網絡在自己的任務上做初始化,或者特征提取器。

這兩種主要的遷移學習場景如下:

  • 微調 ConvNet:我們不是使用隨機初始化,而是使用預訓練網絡(例如在 imagenet 1000 數據集上訓練的網絡)來初始化網絡。其余訓練看起來與往常一樣。
  • ConvNet 作為固定特征提取器:在這里,我們將凍結除最終全連接層之外的所有網絡的權重。最后一個全連接層將被替換為具有隨機權重的新層,并且只訓練這一層。

一、加載數據?

????????我們使用torchvision 和 torch.utils.data數據包進行數據加載,今天的任務是訓練一個模型用來分辨螞蟻和蜜蜂,我們有120張螞蟻和蜜蜂的照片用于訓練,以及75張用于測試蜜蜂和螞蟻的照片。

數據集下載路徑:MyDataset: 數據集倉庫,包括各種網站搜刮的,以及一些自定義的數據。方便后續神經網絡的訓練 - Gitee.com

????????通常,如果從頭開始訓練,這個數據集太小了,無法進行推廣。由于我們使用遷移學習,我們就可以相當好地進行推廣。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#可視化部分圖片,確認下效果
def imshow(inp, title=None):"""Display image for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])

二、訓練模型

編寫一個通用函數來訓練模型。

  • 安排這個learning rate
  • 保存模型

???????參數 scheduler 是來自 torch.optim.lr_scheduler 的 LR 調度程序對象,關于這個scheduler在后續模型優化的章節會講到,也是一個非常強大的功能,這里就先不贅述了。

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()# Create a temporary directory to save training checkpointswith TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')torch.save(model.state_dict(), best_model_params_path)best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), best_model_params_path)print()time_elapsed = time.time() - sinceprint(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'Best val Acc: {best_acc:4f}')# load best model weightsmodel.load_state_dict(torch.load(best_model_params_path, weights_only=True))return model

三、可視化模型預測

? ? ? ? 經過上一節的訓練,我們現在看看模型的預測結果怎么樣。作為傳統的程序開發者,我們習慣通過打印來驗證結果,但是這玩意兒只有開發兄弟能懂哈。Pytorch畢竟是基于Python的深度學習框架,工具包是應用盡有,所以我們可以以可視化的圖形來顯示預測的效果。說個題外話,將工作結果可視化這個習慣,各位開發兄弟得學起來,以后就可以一手抓開發,一手抓產品,既可以跟開發同事一起奮斗,又可以跟老板以及客戶吹牛逼,路就走寬了。

def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title(f'predicted: {class_names[preds[j]]}')imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)

四、卷積網絡微調

? ? ? ? 上三節都是準備工作,這一節,我們就講一下兩種遷移學習的使用。

???微調 ConvNet:

#加載一個預訓練的模型并且重置全連接層
model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)
visualize_model(model_ft)

ConvNet 作為固定特征提取器:

? ? ? ? 注意:這里,我們需要凍結除最后一層之外的所有網絡。我們需要設置requires_grad = False來凍結參數,這樣梯度就不會在backward()中計算。

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():param.requires_grad = False# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25)
visualize_model(model_conv)
plt.ioff()
plt.show()

五、自定義圖像的推理

????????使用訓練的模型進行自定義圖片的預測并且顯示預測圖片對應的標簽

def visualize_model_predictions(model,img_path):was_training = model.trainingmodel.eval()img = Image.open(img_path)img = data_transforms['val'](img)img = img.unsqueeze(0)img = img.to(device)with torch.no_grad():outputs = model(img)_, preds = torch.max(outputs, 1)ax = plt.subplot(2,2,1)ax.axis('off')ax.set_title(f'Predicted: {class_names[preds[0]]}')imshow(img.cpu().data[0])model.train(mode=was_training)visualize_model_predictions(model_conv,img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)plt.ioff()
plt.show()

總結

? ? 遷移學習其實是實際工作中用的非常多的一種神經網絡開發方法,對于開發者來說,從頭構建一個模型,開發難度很大,并且個人很難去實現它的訓練,這個需要龐大的數據集以及場景測試。? ??????????

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/89801.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/89801.shtml
英文地址,請注明出處:http://en.pswp.cn/web/89801.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

數字IC后端培訓教程之數字后端項目典型項目案例解析

數字IC后端低功耗設計實現案例分享(3個power domain,2個voltage domain) Q1: 電路如下圖,clk是一個很慢的時鐘test_clk(屬于DFT的),DFF1與and 形成一個clock gating check。跑pr 發現,時鐘樹綜合CTS階段(C…

2025 Data Whale x PyTorch 安裝學習筆記(Windows 版)

一、Anaconda 的安裝與基本操作 1. 安裝 Anaconda/miniconda 官方鏈接:Anaconda | Individual Edition 根據系統版本選擇合適的安裝包下載并安裝。 2. 檢驗安裝 打開 “開始” 菜單,找到 “Anaconda Prompt”(一般在 Anaconda3 文件夾…

mac OS上docker安裝zookeeper

拉取鏡像:$ docker pull zookeeper:3.5.7 3.5.7: Pulling from library/zookeeper 3.5.7: Pulling from library/zookeeper 3.5.7: Pulling from library/zookeeper no matching manifest for linux/arm64/v8 in the manifest list entries報錯:由于時M3…

設備通過4G網卡接入EasyCVR視頻融合平臺,出現無法播放的問題排查和解決

EasyCVR視頻融合平臺作為支持多協議接入、多設備集中管理的綜合性視頻解決方案,可實現各類終端設備的視頻流匯聚與實時播放。近期收到用戶反饋,在EasyCVR平臺接入設備后出現視頻流無法播放的情況。為幫助更多用戶快速排查同類問題,現將具體處…

板凳-------Mysql cookbook學習 (十二--------3)

第二章 抽象數據類型和python類 2.5類定義實例: 學校人事管理系統中的類 import datetimeclass PersonValueError(ValueError):"""自定義異常類"""passclass PersonTypeError(TypeError):"""自定義異常類""…

css flex 布局中 flex-direction為column,如何讓子元素的寬度根據內容自動變化

在 display: flex 且 flex-direction: column 的布局中,默認情況下子元素會占滿容器的寬度。要讓子元素的寬度根據內容自適應,而不是自動拉伸填滿父容器,你可以這樣處理:? 解決方案一:設置子元素 align-self: start 或…

性能優化實踐:Modbus 在高并發場景下的吞吐量提升(二)

四、Modbus 吞吐量提升實戰策略4.1 優化網絡配置選擇合適的網絡硬件是提升 Modbus 通信性能的基礎。在工業現場,應優先選用高性能的工業級交換機和路由器。工業級交換機具備更好的抗干擾能力和穩定性,其背板帶寬和包轉發率更高,能夠滿足高并發…

上傳ipa到appstore的幾種工具

無論是用原生開發也好,使用uniapp或flutter開發也好,最好打包好的APP是需要上架appstore的。而在app store connect上架的時候,需要上傳ipa文件到app store的構建版本上。因此,需要上傳工具。下面分析下幾種上傳工具的優缺點&…

數控調壓BUCK電路 —— 基于TPS56637(TI)

0 前言 本文基于 TI 的 TPS56637 實現一個支持調壓的 BUCK 電路,包含從零開始詳細的 原理解析、原理圖、PCB 及 實測數據 本文屬于《DIY迷你數控電源》系列,本系列我們一起實現一個簡單的迷你數控電源 我是 LNY,一個在對嵌入式的所有都感興…

prometheus UI 和node_exporter節點圖形化Grafana

prometheus UI 和node_exporter節點圖形化Grafana 先簡單的安裝一下 進行時間的同步操作安裝Prometheus之前必須要先安裝ntp時間同步,因為prometheus server對系統時間的準確性要求很高,必須保證本機時間實時同步。# 用crontab進行定時的時間的同步 yum …

RabbitMQ—TTL、死信隊列、延遲隊列

上篇文章: RabbitMQ—消息可靠性保證https://blog.csdn.net/sniper_fandc/article/details/149311576?fromshareblogdetail&sharetypeblogdetail&sharerId149311576&sharereferPC&sharesourcesniper_fandc&sharefromfrom_link 目錄 1 TTL …

LVS 集群技術詳解與實戰部署

目錄 引言 一、實驗環境準備 二、理論基礎:集群與 LVS 核心原理 2.1 集群與分布式 2.2 LVS 核心原理 LVS 的 4 種工作模式 LVS 調度算法 三、LVS 部署工具:ipvsadm 命令詳解 四、實戰案例:LVS 部署詳解 案例 1:NAT 模式…

前端vue3獲取excel二進制流在頁面展示

excel二進制流在頁面展示安裝xlsx在頁面中定義一個div來展示html數據定義二進制流請求接口拿到數據并展示安裝xlsx npm install xlsx import * as XLSX from xlsx;在頁面中定義一個div來展示html數據 <div class"file-input" id"file-input" v-html&qu…

android 信息驗證動畫效果

layout_check_pro <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"android:id"id/parent"android:layout_width"wrap_content"android:layout_…

【iOS】繼承鏈

文章目錄前言什么是繼承鏈OC中的根類關于NSProxy關鍵作用1.方法查找與動態綁定2. 消息轉發3. **類型判斷與多態**繼承鏈的底層實現元類的繼承鏈總結前言 在objective-c中&#xff0c;繼承鏈是類與類之間通過父類&#xff08;Superclass&#xff09;關系形成的一層層繼承結構&am…

論文閱讀:Instruct BLIP (2023.5)

文章目錄InstructBLIP&#xff1a;邁向通用視覺語言模型的指令微調研究總結一、研究背景與目標二、核心方法數據構建與劃分模型架構訓練策略三、實驗結果零樣本性能消融實驗下游任務微調定性分析可視化結果展示四、結論與貢獻InstructBLIP&#xff1a;邁向通用視覺語言模型的指…

Elasticsearch+Logstash+Filebeat+Kibana部署【7.1.1版本】

目錄 一、準備階段 二、實驗階段 1.配置kibana主機 2.配置elasticsearch主機 3.配置logstash主機 4.配置/etc/filebeat/filebeat.yml 三、驗證 1.開啟Filebeat 2.在logstash查看 3.瀏覽器訪問kibana 一、準備階段 1.準備四臺主機kibana、es、logstash、filebeat 2.在…

Vue開發前端報錯:‘vue-cli-service‘ 不是內部或外部命令解決方案

1.Bug: 最近調試一個現有的Vue前端代碼&#xff0c;發現如下錯誤&#xff1a; vue-cli-service’ 不是內部或外部命令&#xff0c;也不是可運行的程序 或批處理文件。 2.Bug原因&#xff1a; 導入的工程缺少依賴包&#xff1a;即缺少node_modules文件夾 3.解決方案&#xff1…

AI生態,釘釘再「出招」

如果說之前釘釘的AI生態加持更多的圍繞資源和商業的底層助力&#xff0c;那么如今這種加持則是向更深層次進化&#xff0c;即真正的AI模型訓練能力加持&#xff0c;為垂類大模型創業者提供全方位的助力&#xff0c;提高創業成功率和模型產品商業化確定性。作者|皮爺出品|產業家…

XSS GAME靶場

要求用戶不參與&#xff0c;觸發alert(1337) 目錄 Ma Spaghet! Jefff Ugandan Knuckles Ricardo Milos Ah Thats Hawt Ligma Mafia Ok, Boomer Exmaple 1 - Create Example 2 - Overwrite Example 3 - Overwrite2 toString Ma Spaghet! <h2 id"spaghet&qu…