這段代碼是為了并行地處理多個 CSV 文件,并使用機器學習模型進行預測和回測。主要涉及以下步驟:
-
初始化環境與設置:
- 引入必要的庫,如
ray
用于并行計算,pandas
用于數據處理,tqdm
用于進度條顯示等。 - 設置一些路徑,用于保存結果、圖像、模型等。
- 定義一些處理特征、數據預處理的函數。
- 引入必要的庫,如
-
并行處理函數
csv_predict
:- 使用
ray.remote
將csv_predict
函數并行化。 - 在每個函數中,加載訓練好的模型,并對新的 CSV 文件進行預測和回測。
- 使用
-
具體步驟:
- 讀取 CSV 文件:讀取并處理每個 CSV 文件,確保數據格式正確。
- 數據預處理:包括特征計算、標準化等。
- 構建驗證數據集:將處理后的數據轉換為模型可接受的格式。
- 預測與回測:使用模型對數據進行預測,并根據預測結果進行回測計算,模擬交易策略。
-
結果保存:
- 根據回測結果,將交易數據保存到不同的文件夾中。
- 以不同的策略和條件,將結果分門別類保存。
代碼解讀
import ray# 驗證集數據處理
a = []
sum_dam_data = []# 定義并行處理函數
@ray.remote
def csv_predict(csv_path):# 創建和訓練模型參數nhits_params = {'sampling_stride': 8,'eval_metrics': ["mse", "mae"],'batch_size': 32,'max_epochs': 100,'patience': 10}rnn_params = {'sampling_stride': 8,'eval_metrics': ["mse", "mae"],'batch_size': 32,'max_epochs': 100,'patience': 10,}mlp_params = {'sampling_stride': 8,'eval_metrics': ["mse", "mae"],'batch_size': 32,'max_epochs': 100,'patience': 10,'use_bn': True,}# 加載訓練好的加權集成預測模型reg = WeightingEnsembleForecaster(in_chunk_len=64,out_chunk_len=1,skip_chunk_len=0,estimators=[(NHiTSModel, nhits_params), (RNNBlockRegressor, rnn_params), (MLPRegressor, mlp_params)])reg = reg.load(os.path.join(model_center, "low_high"))# 讀取 CSV 文件new_data = pd.read_csv(csv_path)new_data[['open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount']] = new_data[['open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'vol', 'amount']].apply(pd.to_numeric, errors='coerce')# 如果數據長度不足,返回空結果if len(new_data) < 2048:return {}base_case = 0base_num = 0money = 0reverse_data = new_data.iloc[::-1] # 反轉數據順序# 計算特征reverse_data = calculate_features(reverse_data)# 逐天進行預測和回測for day_i in range(64):new_data = reverse_data[-256:-64+day_i]new_data['index_new'] = range(1, len(new_data) + 1)# 構建驗證數據集valid_tsdataset = TSDataset.load_from_dataframe(new_data,time_col="index_new",target_cols=['open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'MA5', 'MA10', 'MA20', 'EMA12', 'EMA26', 'Volatility_5', 'Volatility_10', 'Volume_MA5', 'Volume_Change_Rate', 'RSI14', 'Momentum_3', 'Momentum_7', 'Middle_Band', 'Upper_Band', 'Lower_Band'])valid_tsdataset = scaler.transform(valid_tsdataset)predicted = reg.recursive_predict(valid_tsdataset, 3)predicted = scaler.inverse_transform(predicted)predicted = predicted.to_dataframe()# 根據預測結果進行回測計算high_value = predicted.max().to_dict()['high']low_value = predicted.min().to_dict()['low']round_value = round((high_value - low_value) / low_value, 3) * 1000high_index = predicted[predicted['high'] == high_value].index.values[0] - len(new_data)low_index = predicted[predicted['low'] == low_value].index.values[0] - len(new_data)if high_value > low_value:if high_index > low_index:if base_num < 100000:if reverse_data[-(64 - day_i)][3] > low_value > reverse_data[-(64 - day_i)][4]:base_case += 10000 * low_valuebase_num += 10000money -= 10000 * low_valueif low_index > high_index:high_value = predicted['high'].tolist()[0]if reverse_data[-(64 - day_i)][3] > high_value > reverse_data[-(64 - day_i)][4]:base_case -= base_num * high_valuemoney += base_num * high_valuebase_num = 0else:base_case -= base_num * high_valuemoney += base_num * high_valuebase_num = 0sum_money = money + base_num * reverse_data[-(64 - day_i)][5]# 保存回測結果deal.append({"base_case": base_case,"base_num": base_num,"money": money,"index": reverse_data[-(64 - day_i)][-1],"close": reverse_data[-(64 - day_i)][5] * 10000,"total": base_num * reverse_data[-(64 - day_i)][5],"sum": sum_money,"rate": 100 * (sum_money / (reverse_data[-(64 - day_i)][5] * 10000))})try:pd.DataFrame(deal).to_csv(os.path.join("./back_test/low_high_128_5_100", last_price_data, str(int(deal[-1]['rate'])) + "_" + export_csv),index=False)except Exception as e:print(e)returnif deal[-1]['rate'] > 10:if pd.DataFrame(deal)['rate'].sum() > 0:pd.DataFrame(deal).to_csv(os.path.join("./back_test/good_low_high_5_100_deal_101", last_price_data, str(int(deal[-1]['rate'])) + "_" + export_csv),index=False)if deal[-1]['rate'] > 50:if pd.DataFrame(deal)['rate'].sum() > 0:pd.DataFrame(deal).to_csv(os.path.join("./back_test/good_low_high_5_100_deal_105", last_price_data, str(int(deal[-1]['rate'])) + "_" + export_csv),index=False)
主要功能
-
模型加載與預測:
- 加載預訓練模型
WeightingEnsembleForecaster
并進行預測。 - 預測未來幾天的高低價格。
- 加載預訓練模型
-
回測策略:
- 根據預測的高低價進行模擬交易,計算收益。
- 基于交易規則買入或賣出,計算資金和持倉。
-
結果保存:
- 將回測結果保存到 CSV 文件中。
- 根據不同的收益率將結果分開保存。
使用說明
- 確保已安裝
ray
庫用于并行計算。 - 確保所有依賴庫(如
pandas
,paddlets
,tqdm
等)已安裝。 - 將代碼中的路徑和參數調整為實際數據和模型的位置。
- 運行代碼,通過
ray
并行處理多個 CSV 文件,提高處理效率。
注意事項
- 確保數據格式和模型參數與實際情況匹配。
- 在并行化時,要確保每個子任務的獨立性,避免數據沖突。
- 根據需要調整回測策略和交易規則,以滿足實際需求。