華為開源自研AI框架昇思MindSpore應用案例:基于MindSpore框架實現PWCNet光流估計

如果你對MindSpore感興趣,可以關注昇思MindSpore社區

在這里插入圖片描述

在這里插入圖片描述

1 環境準備

1.進入ModelArts官網
云平臺幫助用戶快速創建和部署模型,管理全周期AI工作流,選擇下面的云平臺以開始使用昇思MindSpore,可以在昇思教程中進入ModelArts官網

創建notebook,點擊【打開】啟動,進入ModelArts調試環境頁面。

注意選擇西南-貴陽一,mindspore_2.3.0

在這里插入圖片描述

等待環境搭建完成

在這里插入圖片描述

下載案例notebook文件

基于MindSpore框架實現PWCNet光流估計:https://github.com/mindspore-courses/applications/blob/master/pwc_net/pwc_net.ipynb

選擇ModelArts Upload Files上傳.ipynb文件

在這里插入圖片描述

進入昇思MindSpore官網,點擊上方的安裝獲取安裝命令

在這里插入圖片描述

MindSpore版本升級,鏡像自帶的MindSpore版本為2.3,該活動要求在MindSpore2.4.0版本體驗,所以需要進行MindSpore版本升級。
在這里插入圖片描述

命令如下:

export no_proxy='a.test.com,127.0.0.1,2.2.2.2'
pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.4.0/MindSpore/unified/aarch64/mindspore-2.4.0-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

在這里插入圖片描述

回到Notebook中,在第一塊代碼前加命令

pip install --upgrade pippip install mindvisionpip install download

2 案例實現

import os
import logging
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# logging.basicConfig(level=logging.ERROR)
logging.disable(logging.WARNING)
%matplotlib inline
import mindspore as mm.set_context(mode=m.GRAPH_MODE, device_target="GPU") # 訓練時使用靜態圖
# m.set_context(mode=m.PYNATIVE_MODE, device_target="GPU") # 設置為動態圖方便debug

將MindSpore設置為圖執行模式,并設置為使用GPU進行訓練。

train_data_path = r"data/MPI-Sintel-complete/training"
val_data_path = r"data/MPI-Sintel-complete/training"pretrained_path = r"pretrained_model/pwcnet-mindspore.ckpt"batch_size = 4
lr = 0.0001
num_parallel_workers = 4
lr_milestones = '6,10,12,16'
lr_gamma = 0.5
max_epoch = 20
loss_scale = 1024
warmup_epochs = 1

設置數據集路徑,設置訓練參數,包括batch_size、epoch_size、learning_rate等。

import mindspore.dataset.vision as Vfrom src.dataset_utils import RandomGammaaugmentation_list = [V.ToPIL(),V.RandomColorAdjust(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),V.ToTensor(),RandomGamma(min_gamma=0.7, max_gamma=1.5, clip_image=True),
]

設置數據增強方法,包括使用隨機顏色變換和隨機Gamma變換。

from black import out
from src.dataset import getFlyingChairsTrainData, getSintelValDatadl_train, len_dl_train, dataset = getSintelValData(root=train_data_path,split="train",augmentations=augmentation_list,batch_size=batch_size,num_parallel_workers=num_parallel_workers,
)
dl_val, len_dl_val, val_dataset = getSintelValData(root=val_data_path,split="train",augmentations=augmentation_list,batch_size=batch_size,num_parallel_workers=num_parallel_workers,
)
train_len = dl_train.get_dataset_size()
dl_train = dl_train.repeat(max_epoch)
print(f"The dataset size of dl_train: {dl_train.get_dataset_size()}")
print(f"The dataset size of dl_val: {dl_val.get_dataset_size()}")dict_datasets = next(dl_train.create_dict_iterator())
print(dict_datasets.keys())
print(dict_datasets["im1"].shape)
print(dict_datasets["im2"].shape)
print(dict_datasets["flo"].shape)
print(type(dict_datasets["flo"]))
print(dict_datasets["flo"].max(), dict_datasets["flo"].min())
print(dict_datasets["flo"].max() * 0.05, dict_datasets["flo"].min() * 0.05)
dl_train = dl_train.create_tuple_iterator(output_numpy=False, do_copy=False)
dl_val = dl_val.create_tuple_iterator(output_numpy=False, do_copy=False)

