Tensorflow代碼轉pytorch代碼 函數的轉換

tensoflow函數和pytorch函數之間的轉換

tensorflowpytroch
tf.reshape(input, shape)input.view()
tf.expand_dims(input, dim)input.unsqueeze(dim) / input.view()
tf.squeeze(input, dim)torch.squeeze(dim)/ input.view()
tf.gather(input1, input2)input1[input2]
tf.tile(input, shape)input.repeat(shape)
tf.boolean_mask(input, mask)input[mask] #注意,mask是bool值,不是0,1的數值
tf.concat(input1, input2)torch.cat(input1, input2)
tf.matmul()torch.matmul()
tf.minium(input, min)torch.clamp(input, max=min)
tf.equal(input1, input2)torch.eq(input1, input2)/ input1 == input2
tf.logical_and(input1, input2)input1 & input2
tf.logical_not(input) ~input
tf.reduce_logsumexp(input, [dim])torch.logsumexp(input, dim=dim)
tf.reduce_any(input, dim)input.any(dim)
tf.reduce_mean(input)torch.mean(input)
tf.reduce_sum(input)input.sum()
tf.transpose(input)input.t()
tf.softmax_cross_entroy_with_logits(logits, labels)torch.nn.CrossEntropyLoss(logits, labels)

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/535175.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/535175.shtml
英文地址,請注明出處:http://en.pswp.cn/news/535175.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

在服務器上遠程使用tensorboard查看訓練loss和準確率

本人使用的是vscode 很簡單 from torch.utils.tensorboard import SummaryWriter writer SummaryWriter(./logs)writer.add_scalar(train_loss,loss.val(),iteration) # 名字,數據,迭代次數訓練的過程中會產生一個./logs的文件夾,里面存放的…

python,pytorch:讀取,保存,顯示圖片

文章目錄一,Pytorch1. 直接保存Tensor2.Tensor 轉CV2 保存二、python1. opencv2.matplotlib:3. PIL一,Pytorch 1. 直接保存Tensor #!/usr/bin/env python # _*_ coding:utf-8 _*_ import torch from torchvision import utils as vutilsdef save_image…

Python循環創建變量名

使用命名空間locals locals 中是當前程序段中的全部變量名是一個字典的形式 所以我們新增的話,直接和字典那樣就行了 names locals() #獲取當前程序段中的全體局部變量名 for i in np.arange(0,10):names[fname_{i}]i

pytorch:固定部分層參數,固定單個模型

文章目錄固定部分層參數固定指定層的參數不同層設置不同的學習率固定部分層參數 class RESNET_attention(nn.Module):def __init__(self, model, pretrained):super(RESNET_attetnion, self).__init__()self.resnet model(pretrained) # 這個model被固定for p in self.parame…

圖片拼接的幾種方法

1. torch tensor 格式 from torchvision.utils import save_imageimg_train_list torch.cat([image_s,image_r,im_G[0],fake_A,mask_s[:, :, 0],mask_r[:, :, 0]])result_path self.save_path_dir/imageif not os.path.exists(result_path):os.makedirs(result_path)save_im…

Pytorch 各種模塊:降低學習率,

