官方github:https://github.com/hiyouga/EasyR1
參考:https://opendeep.wiki/hiyouga/EasyR1/quickstart
代碼和環境配置
github:https://github.com/hiyouga/EasyR1
新建一個虛擬環境:
python -m venv easyr1
source easyr1/bin/activate
python -m pip install transformers==4.51.0
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
python -m pip install wheel
python -m pip install flash-attn==2.7.4.post1
python -m pip install vllm==0.8.3
安裝:
git clone https://github.com/hiyouga/EasyR1.git
cd EasyR1
pip install -e .
數據集
text數據集:https://huggingface.co/datasets/hiyouga/math12k
參考數據集構建
參考代碼:
import json
import os
from datasets import Dataset, DatasetDictdef generate_data(data_path: str):with open(data_path, "r", encoding="utf-8") as f:for line in f:data = json.loads(line)yield {"problem": data["problem"],"answer": data["answer"],}def main():trainset = Dataset.from_generator(generate_data, gen_kwargs={"data_path": os.path.join("prm800k", "math_splits", "train.jsonl")})testset = Dataset.from_generator(generate_data, gen_kwargs={"data_path": os.path.join("prm800k", "math_splits", "test.jsonl")})dataset = DatasetDict({"train": trainset, "test": testset})dataset.push_to_hub("hiyouga/math12k")if __name__ == "__main__":main()
主要修改的參數
參數含義
參數含義2
config路徑:examples/config.yaml
data
- train_files訓練集路徑
- val_files測試集路徑
- max_prompt_length:輸入長度限制
- max_response_length:輸出長度限制
- rollout_batch_size:
- mini_rollout_batch_size:
- format_prompt:根據llm來定對應的jinja文件
worker
- actor.model:模型路徑
- rollout.n:一條數據組內采樣幾條樣本,默認5,我設置的8
- reward.reward_function:reward函數路徑。
trainer
- experiment_name:實驗名稱
遇到的報錯
- 代碼卡在“Started a local Ray instance. View the dashboard at 127.0.0.1:8265”不動
- failed to register worker to ralylet: IOError
這倆問題合在一起。
參考解決方式1
參考解決方式2
做法:
所有的bs都改成1,除了global_batch_size是gpu數量。并rollout batch_size的必須是global_batch_size的倍數,我給rollout_batch_size開了8或16。
代碼路徑:verl/trainer/config.py
調整參數
worker:reward:num_cpus: 1
此外強制修改num_cpus:
/mnt/gemininjceph3/geminicephfs/mmsearch-luban-universal/group_2/user_skylarshao/EasyR1/verl/trainer/main.py
ray.init(runtime_env=runtime_env)
改成ray.init(runtime_env=runtime_env, num_cpus=1)