查看數據集的訓練集和測試集的數量。同時查看數據集中RGB圖片和光流圖片的分辨率大小。

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import flow_visfig = matplotlib.pyplot.gcf()
fig.set_size_inches(18.5, 10.5)
ax = plt.subplot(131)
ax.imshow(np.transpose(dict_datasets["im1"][0].asnumpy(), (1, 2, 0)))
ax.set_title("Image 1")
ax.set_axis_off()
ax = plt.subplot(132)
ax.imshow(np.transpose(dict_datasets["im2"][0].asnumpy(), (1, 2, 0)))
ax.set_title("Image 2")
ax.set_axis_off()
ax = plt.subplot(133)
ax.imshow(flow_vis.flow_to_color(np.transpose(dict_datasets["flo"][0].asnumpy(), (1, 2, 0)))
)
ax.set_axis_off()
ax.set_title("Optical Flow")

在這里插入圖片描述

使用flow_vismatplotlib庫分別將光流圖片與RGB圖片可視化。

# from src.pwc_net import PWCNet
# from src.loss import PyramidEPE, MultiStepLR# from mindspore.nn import Adam# network = PWCNet()
# criterion = PyramidEPE()# optimizer = Adam(params=net.trainable_params(), learning_rate=lr, loss_scale=loss_scale)
from collections import Counter
import numpy as npclass _WarmUp():"""Basic class for warm up"""def __init__(self, warmup_init_lr):self.warmup_init_lr = warmup_init_lrdef get_lr(self):# Get learning rate during warmupraise NotImplementedErrorclass _LRScheduler():"""Basic class for learning rate scheduler"""def __init__(self, lr, max_epoch, steps_per_epoch):self.base_lr = lrself.steps_per_epoch = steps_per_epochself.total_steps = int(max_epoch * steps_per_epoch)def get_lr(self):# Compute learning rate using chainable form of the schedulerraise NotImplementedErrorclass _LinearWarmUp(_WarmUp):"""Class for linear warm up"""def __init__(self, lr, warmup_epochs, steps_per_epoch, warmup_init_lr=0):self.base_lr = lrself.warmup_init_lr = warmup_init_lrself.warmup_steps = int(warmup_epochs * steps_per_epoch)super(_LinearWarmUp, self).__init__(warmup_init_lr)def get_warmup_steps(self):return self.warmup_stepsdef get_lr(self, current_step):lr_inc = (float(self.base_lr) - float(self.warmup_init_lr)) / float(self.warmup_steps)lr = float(self.warmup_init_lr) + lr_inc * current_stepreturn lrclass MultiStepLR(_LRScheduler):"""Multi-step learning rate schedulerDecays the learning rate by gamma once the number of epoch reaches one of the milestones.Args:lr (float): Initial learning rate which is the lower boundary in the cycle.milestones (list): List of epoch indices. Must be increasing.gamma (float): Multiplicative factor of learning rate decay.steps_per_epoch (int): The number of steps per epoch to train for.max_epoch (int): The number of epochs to train for.warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0Outputs:numpy.ndarray, shape=(1, steps_per_epoch*max_epoch)Example:>>> # Assuming optimizer uses lr = 0.05 for all groups>>> # lr = 0.05     if epoch < 30>>> # lr = 0.005    if 30 <= epoch < 80>>> # lr = 0.0005   if epoch >= 80>>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90)>>> lr = scheduler.get_lr()"""def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch, warmup_epochs=0):self.milestones = Counter(milestones)self.gamma = gammaself.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch)super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch)def get_lr(self):warmup_steps = self.warmup.get_warmup_steps()lr_each_step = []current_lr = self.base_lrfor i in range(self.total_steps):if i < warmup_steps:lr = self.warmup.get_lr(i+1)else:cur_ep = i // self.steps_per_epochif i % self.steps_per_epoch == 0 and cur_ep in self.milestones:current_lr = current_lr * self.gammalr = current_lrlr_each_step.append(lr)return np.array(lr_each_step).astype(np.float32)

