Dataset
- 是否需要自己定義:如果你使用的數據集不是 PyTorch 提供的標準數據集(如 MNIST、CIFAR-10 等),那么你需要繼承?
torch.utils.data.Dataset
?類并實現兩個方法:__len__()
?和?__getitem__()
。 __len__()
?應該返回數據集的總大小。__getitem__()
?應該根據索引返回一個數據樣本。
DataLoader
- 是否需要自己定義:
DataLoader
?不需要自己定義,它是 PyTorch 提供的一個類,用于包裝?Dataset
?并在數據集上提供迭代功能。它支持批量處理、打亂數據、多線程加載等。 - 使用?
DataLoader
?時,你可以指定批處理大小(batch_size
)、是否打亂數據(shuffle
)、數據加載的線程數(num_workers
)等。
model定義【繼承nn.module父類】
forward:input--forward-->output
forward(self,x)中x表示輸入,即x->卷積->relu->卷積-->relu-->輸出
class HeightPredictor(nn.Module):def __init__(self):super(HeightPredictor, self).__init__()self.conv1 = nn.Conv2d(1,20,5)self.conv2 = nn.Conv2d(20,20,5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))
Dict
building_info = {}【dict,key--value】
這是一個字典(dictionary)的創建語句。在Python中,字典是一種可變的、無序的、鍵值對(key-value pairs)的集合。每個鍵(key)都是唯一的,且必須是不可變的類型(如字符串、數字或元組),而值(value)可以是任何類型的數據。字典通過鍵來訪問對應的值,提供了快速查找和插入的能力。
特殊:defaultdict:defaultdict
是Python標準庫collections
模塊中的一個類。defaultdict
與普通字典類似,但它在創建時提供了一個默認工廠函數【比如defaultdict(list):當訪問一個不存在的鍵時,defaultdict
會自動為該鍵創建一個空列表作為默認值。】,當嘗試訪問一個不存在的鍵時,defaultdict
會自動為該鍵創建一個默認值,而不會拋出KeyError
。
整理csv
df = pd.read_csv(file_path, encoding="utf-8")#讀取csv
#根據某個屬性分組
area_bins = [0, 100, 200, 300, 400, np.inf]
area_labels = [f"{left}-{right}" if right != np.inf else f">{left}"
? ? ? ? ? ? ? for left, right in zip(area_bins[:-1], area_bins[1:])]
df['area_bins'] = pd.cut(df['area'], bins=area_bins, labels=area_labels)
methods = ["a","b"]
attributes = ['material', 'height_bin']
for attr in attributes:
? ? results = []
? ? for method in methods:
? ? ? ? ? ? Num_col = f"{method}_Num"
? ? ? ? ? ? predict_col = f"{method}_predict"
? ? ? ? ? ? if Num_col not in df.columns or predict_col not in df.columns:
? ? ? ? ? ? ? ? print(f"跳過 {method},缺少必要列")
? ? ? ? ? ? ? ? continue
? ? ? ? ? ?
? ? ? ? ? ? valid_data = df[['true_Num', Num_col, predict_col, attr]].dropna()
? ? ? ? ? ? if valid_data.empty:
? ? ? ? ? ? ? ? print(f"{method} 在屬性 {attr} 下無有效數據")
? ? ? ? ? ? ? ? continue
? ? ? ? ? ?
? ? ? ? ? ? # 計算完整指標
? ? ? ? ? ? grouped = valid_data.groupby(attr).apply(
? ? ? ? ? ? ? ? lambda x: pd.Series({
? ? ? ? ? ? ? ? ? ? 'Ori_RMSE': np.sqrt(mean_squared_error(x['true_Num'], x[Num_col])),
? ? ? ? ? ? ? ? ? ? 'Pred_RMSE': np.sqrt(mean_squared_error(x['true_Num'], x[predict_col])),
? ? ? ? ? ? ? ? ? ? 'Ori_MAE': mean_absolute_error(x['true_Num'], x[Num_col]),
? ? ? ? ? ? ? ? ? ? 'Pred_MAE': mean_absolute_error(x['true_Num'], x[predict_col]),
? ? ? ? ? ? ? ? ? ? 'Group_Size': len(x),
? ? ? ? ? ? ? ? ? ? 'Sample_Optimized': np.sum(
? ? ? ? ? ? ? ? ? ? ? ? np.abs(x[Num_col] - x['true_Num']) >
? ? ? ? ? ? ? ? ? ? ? ? np.abs(x[predict_col] - x['true_height'])
? ? ? ? ? ? ? ? ? ? )
? ? ? ? ? ? ? ? })
? ? ? ? ? ? ).reset_index()
? ? ? ?
? ? ? ? grouped['method'] = method
? ? ? ? results.append(grouped)
? ?
? ? if not results:
? ? ? ? print(f"屬性 {attr} 無數據,跳過")
? ? ? ? continue
? ?
? ? # 合并結果
? ? combined_df = pd.concat(results, ignore_index=True)
? ?
? ? # 生成透視表
? ? pivot_df = combined_df.pivot(
? ? ? ? index=attr,
? ? ? ? columns='method',
? ? ? ? values=['Ori_RMSE', 'Pred_RMSE', 'Ori_MAE', 'Pred_MAE']
? ? )
? ?
? ? # 扁平化列名并填充NaN
? ? pivot_df.columns = [f"{method}_{metric}" for metric, method in pivot_df.columns]
? ? pivot_df = pivot_df.fillna(0)
? ?
? ? # 保存到CSV
? ? csv_path = os.path.join(output_dir, f"{attr}.csv")
? ? pivot_df.reset_index().to_csv(csv_path, index=False)
實現了分別對每個方法依據不同屬性評估的功能