目錄
■導言
①定義模型
②加載文本數據
③加載預訓練模型
④測試動態量化
■結論
■導言
量化涉及將模型的權重和激活從float轉換為int,這可以導致更小的模型大小和更快的推理,并且只對準確性造成很小的影響。
本文將把最簡單的量化形式-動態量化-應用于基于lstm的下一個單詞預測模型,與PyTorch示例中的單詞語言模型非常相似。
# importsimport osfrom io import openimport timeimport torchimport torch.nn as nnimport torch.nn.functional as F
①定義模型
定義LSTM模型體系結構,遵循單詞語言模型示例中的模型。
# 定義一個包含編碼器、循環層和解碼器的LSTM模型
class LSTMModel(nn.Module):"""容器模塊,包含編碼器、遞歸模塊(LSTM)和解碼器。"""def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):"""初始化LSTM模型。參數:ntoken (int): 詞匯表大小,即輸入數據的類別數。ninp (int): 輸入層維度,即詞嵌入的維度。nhid (int): 隱藏層維度,即LSTM中每個單元的隱藏狀態維度。nlayers (int): LSTM的層數。dropout (float): Dropout率,用于防止過擬合,默認為0.5。"""super(LSTMModel, self).__init__()# Dropout層,用于在訓練過程中隨機丟棄一部分神經元,防止過擬合self.drop = nn.Dropout(dropout)# 編碼器層:將輸入的離散詞索引轉換為密集向量表示(詞嵌入)self.encoder = nn.Embedding(ntoken, ninp)# LSTM層:負責處理序列數據,提取時序特征self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)# 解碼器層:將LSTM輸出的隱藏狀態映射回詞匯表空間,預測下一個詞self.decoder = nn.Linear(nhid, ntoken)# 初始化模型參數self.init_weights()# 保存隱藏層維度和網絡層數供后續使用self.nhid = nhidself.nlayers = nlayersdef init_weights(self):"""初始化模型參數,采用均勻分布進行初始化"""initrange = 0.1# 對編碼器的權重進行初始化self.encoder.weight.data.uniform_(-initrange, initrange)# 解碼器偏置初始化為0self.decoder.bias.data.zero_()# 解碼器權重也使用相同的均勻分布初始化self.decoder.weight.data.uniform_(-initrange, initrange)def forward(self, input, hidden):"""前向傳播函數。參數:input (Tensor): 當前批次的輸入數據,形狀為(seq_len, batch_size)。hidden (tuple): 初始隱藏狀態(h0, c0)。返回:decoded (Tensor): 輸出結果,表示每個時間步的預測詞概率。hidden (tuple): 更新后的隱藏狀態。"""# 將輸入通過編碼器轉化為嵌入向量,并應用Dropoutemb = self.drop(self.encoder(input))# 將嵌入后的數據傳入LSTM進行處理,得到輸出和新的隱藏狀態output, hidden = self.rnn(emb, hidden)# 對LSTM的輸出應用Dropoutoutput = self.drop(output)# 通過解碼器將輸出映射到詞匯表空間,得到每個詞的概率分布decoded = self.decoder(output)return decoded, hiddendef init_hidden(self, bsz):"""初始化隱藏狀態(h0 和 c0),通常在每輪開始時調用。參數:bsz (int): batch size,當前批次的大小。返回:tuple: 包含初始隱藏狀態的兩個張量(h0, c0),形狀均為(nlayers, bsz, nhid)"""weight = next(self.parameters()) # 獲取第一個參數作為參考,創建相同類型的張量# 初始化為全零張量return (weight.new_zeros(self.nlayers, bsz, self.nhid),weight.new_zeros(self.nlayers, bsz, self.nhid))
這段代碼定義了一個基于 LSTM 的語言模型,主要包括以下幾個部分:
- 編碼器:將輸入的詞索引轉換為詞向量(Embedding)。
- LSTM 層:處理序列數據,提取時序信息。
- 解碼器:將 LSTM 的輸出映射回詞匯表空間,用于預測下一個詞。
- Dropout:防止訓練過程中的過擬合。
- 參數初始化:對 Embedding 和 Linear 層的參數進行均勻分布初始化。
- 隱藏狀態管理:提供初始化隱藏狀態的方法,便于模型在每次新序列開始時重置記憶。
②加載文本數據
接下來,加載 Wikitext-2?數據集進入Corpus, 再次跟隨單詞語言模型預處理示例。
# 定義一個詞典類,用于建立詞語與索引之間的映射
class Dictionary(object):def __init__(self):# word2idx: 詞語到索引的映射字典self.word2idx = {}# idx2word: 索引到詞語的列表(用于反查)self.idx2word = []def add_word(self, word):"""向詞典中添加一個詞。如果該詞尚未存在,則將其加入列表和字典;否則直接返回已有的索引。參數:word (str): 要添加的詞語返回:int: 該詞對應的索引"""if word not in self.word2idx:self.idx2word.append(word)self.word2idx[word] = len(self.idx2word) - 1return self.word2idx[word]def __len__(self):"""返回詞典中不同詞語的數量"""return len(self.idx2word)# 定義語料庫類,用于加載和處理文本數據集
class Corpus(object):def __init__(self, path):"""初始化語料庫對象,并加載訓練、驗證和測試數據。參數:path (str): 數據集存放路徑"""self.dictionary = Dictionary()# 加載并分詞訓練集、驗證集和測試集self.train = self.tokenize(os.path.join(path, 'train.txt'))self.valid = self.tokenize(os.path.join(path, 'valid.txt'))self.test = self.tokenize(os.path.join(path, 'test.txt'))def tokenize(self, path):"""對指定路徑的文本文件進行分詞處理。步驟:1. 遍歷文件內容,將所有詞語加入詞典;2. 再次遍歷文件,將每句話轉換為對應的索引張量;3. 將所有句子的索引拼接成一個大的張量返回。參數:path (str): 文件路徑返回:Tensor: 包含整個文件詞匯索引的一維張量"""assert os.path.exists(path)# 第一步:讀取文本并構建詞典with open(path, 'r', encoding="utf8") as f:for line in f:words = line.split() + ['<eos>'] # 每句話以 '<eos>' 結尾表示結束for word in words:self.dictionary.add_word(word)# 第二步:將文本轉換為索引張量with open(path, 'r', encoding="utf8") as f:idss = [] # 存儲每個句子的索引張量for line in f:words = line.split() + ['<eos>']ids = []for word in words:ids.append(self.dictionary.word2idx[word])idss.append(torch.tensor(ids).type(torch.int64)) # 轉換為PyTorch張量ids = torch.cat(idss) # 將多個句子拼接為一個一維張量return ids# 設置數據集路徑
model_data_filepath = 'data/'# 創建語料庫對象,加載WikiText-2數據集
corpus = Corpus(model_data_filepath + 'wikitext-2')
這段代碼的主要功能是:
Dictionary
?類:構建詞典,實現詞語與索引之間的雙向映射。Corpus
?類:加載并處理文本數據集,包括訓練集、驗證集和測試集。
使用?
tokenize
?方法將文本轉換為詞索引張量。每句話末尾添加特殊標記?
<eos>
?表示句子結束。
- 最終通過?
corpus
?實例加載 WikiText-2 數據集,可用于后續模型訓練或評估。
③加載預訓練模型
將一些預訓練的權重加載到這個模型架構中。下載所需的預訓練模型:
wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth
?
將下載的文件放在數據目錄中或相應地更新 model_data_filepath。
ntokens = len(corpus.dictionary)model = LSTMModel(ntoken = ntokens,ninp = 512,nhid = 256,nlayers = 5,
)model.load_state_dict(torch.load(model_data_filepath + 'word_language_model_quantize.pth',map_location=torch.device('cpu'),weights_only=True))model.eval()
print(model)
輸出:
LSTMModel((drop): Dropout(p=0.5, inplace=False)(encoder): Embedding(33278, 512)(rnn): LSTM(512, 256, num_layers=5, dropout=0.5)(decoder): Linear(in_features=256, out_features=33278, bias=True))
現在生成一些文本,以確保預訓練模型工作正確。
# 初始化一個隨機輸入,表示起始詞的索引(形狀為 (1, 1),即一個詞)
input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)# 初始化模型的隱藏狀態,batch_size = 1
hidden = model.init_hidden(1)# 設置溫度參數,用于控制生成結果的隨機性:
# 溫度越高,輸出越隨機;溫度越低,輸出越確定。
temperature = 1.0# 要生成的總詞數
num_words = 1000# 打開文件以寫入生成的文本
with open(model_data_filepath + 'out.txt', 'w') as outf:with torch.no_grad(): # 不需要計算梯度,加快推理速度for i in range(num_words):# 模型前向傳播,得到當前詞的輸出和新的隱藏狀態output, hidden = model(input_, hidden)# 對輸出進行處理,得到每個詞的概率分布word_weights = output.squeeze().div(temperature).exp().cpu()# 根據概率分布隨機選取下一個詞的索引word_idx = torch.multinomial(word_weights, 1)[0]# 將當前預測的詞作為下一次生成的輸入input_.fill_(word_idx)# 將索引轉換為實際詞語word = corpus.dictionary.idx2word[word_idx]# 寫入文件,每20個詞換行,否則空格分隔outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))# 每生成100詞打印進度if i % 100 == 0:print('| 已生成 {}/{} 個詞'.format(i, num_words))# 讀取并打印生成的全部文本
with open(model_data_filepath + 'out.txt', 'r') as outf:all_output = outf.read()print(all_output)
輸出:
| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'.' b'David' b'<unk>' b'states' b'the' b'album' b'of' b'the' b'key' b'(' b'3' b'@.@' b'2' b'miles' b'per' b'hour' b'destructive' b'73' b'@.@' b'8'
b'm' b')' b'ten' b'years' b'with' b'edible' b'intellectual' b'instruments' b',' b'that' b'was' b'subdivided' b'into' b'an' b'star' b'Hampshires' b'.' b'1981' b',' b'Megan'
b'Room' b'campaigned' b'in' b'1956' b'in' b'Jacob' b'Lake' b'in' b'Floyd' b'of' b'Garden' b'which' b'was' b'introduced' b'by' b'Enuff' b'<unk>' b',' b'<unk>' b'of'
b'a' b'special' b'state' b'opens' b'with' b'a' b'Amusements' b'from' b'North' b'Korea' b'and' b'Temple' b'County' b'.' b'Everson' b',' b'with' b'a' b'Shanghai' b'ultimate'
b'potential' b'play' b',' b'also' b'on' b'October' b'7' b',' b'1848' b',' b'and' b'collaborated' b'up' b'an' b'main' b'(' b'three' b'species' b'hills' b')'
b'.' b'<eos>' b'A.' b'galericulata' b'is' b'a' b'popular' b'white' b'post' b'@-@' b'spored' b',' b'raising' b'in' b'the' b'fourth' b'century' b',' b'and' b'it'
b'was' b'many' b'450' b'claims' b'where' b'there' b'are' b'no' b'official' b'amounts' b'of' b'digital' b'species' b'that' b'are' b'found' b'to' b'be' b'mild' b'.'
b'<eos>' b'<eos>' b'=' b'=' b'Musical' b'records' b'=' b'=' b'<eos>' b'<eos>' b'Gods' b'Narvik' b'a' b'year' b'meets' b'to' b'be' b'found' b'for' b'Hawaii'
b',' b'and' b'the' b'common' b'starling' b'has' b'also' b'described' b'it' b'weak' b'community' b'kings' b'.' b'Continuing' b'regarding' b'<unk>' b',' b'they' b'were' b'partially'
b'deeply' b'distinguished' b'in' b'Ireland' b'.' b'The' b'range' b'of' b'Bullet' b'material' b'is' b'a' b'non' b'@-@' b'imposed' b'planet' b'each' b',' b'even' b'outright'
b'terminology' b',' b'usually' b'available' b',' b'by' b'Russia' b'.' b'The' b'spotless' b'transit' b'along' b'a' b'24' b'Bulletin' b'and' b'could' b'be' b'expected' b'.'
b'There' b'are' b'90' b'mg' b'phallic' b'method' b'<unk>' b'as' b',' b'between' b'2016' b'and' b'1913' b',' b'and' b'first' b'a' b'important' b'sensors' b'since'
b'it' b'has' b'reached' b'a' b'clade' b',' b'distinguished' b',' b'other' b'predators' b'and' b'air' b'long' b'.' b'There' b'are' b'no' b'time' b'of' b'Ceres'
b'in' b'within' b'two' b'years' b'.' b'Five' b'size' b',' b'with' b'35' b'@,@' b'Catholics' b',' b'is' b'proposed' b'to' b'do' b'so' b'unless' b'they'
b'were' b'extremes' b'to' b'ask' b'songs' b'.' b'Chapels' b',' b'until' b'January' b'1988' b',' b'when' b'Dawn' b"'s" b'gravity' b',' b'even' b'eaten' b'by'
b'another' b'shortage' b'.' b'It' b'is' b'possible' b'that' b'group' b'suspended' b',' b'because' b'of' b'the' b'Intermediate' b',' b'related' b'after' b'this' b',' b'they'
b'makes' b'as' b'two' b'as' b'they' b'affect' b'them' b',' b'merged' b'up' b'on' b'DAGs' b'.' b'In' b'Molecular' b'<unk>' b',' b'males' b'may' b'be'
b'occasionally' b'closed' b'in' b'England' b',' b'eye' b'children' b'and' b'unrelated' b'attempts' b'to' b'start' b'variable' b'and' b'coordination' b'.' b'<unk>' b',' b'which' b','
b'unlike' b'other' b'birds' b',' b'mainly' b'fresh' b',' b'<unk>' b'they' b'are' b'<unk>' b'and' b'they' b'are' b'spotted' b'in' b'trees' b'.' b'This' b'kakapo'
b'can' b'be' b'of' b'one' b'side' b'of' b'William' b'<unk>' b'to' b'produce' b'for' b'electron' b'behaviour' b'prove' b'.' b'In' b'the' b'19th' b'century' b','
b'the' b'Skye' b'mural' b'(' b'immigrants' b'has' b'the' b'highest' b'birds' b')' b',' b'to' b'look' b'round' b'or' b'<unk>' b'.' b'With' b'8' b'million'
b'tons' b'(' b'13' b'@.@' b'5' b'in' b')' b'or' b'knowledge' b',' b'their' b'mass' b'can' b'have' b'relocated' b'to' b'the' b'distances' b'of' b'peak'
b'birds' b'.' b'<eos>' b'Mycena' b'galericulata' b'reaches' b'a' b'very' b'apparent' b'amount' b'of' b'scholars' b',' b'as' b'of' b'late' b'November' b'230' b',' b'2006'
b',' b'will' b'buy' b'the' b'overlap' b',' b'especially' b'by' b'<unk>' b'Trypanosoma' b',' b'<unk>' b',' b'and' b'sur' b'Notably' b'\xe2\x80\x93' b'Africa' b'.' b'The'
b'birds' b"'" b'clutch' b'surface' b'weight' b'and' b'kakapo' b'asserts' b'that' b'they' b'are' b'serving' b'as' b'"' b'<unk>' b'"' b',' b'a' b'flock' b'that'
b'feed' b'increase' b'by' b'Call' b'daylight' b'on' b'22' b'August' b'1801' b'.' b'The' b'surviving' b'species' b'of' b'this' b'similar' b'classification' b'is' b'meant' b'to'
b'be' b'(' b'"' b'Always' b'Bird' b'"' b')' b'that' b'could' b'be' b'seen' b'as' b'the' b'country' b'for' b'their' b'sight' b'as' b'it' b'ruler'
b',' b'as' b'it' b'hit' b'that' b'it' b'is' b'probably' b'more' b'time' b'.' b'<eos>' b'Family' b'therapy' b',' b'about' b'3' b'@.@' b'5' b'million'
b'years' b',' b'suggests' b'of' b'1' b'@,@' b'000' b'to' b'eight' b'kilometres' b'(' b'19' b'@.@' b'4' b'\xe2\x80\x93' b'5' b'a.m.' b')' b'long' b','
b'than' b'which' b'may' b'one' b'reopened' b'of' b'the' b'works' b'.' b'It' b'may' b'exist' b'with' b'agriculture' b'such' b'as' b'local' b'areas' b'as' b'such'
b'as' b'a' b'type' b'that' b'may' b'occur' b'back' b'into' b'the' b'night' b',' b'but' b'further' b'does' b'not' b'include' b'military' b'forests' b'.' b'<eos>'
b'Mycena' b'Evans' b'Wi\xc5\x9bniowiecki' b'suggests' b'that' b'"' b'when' b'early' b'practice' b'carried' b'a' b'subspecies' b'strips' b',' b'adding' b'below' b'too' b'distant' b'or' b'to'
b'be' b'the' b'subject' b'of' b'it' b'during' b'Baby' b'terms' b'.' b'"' b'<eos>' b'The' b'Australian' b'starling' b'"' b'The' b'One' b'best' b'name' b'"'
b'have' b'already' b'sold' b'.' b'There' b'is' b'few' b'people' b'in' b'common' b'areas' b'that' b'may' b'be' b'obtained' b'it' b',' b'but' b'sort' b'of'
b'more' b'late' b'recorded' b',' b'hard' b'Ozawa' b',' b'nucleolar' b'bound' b',' b'and' b'Xemnas' b';' b'and' b'is' b'foraging' b'in' b'2000' b'.' b'<eos>'
b'Northern' b'Ireland' b'is' b'very' b'invisible' b'for' b'their' b'state' b',' b'and' b'in' b'<unk>' b',' b'they' b'once' b'have' b'been' b'became' b'inconclusive' b'.'
b'The' b'Maasai' b'which' b'develop' b'near' b'the' b'breeding' b'Trade' b'Island' b'below' b'of' b'theologian' b"'s" b'husband' b'to' b'increase' b',' b'their' b'mouth' b'Colfer'
b'they' b'are' b'.' b'<eos>' b'Scotland' b'blocks' b'by' b'the' b'species' b'and' b'domains' b'preventing' b'tends' b'to' b'work' b'about' b'into' b'the' b'population' b'.'
b'<unk>' b'of' b'volunteers' b'are' b'referring' b'to' b'other' b'ribosomal' b'motifs' b'where' b'other' b'birds' b'have' b'fallen' b'.' b'For' b'this' b'statistics' b'are' b'short'
b'near' b'horns' b',' b'but' b'R\xc3\xa9union' b'has' b'praised' b'it' b'with' b'his' b'main' b'body' b',' b'whereas' b'admire' b'little' b'eye' b'sequences' b'of' b'Bay'
b'177' b',' b'both' b'of' b'which' b'are' b'more' b'small' b'scoring' b'.' b'<eos>' b'Other' b'groups' b'may' b'have' b'completely' b'allowed' b'lobbying' b'to' b'within'
b'the' b'excuse' b'of' b'cameras' b'.' b'Unlike' b'also' b'droplets' b'or' b'<unk>' b',' b'they' b'are' b'loose' b'simultaneously' b'.' b'They' b'have' b'parallel' b'to'
b'pandemic' b'and' b'often' b'beat' b'contact' b'over' b'about' b'60' b'thousand' b'months' b'old' b'.' b'This' b'bird' b'is' b'three' b'more' b'active' b'.' b'Other'
b'other' b'species' b'were' b'<unk>' b'restricted' b'to' b'the' b'standard' b',' b'leaving' b'<unk>' b',' b'particularly' b'slightly' b'Abdi' b'devil' b',' b'resulting' b'on' b'fifty'
b'<unk>' b',' b'known' b'as' b'an' b'bonnet' b'/' b'possibly' b'scale' b',' b'with' b'simplistic' b'and' b'<unk>' b'.' b'Mycena' b'Josip' b'Roth' b',' b'infected'
b'a' b'distinctive' b'item' b'for' b'moving' b'in' b'County' b'City' b'"' b'Mr' b'<unk>' b'"' b'and' b'"' b'patriarch' b'"' b'as' b'"' b'apprehend' b'"'
b',' b'frame' b'expanding' b'capability' b'between' b'behind' b'history' b',' b'red' b',' b'isolated' b',' b'urine' b'(' b'back' b')' b'which' b'are' b'known' b'with'
b'ASCAP' b'.' b'In' b'fact' b',' b'for' b'26' b'%' b'of' b'a' b'year' b'point' b'around' b'65' b'million' b'50' b'(' b'blue' b'long' b')'
b'.' b'These' b'legs' b'John' b'Europos' b'(' b'A' b'<unk>' b')' b'is' b'a' b'recent' b'.' b'In' b'DD' b'eggs' b',' b'it' b'was' b'described'
b'by' b'Hawks' b'as' b'<unk>' b',' b'an' b'pair' b'of' b'charm' b'.' b'It' b'was' b'also' b'hunted' b'that' b'they' b'lived' b'with' b'interior' b'pagan'
它不是GPT-2,但看起來模型已經開始學習語言結構了!
演示動態量化,只需要再定義幾個輔助函數:
# 設置BPTT(Backpropagation Through Time)的長度為25,即每次處理25個時間步的數據
bptt = 25# 定義損失函數為交叉熵損失函數,用于計算模型輸出與真實標簽之間的誤差
criterion = nn.CrossEntropyLoss()# 測試時使用的batch size為1
eval_batch_size = 1# 創建測試數據集
def batchify(data, bsz):"""將原始數據分割成多個批次,每個批次大小為bsz。參數:data (Tensor): 原始數據bsz (int): 每個批次的大小返回:Tensor: 分批后的數據"""# 計算可以完整分成多少個批次nbatch = data.size(0) // bsz# 去掉不能整除的部分數據data = data.narrow(0, 0, nbatch * bsz)# 將數據均勻分配到每個批次中,并調整形狀為(bsz, -1),然后進行轉置和連續化return data.view(bsz, -1).t().contiguous()# 對測試數據進行分批處理
test_data = batchify(corpus.test, eval_batch_size)# 定義獲取單個批次數據的函數
def get_batch(source, i):"""從source中提取一個批次的數據和對應的標簽。參數:source (Tensor): 輸入數據源i (int): 當前批次的起始位置返回:data (Tensor): 輸入數據target (Tensor): 對應的標簽數據"""# 確定當前批次的時間步長度,不超過bptt且不超過剩余數據長度seq_len = min(bptt, len(source) - 1 - i)# 提取輸入數據data = source[i:i+seq_len]# 提取對應的標簽數據,并將其展平為一維張量target = source[i+1:i+1+seq_len].reshape(-1)return data, target# 重新封裝隱藏狀態,使其脫離歷史梯度
def repackage_hidden(h):"""包裝隱藏狀態以斷開其歷史記錄,防止梯度在反向傳播時回傳到前面的批次。參數:h (Tensor or tuple): 隱藏狀態返回:Tensor or tuple: 脫離歷史后的隱藏狀態"""if isinstance(h, torch.Tensor):return h.detach()else:return tuple(repackage_hidden(v) for v in h)# 定義評估函數
def evaluate(model_, data_source):"""對模型在給定數據源上的表現進行評估。參數:model_ (nn.Module): 訓練好的模型data_source (Tensor): 數據源返回:float: 平均損失值"""# 將模型設置為評估模式,禁用dropout等訓練專用操作model_.eval()total_loss = 0.# 初始化隱藏狀態hidden = model_.init_hidden(eval_batch_size)# 在不計算梯度的情況下進行評估with torch.no_grad():for i in range(0, data_source.size(0) - 1, bptt):# 獲取當前批次的數據和標簽data, targets = get_batch(data_source, i)# 模型前向傳播,得到輸出和新的隱藏狀態output, hidden = model_(data, hidden)# 重新封裝隱藏狀態,避免占用過多內存hidden = repackage_hidden(hidden)# 展平輸出,以便與標簽進行損失計算output_flat = output.view(-1, ntokens)# 累加當前批次的損失total_loss += len(data) * criterion(output_flat, targets).item()# 計算平均損失return total_loss / (len(data_source) - 1)
這段代碼的主要功能是:
- 數據預處理:通過?
batchify
?函數將原始數據劃分為固定大小的批次,適用于模型輸入。 - 批量讀取:
get_batch
?函數從數據源中提取指定位置的一個批次數據及其標簽。 - 隱藏狀態管理:
repackage_hidden
?函數用于切斷隱藏狀態的歷史記錄,防止梯度回傳到前面的批次。 - 模型評估:
evaluate
?函數對模型在測試數據上的性能進行評估,計算并返回平均損失。
④測試動態量化
最后,可以調用torch.quantization.quantize_dynamic。具體地說,
▲nn.LSTM和?nn.Linear模塊將被量化;
▲指定要將權重轉換為int8值。
import torch.quantizationquantized_model = torch.quantization.quantize_dynamic(model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
輸出:
LSTMModel((drop): Dropout(p=0.5, inplace=False)(encoder): Embedding(33278, 512)(rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)(decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)
模型看起來是一樣的,這如何受益?首先,可以看到顯著減小了模型尺寸。
def print_size_of_model(model):torch.save(model.state_dict(), "temp.p")print('Size (MB):', os.path.getsize("temp.p")/1e6)os.remove('temp.p')print_size_of_model(model)
print_size_of_model(quantized_model)
輸出:
Size (MB): 113.944455
Size (MB): 79.738939
其次,可以看到更快的推理時間,評估損失沒有差異。
注意:將線程數設置為單個線程比較的線程數,因為量化模型運行單線程。
# 設置 PyTorch 使用單個線程,以便更準確地測量模型推理時間
torch.set_num_threads(1)def time_model_evaluation(model, test_data):"""評估模型在測試數據上的性能,并統計推理時間。參數:model (nn.Module): 要評估的模型test_data (Tensor): 測試數據返回:打印模型損失和所用時間"""s = time.time() # 記錄開始時間loss = evaluate(model, test_data) # 調用評估函數計算損失elapsed = time.time() - s # 計算耗時# 打印損失值和耗時(秒)print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))# 分別對原始模型和量化模型進行評估并計時
time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)
這段代碼的主要目的是:
- 限制線程數:通過?
torch.set_num_threads(1)
?確保模型推理只使用一個線程,避免多線程影響時間測量準確性。 - 定義評估計時函數?
time_model_evaluation
:- 對輸入模型在測試數據上做評估;
- 輸出交叉熵損失和推理耗時。
- 對比評估兩個模型:
- 原始浮點模型?
model
- 量化后的模型?
quantized_model
- 原始浮點模型?
這通常用于比較量化前后模型的推理速度和精度損失,是模型壓縮與優化中的常見做法。
輸出:
loss: 5.167
elapsed time (seconds): 198.3
loss: 5.168
elapsed time (seconds): 111.4
在 MacBook Pro上本地運行此功能,無需量化,推理大約需要 200 秒, 量化只需要大約100秒。
■結論
動態量化可以是一種減小模型尺寸的簡單方法,而只對準確性帶來有限的影響。
至此,本文分享的內容就結束了。