初始化神經網絡、損失函數、優化器、模型和回調函數。

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore import Model, load_checkpoint, load_param_into_netfrom src.pwc_net import BuildTrainNetwork, PWCNet
from src.loss import PyramidEPEclass CustomWithLossCell(nn.Cell):def __init__(self, network, criterion):super(CustomWithLossCell, self).__init__(auto_prefix=False)self.network = networkself.criterion = criteriondef construct(self, im1, im2, flow):out = self.network(im1, im2)loss = self.criterion(out, flow)return lossnetwork = PWCNet()
criterion = PyramidEPE()param_dict = load_checkpoint(pretrained_path)
param_dict_new = {}
for key, values in param_dict.items():if key.startswith('moment1.' or 'moment2' or 'global_step' or 'beta1_power' or 'beta2_power' or'learning_rate'):continueelif key.startswith('network.'):param_dict_new[key[8:]] = valueselse:param_dict_new[key] = values
load_param_into_net(network, param_dict_new)train_net = BuildTrainNetwork(network, criterion)# model = Model(
#     network=net_with_loss,
#     # loss_fn=criterion,
#     optimizer=optimizer,
#     eval_network=net_with_loss,
#     metrics={"loss"},
#     amp_level="O0",
# )
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore import Model, load_checkpoint, load_param_into_netfrom src.pwc_net import BuildTrainNetwork, PWCNet
from src.loss import PyramidEPEclass CustomWithLossCell(nn.Cell):def __init__(self, network, criterion):super(CustomWithLossCell, self).__init__(auto_prefix=False)self.network = networkself.criterion = criteriondef construct(self, im1, im2, flow):out = self.network(im1, im2)loss = self.criterion(out, flow)return lossnetwork = PWCNet()
criterion = PyramidEPE()param_dict = load_checkpoint(pretrained_path)
param_dict_new = {}
for key, values in param_dict.items():if key.startswith('moment1.' or 'moment2' or 'global_step' or 'beta1_power' or 'beta2_power' or'learning_rate'):continueelif key.startswith('network.'):param_dict_new[key[8:]] = valueselse:param_dict_new[key] = values
load_param_into_net(network, param_dict_new)train_net = BuildTrainNetwork(network, criterion)# model = Model(
#     network=net_with_loss,
#     # loss_fn=criterion,
#     optimizer=optimizer,
#     eval_network=net_with_loss,
#     metrics={"loss"},
#     amp_level="O0",
# )
print('Start training...')
for i, data in enumerate(dl_train):# clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss# print(data[0].shape, data[1].shape, data[2].shape)# print(data[0].max(), data[0].min(), data[1].max(), data[1].min(), data[2].max(), data[2].min())loss = train_net_step(data[0], data[1], data[2])# print(loss)loss_meter.update(loss.asnumpy())if i == 0:time_for_graph_compile = time.time() - create_network_startprint('graph compile time={:.2f}s'.format(time_for_graph_compile))if i % 10 == 0 and i > 0:t_now = time.time()epoch = int(i / train_len)print('epoch: [{}], iter: [{}], loss: [{:.4f}], time: [{:.2f}]s'.format(epoch, i, loss_meter.avg, t_now - t_end))t_end = t_nowloss_meter.reset()if i % train_len == 0  and i > 0:epoch_time_used = time.time() - t_epochepoch = int(i / train_len)fps = batch_size * train_len / epoch_time_usedprint('=================================================')print('epoch[{}], iter[{}], [{:.2f}] imgs/sec'.format(epoch, i, fps))t_epoch = time.time()validation_loss = 0sum_num = 0for _, val_data in enumerate(dl_val):network.set_train(False)val_output = network(val_data[0], val_data[1], training=False)val_loss = criterion(val_output, val_data[2], training=False)validation_loss += val_losssum_num += 1if (validation_loss / sum_num) < best_val_loss:best_val_loss = validation_loss / sum_numprint('validation EPE: {}, best validation EPE: {}'.format(validation_loss / sum_num, best_val_loss))

在這里插入圖片描述

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import flow_visout_flow = network(dict_datasets['im1'][3][None, ...], dict_datasets['im2'][3][None, ...], training=False)fig = matplotlib.pyplot.gcf()
fig.set_size_inches(18.5, 10.5)
ax = plt.subplot(141)
ax.imshow(np.transpose(dict_datasets["im1"][3].asnumpy(), (1, 2, 0)))
ax.set_title("Image 1")
ax.set_axis_off()
ax = plt.subplot(142)
ax.imshow(np.transpose(dict_datasets["im2"][3].asnumpy(), (1, 2, 0)))
ax.set_title("Image 2")
ax.set_axis_off()
ax = plt.subplot(143)
ax.imshow(flow_vis.flow_to_color(np.transpose(dict_datasets["flo"][3].asnumpy(), (1, 2, 0)))
)
ax.set_axis_off()
ax.set_title("Optical Flow")
ax = plt.subplot(144)
ax.imshow(flow_vis.flow_to_color(np.transpose(out_flow[0].asnumpy(), (1, 2, 0)))
)
ax.set_axis_off()
ax.set_title("Predicted Optical Flow")
# plt.show()

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/75149.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/75149.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/75149.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

