Dataset.map 方法概要
可以將datasets中的Dataset實例看做是一張數據表。map方法會將輸入的function
按照指定的方式應用在每一行(每一行稱為一個example)上。本文采用一下示例進行說明:
from datasets import Dataset # datasets.__version__ = '2.13.0'
x = [{"text": "good", "label": 1}, {"text": "bad", "label": 0}, {"text": "great", "label": 2}]
ds = Dataset.from_list(x)
function
是map方法的核心,其介紹單獨放在下列章節。其它常用參數的說明如下:
- remove_columns:當function可調用對象處理完數據表中的全部樣例之后,可以將指定的列刪掉。在文本處理中,一般會將編碼之后的text原文刪除掉。
- fn_kwargs: 傳入可調用對象function的關鍵詞參數。
- desc: 自定義的描述性信息。
?
function位置參數
function位置參數接受一個可調用對象,本質是該可調用對象對數據表中的每行進行處理。按照布爾型位置參數with_indices, with_rank, batched
的取值, function有8種簽名。其中batched表示可調用對象一次處理一行還是多行,with_indices表示是否將樣本的索引編號傳入可調用對象, with_rank表示是否將進程rank傳入可調用對象。
單樣本處理(batched=False)
-
當設置
batched=False
時,可調用對象依次處理每一行。在默認情況下,可調用對象的簽名為function(example: Dict[str, Any]) -> Dict[str, Any]
,example中的鍵為數據表中的所有列。可使用input_columns
指定哪些列進入example中。可調用對象應該返回一個Dict對象,并使用返回結果更新數據表。from typing import Dict, Anydef func(exam: Dict[str, Any]) -> Dict[str, Any]:return {"text_length": len(exam["text"])}print(ds.map(func).data)
-
如果指定
with_indices=True
,則可調用對象中應該指定一個位置參數用于接受樣本索引編號。def func_with_indices(exam, index):return {"text_length": len(exam["text"]), "row_index": int(index)}print(ds.map(func_with_indices, with_indices=True, batched=False).data)
-
如果指定
with_rank=True
,則可調用對象中應指定一個位置參數用于接受進程rank。并且應當設置num_proc
大于1,否則進程rank的值是相同的。def func_with_rank(exam, rank):return {"text_length": len(exam["text"]), "proc_rank": int(rank)}print(ds.map(func_with_rank, with_rank=True, batched=False, num_proc=2).data)
-
如果同時指定
with_indices=True
與with_rank=True
,則可調用對象中接受樣本索引編號的位置參數應置于接受進程rank位置參數之前。如def function(example, idx, rank): ...
。def func_with_index_rank(exam, index, rank):return {"text_length": len(exam["text"]), "row_index": int(index), "proc_rank": int(rank)}print(ds.map(func_with_rank, with_indices=True, with_rank=True, batched=False, num_proc=2).data)
樣本批處理(Batched=True)
當設置batched=True
時,可調用對象會對樣本進行批處理,批的大小可以通過batch_size
控制,默認一個批為1000條樣本。此情況下簽名應滿足function(batch: Dict[str, List]) -> Dict[str, List]
。batch中的鍵仍然是數據表中的列名稱,值為多行數據組成的列表。
def batched_func(batch):return {"text_length": [len(text) for text in batch["text]]}print(ds.map(batched_func, batched=True, batch_size=2).data)
?
map方法返回的torch.tensor會被轉換為list
x = [{'text': 'good', 'label': torch.tensor(1, dtype=torch.long)},{'text': 'bad', 'label': torch.tensor(0, dtype=torch.long)},{'text': 'great', 'label': torch.tensor(2, dtype=torch.long)}]
ds = Dataset.from_list(x)
print(type(x[0]["label"])) # torch.Tensorprint(ds.map().data)def to_tensor(exam):return {"label" : torch.tensor(exam["label"], dtype=torch.long)}print(ds.map(to_tensor).data) # 結果一致,數據表中的label均為整數型