pytorch: 在訓練中保存模型,加載模型

文章目錄

  • 1. 保存整個模型
  • 2.僅保存和加載模型參數(推薦使用)
  • 3. 保存其他參數到模型中,比如optimizer,epoch 等

1. 保存整個模型

torch.save(model, 'model.pkl')
model = torch.load('model.pkl')

2.僅保存和加載模型參數(推薦使用)

torch.save(model_object.state_dict(), 'params.pkl')
model.load_state_dict(torch.load('params.pkl'))

注意在多卡訓練的時候,會多保留一個module
因此我們可以設置

if torch.cuda.device_count() > 1:  #多卡訓練torch.save(model.module.state_dict(),'params.pkl')
else:torch.save(model.state_dict(),'params.pkl')

3. 保存其他參數到模型中,比如optimizer,epoch 等

torch.save({'epoch':edix, 'state_dict':model.state_dict(),  \
'optimizer':optimizer.state_dict()}, \
f'./saved_models/{opt.exp_name}/epoch_{edix+1}.pth')
model.load_state_dict(torch.load(opt.saved_model)['state_dict'], strict=False)model.load_state_dict(torch.load(opt.saved_model)['epoch'], strict=False)
model.load_state_dict(torch.load(opt.saved_model)['optimizer'], strict=False)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/535185.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/535185.shtml
英文地址,請注明出處:http://en.pswp.cn/news/535185.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Pytorch:保存圖片

1. 直接保存Tensor #!/usr/bin/env python # _*_ coding:utf-8 _*_ import torch from torchvision import utils as vutilsdef save_image_tensor(input_tensor: torch.Tensor, filename):"""將tensor保存為圖片:param input_tensor: 要保存的tensor:param fi…

Python List:合并多個list,listd的合并

第一種方法 a [1,3,3] b [2,3,3] a a b print(a) [1,3,3,2,3,3]第二種方法 a [1,3,3] b [2,3,3] a.extend(b) print(a) [1,3,3,2,3,3]

掛載硬盤問題:mount: wrong fs type, bad option, bad superblock on /dev/sdb,

mount: wrong fs type, bad option, bad superblock on /dev/sdb,missing codepage or helper program, or other error 解決方案: # create mount dir sudo mkdir /hdd6T# new file system sudo mkfs.ext4 /dev/sdc# mount drive sudo mount /dev/sdc /hdd6T/# c…

linux 安裝python3.8的幾種方法

1.命令行搞定 git clone https://github.com/waketzheng/carstino cd carstino python3 upgrade_py.py2.離線安裝 自己在官網下載安裝包 https://www.python.org/ftp/python/3.8.0/ 解壓: tar -zvf Python-3.8.0.tgz安裝 cd Python-3.8.0 ./configure --prefix/u…

面試題目:欠擬合、過擬合及如何防止過擬合

對于深度學習或機器學習模型而言,我們不僅要求它對訓練數據集有很好的擬合(訓練誤差),同時也希望它可以對未知數據集(測試集)有很好的擬合結果(泛化能力),所產生的測試誤…

LaTeX:equation, aligned 書寫公式換行,頂部對齊

使用aligined 函數,其中aligned就是用來公式對齊的,在中間公式中,\ 表示換行, & 表示對齊。在公式中等號之前加&,等號介紹要換行的地方加\就可以了。 \begin{equation} \begin{aligned} L_{task} &\lamb…

Latex: 表格中 自動換行居中

1、在導言區添加宏包: \usepackage{makecell}2、環境:tabular 命令: \makecell[居中情況]{第1行內容 \\ 第2行內容 \\ 第3行內容 ...} \makecell [c]{ResNet101\\ (11.7M)}參數說明: [c]是水平居中,[l]水平左居中&am…

argparse:shell向Python中傳參數

一般是 python train.py --bath_size 5利用argparse解析參數 import argparse parser argparse.ArgumentParser() parser.add_argument(integer, typeint, helpdisplay an integer) args parser.parse_args()參數類型 可選參數 import argparse parser argparse.Argumen…

FTP命令:下載,上傳FTP服務器中的文件