虛幻基礎:UI

文章目錄 控件藍圖可以裝載其他控件藍圖可以安裝其他藍圖接口 填充&#xff1a;相對于父組件填充水平框尺寸—填充—0.5&#xff1a;改變填充的尺寸填充—0.5&#xff1a;改變與父組件的距離 錨點&#xff1a;相對于父組件的控件坐標系原點&#xff0c;屏幕比例改變時&#xff…

監控平臺——SkyWalking部署

一、環境準備 先下載SkyWalking安裝包&#xff0c;需要注意的是SkyWalking 版本在10.X以上使用的nacos-client是2.X&#xff0c;如果安裝的Nacos版本是1.X就會存在兼容性的問題。由于本人使用的SpringBoot項目是2.7.X版本&#xff0c;安裝的Nacos版本只能是1.X版本的&#xff…

熱門索尼S-Log3電影感氛圍旅拍LUTS調色預設 Christian Mate Grab - Sony S-Log3 Cinematic LUTs

熱門索尼S-Log3電影感氛圍旅拍LUTS調色預設 Christian Mate Grab – Sony S-Log3 Cinematic LUTs 我們最好的 Film Look S-Log3 LUT 的集合&#xff0c;適用于索尼無反光鏡相機。無論您是在戶外、室內、風景還是旅行電影中拍攝&#xff0c;這些 LUT 都經過優化&#xff0c;可為…

自動化工作流工具的綜合對比與推薦

最近收到很多朋友私信我說&#xff1a;“刷短視頻的時候&#xff0c;總是刷到自動化工作流的工具&#xff0c;有好多直播間都在宣傳&#xff0c;不知道哪款工具好”。我花了點時間&#xff0c;做了一下測試&#xff0c;大家可以參考一下&#xff0c;以下內容&#xff1a; 以下…

fircrawl本地部署

企業內部的網站作為知識庫給dify使用&#xff0c;使用fircrawl來爬蟲并且轉換為markdown。 ? git clone https://github.com/mendableai/firecrawl.gitcd ./firecrawl/apps/api/ cp .env.example .env cd ~/firecrawl docker compose up -d 官方&#xff1a; https://githu…

day17 學習筆記

文章目錄 前言一、數組的增刪改查1.resize函數2.append函數3.insert函數4.delete函數5.argwhere函數6.unique函數 二、統計函數1.amax&#xff0c;amin函數2.ptp函數3.median函數4.mean函數5.average函數6.var&#xff0c;std函數 前言 通過今天的學習&#xff0c;我掌握了num…

CentOS 8 Stream 配置在線yum源參考 —— 筑夢之路

CentOS 8 Stream ISO 文件下載地址&#xff1a;http://mirrors.aliyun.com/centos-vault/8-stream/isos/x86_64/CentOS-Stream-8-20240603.0-x86_64-dvd1.isoCentOS 8 Stream 網絡引導ISO 文件下載地址&#xff1a;http://mirrors.aliyun.com/centos-vault/8-stream/isos/x86_6…

網絡原理-TCP/IP

網絡原理學習筆記&#xff1a;TCP/IP 核心概念 本文是我在學習網絡原理時整理的筆記&#xff0c;主要涵蓋傳輸層、網絡層和數據鏈路層的核心協議和概念&#xff0c;特別是 TCP, UDP, IP, 和以太網。 一、傳輸層 (Transport Layer) 傳輸層負責提供端到端&#xff08;進程到進…

EF Core 執行原生SQL語句

文章目錄 前言一、執行查詢&#xff08;返回數據&#xff09;1&#xff09; 使用 FromSqlRaw或 FromSqlInterpolated 方法&#xff0c;適用于 DbSet<T>&#xff0c;返回實體集合。2&#xff09;結合 LINQ 查詢3&#xff09;執行任意原生SQL查詢語句&#xff08;使用ADO.N…

Unity LOD Group動態精度切換算法(基于視錐+運動速度)技術詳解

一、動態LOD技術背景與核心挑戰 1. 傳統LOD系統的局限 靜態閾值切換&#xff1a;僅基于距離的切換在動態場景中表現不佳 視覺突變&#xff1a;快速移動時LOD層級跳變明顯 性能浪費&#xff1a;靜態算法無法適應復雜場景變化 對惹&#xff0c;這里有一個游戲開發交流小組&…