1.訓練過程中學習率衰減 if (self.e1) > (self.num_epochs - self.num_epochs_decay):g_lr - (self.g_lr / float(self.num_epochs_decay))d_lr - (self.d_lr / float(self.num_epochs_decay))self.update_lr(g_lr, d_lr)print(Decay learning rate to g_lr: {}, d_lr:{}..…

cudnn.deterministic = True 固定隨機種子

隨機數種子seed確定時,模型的訓練結果將始終保持一致。 隨機數種子seed確定時使用相同的網絡結構,跑出來的效果完全不同,用的學習率,迭代次數,batch size 都是一樣。 torch.backends.cudnn.deterministic是啥&#x…

torch.backends.cudnn.benchmark 加速訓練

設置 torch.backends.cudnn.benchmarkTrue 將會讓程序在開始時花費一點額外時間,為整個網絡的每個卷積層搜索最適合它的卷積實現算法,進而實現網絡的加速。適用場景是網絡結構固定(不是動態變化的),網絡的輸入形狀&…

各種損失損失函數的使用場景和使用方法:KL散度

KL 散度的使用場景 KL散度( Kullback–Leibler divergence),又稱相對熵,是描述兩個概率分布 P 和 Q 差異的一種方法 torch.nn.functional.kl_div(input, target, size_averageNone, reduceNone, reductionmean) torch.nn.KLDivLoss(input, target, si…

RNN,LSTM,GRU的理解

RNN x 為當前狀態下數據的輸入, h 表示接收到的上一個節點的輸入。 y為當前節點狀態下的輸出,而h′h^\primeh′為傳遞到下一個節點的輸出. LSTM #定義網絡 lstm nn.LSTM(input_size20,hidden_size50,num_layers2) #輸入變量 input_data Variable(tor…

常用的loss函數,以及在訓練中的使用

文章目錄KL 散度L2 loss做標準化處理CElossCTCLossAdaptiveAvgPool2dKL 散度 算KL散度的時候要注意前后順序以及加log import torhch.nn as nn d_loss nn.KLDivLoss(reductionreduction_kd)(F.log_softmax(y / T, dim1),F.softmax(teacher_scores / T, dim1)) * T * T蒸餾lo…

Shell 在訓練模型的時候自動保存訓練文件和模型到指定文件夾

在進行深度學習訓練的過程中,往往會跑很多實驗,這就導致有的實驗設置會忘記或者記混淆,我們最好把train test model 的代碼都copy一遍到指定文件夾中,這樣后面檢查也方便。 用shell指令保存文件 #!/bin/sh GRUB_CMDLINE_LINUX&qu…

Pytorch:數據并行和模型并行,解決訓練過程中內存分配不均衡的問題

文章目錄數據并行單機多卡訓練,即并行訓練。并行訓練又分為數據并行 (Data Parallelism) 和模型并行兩種。 數據并行指的是,多張 GPU 使用相同的模型副本,但是使用不同的數據批進行訓練。而模型并行指的是,多張GPU 分別訓練模型的…

DataParallel 和 DistributedDataParallel 的區別和使用方法

1.DataParallel DataParallel更易于使用(只需簡單包裝單GPU模型)。 model nn.DataParallel(model)它使用一個進程來計算模型參數,然后在每個批處理期間將分發到每個GPU,然后每個GPU計算各自的梯度,然后匯總到GPU0中…

torch.cuda.is_available(),torch.cuda.device_count(),torch.cuda.get_device_name(0)

torch.cuda.is_available() cuda是否可用; torch.cuda.device_count() 返回gpu數量; torch.cuda.get_device_name(0) 返回gpu名字,設備索引默認從0開始; torch.cuda.current_device() 返回當前設備索引;

windows, 放方向鍵設置為vim格式,autohotkey-windows

安裝 Autohotkey https://www.autohotkey.com/download/ 設置快捷鍵 隨便找個目錄,鼠標右鍵新建一個autohotkey的腳本。 映射一個鍵——上左下右 經常打字的人都知道,我們編輯文本時要上下左右移動光標,難免要將手移到方向鍵再移回來打字。對我這樣的懶癌后期患者,這簡直不能…

window設置快捷鍵左右方向鍵

autohotkey-windows快捷鍵設置神器 使用方法 地址

Hbase數據模型及Hbase Shell

目錄 1 數據模型 1.1 相關名詞概念 1.2 模型分析 2 Hbase Shell操作 2.1 命名空間 2.2 表操作 2.2.1 創建表 2.2.2 更改表結構 2.2.3 表的其他操作 2.3 數據操作 2.3.1 添加數據(put) 2.3.2 刪除數據(delete) 2.3.3 獲取數據(get|scan) 3 過濾器 3.1 比較運算符…

非關型數據庫之Hbase

目錄 1 Hbase簡介 1.1 初識Hbase 1.2 Hbase的特性 2 HDFS專項模塊 2.1 HDFS的基本架構 2.1.1 HDFS各組件的功能: 2.2 HFDFS多種機制 2.2.1 分塊機制 2.2.2 副本機制 2.2.3 容錯機制 2.2.4 讀寫機制 3 Hbase組件及其功能 3.1 客戶端 3.2 Zookeeper 3.3 …

MongoDB Shell操作

目錄 1 數據庫操作 2 集合操作 3 文檔操作 3.1 插入文檔(insert|insertOne|insertMany) 3.2插入、刪除的循環操作 3.2 刪除文檔(remove|deleteOne|deleteMany) 3.3 更新文檔(update|save) 3.4 查詢文檔(find) 4 游標 5 索引 6 聚合 1 數據庫操作 當新創建的數據庫里…