步驟 1: 建立 FTP 連接 想要連接 FTP 服務器,在命令上中先輸入ftp然后空格跟上 FTP 服務器的域名 domain.com 或者 IP 地址例如:1.ftp domain.com2.ftp 192.168.0.13.ftp userftpdomain.com注意: 本例中使用匿名服務器。替換下面例子中 IP 或域名為你的服務器地址。…

Tensorflow代碼轉pytorch代碼 函數的轉換

tensoflow函數和pytorch函數之間的轉換 tensorflowpytrochtf.reshape(input, shape)input.view()tf.expand_dims(input, dim)input.unsqueeze(dim) / input.view()tf.squeeze(input, dim)torch.squeeze(dim)/ input.view()tf.gather(input1, input2)input1[input2]tf.tile(inp…

在服務器上遠程使用tensorboard查看訓練loss和準確率

本人使用的是vscode 很簡單 from torch.utils.tensorboard import SummaryWriter writer SummaryWriter(./logs)writer.add_scalar(train_loss,loss.val(),iteration) # 名字,數據,迭代次數訓練的過程中會產生一個./logs的文件夾,里面存放的…

python,pytorch:讀取,保存,顯示圖片

文章目錄一,Pytorch1. 直接保存Tensor2.Tensor 轉CV2 保存二、python1. opencv2.matplotlib:3. PIL一,Pytorch 1. 直接保存Tensor #!/usr/bin/env python # _*_ coding:utf-8 _*_ import torch from torchvision import utils as vutilsdef save_image…

Python循環創建變量名

使用命名空間locals locals 中是當前程序段中的全部變量名是一個字典的形式 所以我們新增的話,直接和字典那樣就行了 names locals() #獲取當前程序段中的全體局部變量名 for i in np.arange(0,10):names[fname_{i}]i

pytorch:固定部分層參數,固定單個模型

文章目錄固定部分層參數固定指定層的參數不同層設置不同的學習率固定部分層參數 class RESNET_attention(nn.Module):def __init__(self, model, pretrained):super(RESNET_attetnion, self).__init__()self.resnet model(pretrained) # 這個model被固定for p in self.parame…

圖片拼接的幾種方法

1. torch tensor 格式 from torchvision.utils import save_imageimg_train_list torch.cat([image_s,image_r,im_G[0],fake_A,mask_s[:, :, 0],mask_r[:, :, 0]])result_path self.save_path_dir/imageif not os.path.exists(result_path):os.makedirs(result_path)save_im…

Pytorch 各種模塊:降低學習率,

1.訓練過程中學習率衰減 if (self.e1) > (self.num_epochs - self.num_epochs_decay):g_lr - (self.g_lr / float(self.num_epochs_decay))d_lr - (self.d_lr / float(self.num_epochs_decay))self.update_lr(g_lr, d_lr)print(Decay learning rate to g_lr: {}, d_lr:{}..…

cudnn.deterministic = True 固定隨機種子

隨機數種子seed確定時,模型的訓練結果將始終保持一致。 隨機數種子seed確定時使用相同的網絡結構,跑出來的效果完全不同,用的學習率,迭代次數,batch size 都是一樣。 torch.backends.cudnn.deterministic是啥&#x…

torch.backends.cudnn.benchmark 加速訓練

設置 torch.backends.cudnn.benchmarkTrue 將會讓程序在開始時花費一點額外時間,為整個網絡的每個卷積層搜索最適合它的卷積實現算法,進而實現網絡的加速。適用場景是網絡結構固定(不是動態變化的),網絡的輸入形狀&…

各種損失損失函數的使用場景和使用方法:KL散度

KL 散度的使用場景 KL散度( Kullback–Leibler divergence),又稱相對熵,是描述兩個概率分布 P 和 Q 差異的一種方法 torch.nn.functional.kl_div(input, target, size_averageNone, reduceNone, reductionmean) torch.nn.KLDivLoss(input, target, si…

RNN,LSTM,GRU的理解

RNN x 為當前狀態下數據的輸入, h 表示接收到的上一個節點的輸入。 y為當前節點狀態下的輸出,而h′h^\primeh′為傳遞到下一個節點的輸出. LSTM #定義網絡 lstm nn.LSTM(input_size20,hidden_size50,num_layers2) #輸入變量 input_data Variable(tor…