MyBatis復雜查詢——一對一、一對多

目錄 &#xff08;一&#xff09;復雜查詢&#xff1a;1對1關系 【任務】數據庫里有學生表(student)和學生證信息表(student_card)&#xff0c;表結構如下所示&#xff0c;要求使用MyBatis框架查詢所有的學生信息以及每位學生的學生證信息 解決方案1&#xff1a;關聯查詢實現…

【服務端】使用conda虛擬環境部署Django項目

寫在開頭 為了與客戶端的Deep search配合&#xff0c;需要整一個后臺管理來保存和管理deep search的數據資料。選擇前端框架Vue-Vben-Admin Django后臺服務來實現這個項目。 廢話結束&#xff0c;從零開始。。。。 一、環境搭建 1. 安裝 Anaconda 下載 Anaconda&#xff1…

Python爬蟲-爬取大麥網演出詳情頁面數據

前言 本文是該專欄的第50篇,后面會持續分享python爬蟲干貨知識,記得關注。 本文,筆者以大麥網平臺為例。基于Python,實現獲取演出詳情頁面的演出信息。 廢話不多說,具體實現思路和詳細邏輯,筆者將在正文結合完整代碼進行詳細介紹。接下來,跟著筆者直接往下看正文詳細內…

多onnx模型導出合并調研(文本檢測+方向分類+文本識別)

??主頁:吾名招財 ??簡介:工科學碩,研究方向機器視覺,愛好較廣泛… ???簽名:面朝大海,春暖花開! 多onnx模型合并導出調研(文本檢測+方向分類+文本識別) 引言1,嘗試合并兩個模型(文本方向分類+文本識別模型)(并行合并)(1)文本方向分類(2)文本識別模型(…

Flink介紹——實時計算核心論文之S4論文詳解

引入 在上一篇我們對Flink的發展歷史有了全局的了解&#xff0c;下面我們會通讀幾篇分布式實時處理相關的重要論文&#xff0c;從S4到Storm&#xff0c;再從MillWheel到Dataflow&#xff0c;最后到Flink。 通過深入梳理分布式實時處理技術的發展脈絡&#xff0c;了解這些年技…

【商城實戰(97)】ELK日志管理系統的全面應用

【商城實戰】專欄重磅來襲!這是一份專為開發者與電商從業者打造的超詳細指南。從項目基礎搭建,運用 uniapp、Element Plus、SpringBoot 搭建商城框架,到用戶、商品、訂單等核心模塊開發,再到性能優化、安全加固、多端適配,乃至運營推廣策略,102 章內容層層遞進。無論是想…

Linux系統-ls命令

一、ls命令的定義 Linux ls命令&#xff08;英文全拼&#xff1a;list directory contents&#xff09;用于顯示指定工作目錄下之內容&#xff08;列出目前工作目錄所含的文件及子目錄)。 二、ls命令的語法 ls [選項] [目錄或文件名] ls [-alrtAFR] [name...] 三、參數[選項…

游戲被外掛攻破?金融數據遭篡改?AI反作弊系統實戰方案(代碼+詳細步驟)

一、背景與需求分析 隨著游戲行業與金融領域的數字化進程加速,作弊行為(如游戲外掛、金融數據篡改)日益復雜化。傳統基于規則的防御手段已難以應對新型攻擊,而AI技術通過動態行為分析、異常檢測等能力,為安全領域提供了革命性解決方案。本文以游戲反作弊系統和金融數據安…

Node.js 路由 - 初識 Express 中的路由

目錄 Node.js 路由 - 初識 Express 中的路由 1. 什么是路由&#xff1f; 2. 安裝 Express 3. 創建 server.js 4. 運行服務器 5. 測試路由 5.1 訪問主頁 5.2 訪問用戶路由 5.3 發送 POST 請求 6. 結語 1. 什么是路由&#xff1f; 路由&#xff08;Routing&#xff09…

面經-項目

項目 項目(重點)問題1:描述在網頁中題目點擊提交后到題目結果出現的一系列后臺反應【1】如何獲取到用戶提交的代碼的?【2】_1. 題目細節都有哪些?【2】_2. 題目信息怎么存儲的?【3】負載均衡算法的實現?【4】oj_server怎么連接對應的compile_server(編譯主機)的?【5】oj_…