CV 醫學影像分類、分割、目標檢測,之【3D肝臟分割】項目拆解
- 第1行:`from posixpath import join`
- 第2行:`from torch.utils.data import DataLoader`
- 第3行:`import os`
- 第4行:`import sys`
- 第5行:`import random`
- 第6行:`from torchvision.transforms import RandomCrop`
- 第7行:`import numpy as np`
- 第8行:`import SimpleITK as sitk`
- 第11行:`import torchvision`
- 第14行:`import glob`
- 第15行:`import pandas as pd`
- 第16行:`import matplotlib.pyplot as plt`
- 第17行:`from PIL import Image`
- 第18-23行:重復導入
- 第24行:`import PIL`
- 第26行:`images=os.listdir('D:/LiverDataset/image')`
- 第27行:`labels=os.listdir('D:/LiverDataset/label')`
- 第29-30行:`image_list=[]` `label_list=[]`
- 第32-35行:第一個for循環
- 第37-40行:第二個for循環
- 第42-46行:`getarrayFromslice`函數
- 第48-49行:調用函數
- 第52-99行:D3UnetData類
- 第52-53行:類定義
- 第54-56行:初始化
- 第58-59行:`__getitem__`方法
- 第62-63行:讀取圖像
- 第65-66行:切片選擇
- 第68行:二值化
- 第70行:類型轉換
- 第72-73行:轉為張量
- 第75-76行:應用變換
- 第78行:返回
- 第80-81行:`__len__`方法
- 第83-99行:D3UnetData_test類
- 第103-105行:定義變換
- 第107-109行:創建數據集
- 第112-113行:創建DataLoader
- 第117-118行:測試加載
- 第121-122行:可視化
- 第125-126行:顯示標簽
- 第129-145行:DoubleConv類
- 第148-151行:測試DoubleConv
- 第154-161行:Down類(下采樣)
- 第164-179行:Up類(上采樣)
- 第182-187行:Out類
- 第190-216行:UNet3d類
- 第219-223行:創建模型
- 第228-229行:定義損失和優化器
- 第232-289行:train函數
- 第292-301行:訓練循環
- 第305-318行:預測和可視化
?
第1行:from posixpath import join
問1:為什么要導入posixpath?
答1:處理文件路徑的拼接。
問2:posixpath和os.path有什么區別?
答2:posixpath專門處理UNIX風格路徑(用/),os.path會根據操作系統自動選擇。
問3:這里為什么不用os.path.join?
答3:實際上這是個錯誤!后面代碼沒用到posixpath.join,應該刪除或改成os.path。
問4:什么是POSIX?
答4:Portable Operating System Interface,可移植操作系統接口標準。
第2行:from torch.utils.data import DataLoader
問5:DataLoader的作用是什么?
答5:批量加載數據,打亂順序,多進程讀取。
問6:為什么需要批量加載?
答6:GPU并行計算需要批量數據,單個樣本浪費算力。
問7:utils是什么意思?
答7:utilities的縮寫,工具集。
第3行:import os
問8:os模塊提供什么功能?
答8:操作系統接口,文件操作、路徑處理、環境變量。
問9:為什么第1行導入了posixpath.join,這里又導入os?
答9:os提供更多功能如listdir,posixpath只處理路徑。
第4行:import sys
問10:sys模塊是做什么的?
答10:Python解釋器相關的變量和函數。
問11:這段代碼為什么導入sys但沒用?
答11:可能是遺留代碼,或準備用于sys.path添加路徑。
第5行:import random
問12:random在深度學習中用來做什么?
答12:數據增強、打亂順序、隨機初始化。
問13:這里導入了但沒用到,可能用途是什么?
答13:可能原計劃做隨機裁剪或隨機選擇訓練樣本。
第6行:from torchvision.transforms import RandomCrop
問14:RandomCrop是什么?
答14:隨機裁剪圖像的數據增強方法。
問15:為什么要隨機裁剪?
答15:增加數據多樣性,防止過擬合。
問16:這里導入了但沒用,為什么?
答16:可能原計劃用但后來改用Resize了。
第7行:import numpy as np
問17:為什么起別名np?
答17:約定俗成,簡化書寫。
問18:numpy和torch的關系是什么?
答18:numpy是CPU數組運算,torch支持GPU,可相互轉換。
第8行:import SimpleITK as sitk
問19:ITK是什么的縮寫?
答19:Insight Segmentation and Registration Toolkit。
問20:為什么叫Simple?
答20:簡化版的ITK,C++庫的Python封裝。
問21:sitk相比其他圖像庫的優勢?
答21:保留醫學圖像元數據:間距、方向、原點。
第11行:import torchvision
問23:torchvision提供什么?
答23:計算機視覺的數據集、模型、變換工具。
問24:vision是指什么?
答24:計算機視覺,讓計算機"看懂"圖像。
第14行:import glob
問27:glob是做什么的?
答27:文件路徑模式匹配,如*.jpg找所有jpg文件。
問28:glob這個名字的由來?
答28:global的縮寫,全局通配符擴展。
問29:這里導入glob但沒用到,可能的原因?
答29:可能原來用glob查找文件,后來改用os.listdir。
第15行:import pandas as pd
問30:pandas在這里可能的用途?
答30:讀取CSV格式的標注信息或記錄訓練日志。
問31:為什么沒用到?
答31:可能改用直接讀取文件夾的方式。
第16行:import matplotlib.pyplot as plt
問32:pyplot是什么?
答32:matplotlib的面向對象繪圖接口。
問33:為什么叫pyplot?
答33:模仿MATLAB的plot函數設計的Python版本。
第17行:from PIL import Image
問34:PIL是什么的縮寫?
答34:Python Imaging Library。
問35:為什么已有PIL還要導入SimpleITK?
答35:PIL處理普通圖像,SimpleITK處理醫學圖像的3D體積和元數據。
第18-23行:重復導入
問36:這么多重復導入說明什么?
答36:代碼沒有經過清理,可能是多次實驗的累積。
第24行:import PIL
問37:已經from PIL import Image,為什么還import PIL?
答37:可能想用PIL的其他功能,但實際沒用到。
第26行:images=os.listdir('D:/LiverDataset/image')
問38:listdir返回什么?
答38:文件夾內所有文件和子文件夾的名稱列表。
問39:為什么用D盤?
答39:Windows系統,D盤通常是數據盤,C盤是系統盤。
問40:LiverDataset說明什么?
答40:Liver是肝臟,這是肝臟分割數據集。
第27行:labels=os.listdir('D:/LiverDataset/label')
問41:label和image的對應關系是什么?
答41:同名文件,一個是原圖,一個是標注。
第29-30行:image_list=[]
label_list=[]
問42:為什么要創建空列表?
答42:準備存儲完整的文件路徑。
問43:列表和數組的區別?
答43:列表可變長度、存儲任意類型;數組固定類型、支持向量運算。
第32-35行:第一個for循環
for i in images:file_path='D:/LiverDataset/image/{}/{}'.format(i,i)print(file_path)image_list.append(file_path)
問44:format是什么?
答44:字符串格式化方法,{}是占位符。
問45:為什么路徑是image/{}/{}
兩個i?
答45:數據組織是:image文件夾/病人ID文件夾/病人ID文件。
問46:append是什么操作?
答46:在列表末尾添加元素。
問47:print的作用?
答47:調試用,確認路徑正確。
第37-40行:第二個for循環
for i in labels:file_path='D:/LiverDataset/label/{}/{}'.format(i,i)print(file_path)label_list.append(file_path)
問48:這段代碼有什么問題?
答48:沒有檢查image和label是否一一對應。
第42-46行:getarrayFromslice
函數
def getarrayFromslice(file_path):image=sitk.ReadImage(file_path)img_array=sitk.GetArrayFromImage(image)shape=img_array.shapeprint(shape)
問49:為什么函數名有From和slice?
答49:從切片文件獲取數組,但命名不準確,應該是from volume。
問50:ReadImage讀取的是什么格式?
答50:SimpleITK的Image對象,包含像素數據和元數據。
問51:GetArrayFromImage做了什么轉換?
答51:從SimpleITK Image轉為numpy array,丟棄元數據。
問52:shape會是什么樣的?
答52:(切片數, 高度, 寬度),如(512, 512, 512)。
第48-49行:調用函數
for i in image_list:getarrayFromslice(i)
問53:這個循環的目的是什么?
答53:檢查所有圖像的尺寸,確保數據一致性。
問54:為什么只打印不保存?
答54:這是數據探索階段,了解數據結構。
第52-99行:D3UnetData類
第52-53行:類定義
class D3UnetData(Dataset):def __init__(self,image_list,label_list,transformer):
問55:D3是什么意思?
答55:3D的意思,表示三維U-Net。
問56:為什么繼承Dataset?
答56:PyTorch要求,實現__getitem__
和__len__
接口。
問57:transformer是什么?
答57:圖像變換操作,如縮放、裁剪。
第54-56行:初始化
self.image_list=image_list
self.label_list=label_list
self.transformer=transformer
問58:self是什么?
答58:實例自身的引用,Python的面向對象機制。
第58-59行:__getitem__
方法
def __getitem__(self,index):image=self.image_list[index]label=self.label_list[index]
問59:index從哪里來?
答59:DataLoader自動生成,從0到len-1。
問60:為什么叫__getitem__
不叫get_item?
答60:Python魔術方法,支持[]索引操作。
第62-63行:讀取圖像
image_ct=sitk.ReadImage(image,sitk.sitkInt16)
label_ct=sitk.ReadImage(label,sitk.sitkInt8)
問61:sitkInt16是什么?
答61:16位有符號整數,范圍-32768到32767。
問62:為什么image用Int16,label用Int8?
答62:CT值范圍大需要16位,標簽只有0/1用8位省內存。
問63:CT值的范圍是多少?
答63:-1000(空氣)到+3000(骨骼),單位是Hounsfield。
第65-66行:切片選擇
ct_array=sitk.GetArrayFromImage(image_ct)[250:300]
label_array=sitk.GetArrayFromImage(label_ct)[250:300]
問64:為什么選250:300?
答64:肝臟在腹部中間位置,這50層大概覆蓋肝臟區域。
問65:如果總共只有200層怎么辦?
答65:會報錯,代碼沒有邊界檢查。
第68行:二值化
label_array[label_array>0]=1
問66:這是什么操作?
答66:布爾索引,將所有大于0的值設為1。
問67:為什么要這樣做?
答67:確保標簽只有0和1兩類,可能原始數據有多個器官標注。
第70行:類型轉換
ct_array=ct_array.astype(np.float32)
問68:為什么轉float32?
答68:神經網絡計算需要浮點數,32位在精度和內存間平衡。
問69:float32和float64的區別?
答69:32位精度7位小數,64位精度15位,深度學習32位夠用。
第72-73行:轉為張量
ct_array=torch.FloatTensor(ct_array).unsqueeze(0)
label_array=torch.LongTensor(label_array)
問70:FloatTensor和LongTensor的區別?
答70:FloatTensor存浮點數,LongTensor存整數索引。
問71:unsqueeze(0)在哪個維度添加?
答71:第0維,[50,512,512]變[1,50,512,512]。
問72:為什么label不需要unsqueeze?
答72:CrossEntropyLoss期望標簽是類別索引,不需要通道維。
第75-76行:應用變換
ct_array=self.transformer(ct_array)
label_array=self.transformer(label_array)
問73:transformer會做什么?
答73:這里是Resize到96×96,減少內存使用。
第78行:返回
return ct_array,label_array
問74:返回的是元組嗎?
答74:是的,Python自動打包成元組。
第80-81行:__len__
方法
def __len__(self):return len(self.image_list)
問75:為什么需要__len__
?
答75:DataLoader需要知道數據集大小來創建批次。
我繼續分析剩余的每一行代碼…
第83-99行:D3UnetData_test類
問76:為什么要單獨的test類?
答76:測試集可能需要不同的預處理或采樣策略。
ct_array=sitk.GetArrayFromImage(image_ct)[200:250]
問77:為什么test用200:250,train用250:300?
答77:避免數據泄露,訓練和測試用不同的切片。
問78:這樣分割合理嗎?
答78:不合理!應該按病人分,不是按切片分。
第103-105行:定義變換
transformer=transforms.Compose([transforms.Resize((96,96)),
])
問79:Compose是什么?
答79:組合多個變換,按順序執行。
問80:為什么縮放到96×96?
答80:原始512×512太大,3D卷積內存消耗巨大。
問81:96這個數字有什么特殊?
答81:是32的倍數,適合多次下采樣(96→48→24→12→6)。
第107-109行:創建數據集
train_ds=D3UnetData(image_list,label_list,transformer)
test_ds=D3UnetData_test(image_list,label_list,transformer)
len(train_ds)
問82:ds是什么縮寫?
答82:dataset的縮寫。
問83:為什么train和test用同樣的image_list?
答83:這是錯誤!訓練集和測試集應該是不同的病人。
第112-113行:創建DataLoader
train_dl=DataLoader(train_ds,batch_size=2,shuffle=True)
test_dl=DataLoader(test_ds,batch_size=2,shuffle=True)
問84:dl是什么縮寫?
答84:dataloader的縮寫。
問85:batch_size=2為什么這么小?
答85:3D醫學圖像占用內存大,[2,1,50,96,96]已經很大了。
問86:測試集為什么要shuffle?
答86:不應該!測試集應該shuffle=False保持順序。
第117-118行:測試加載
img,label=next(iter(train_dl))
print(img.shape,label.shape)
問87:iter是什么?
答87:創建迭代器對象,可以用next獲取下一個。
問88:next做什么?
答88:從迭代器獲取一個批次。
第121-122行:可視化
img_show=img[0,0,25,:,:].numpy()
plt.imshow(img_show,cmap='gray')
問89:[0,0,25,:,:]各維度是什么?
答89:[批次0,通道0,第25層,所有高,所有寬]。
問90:為什么選第25層?
答90:50層的中間,最可能看到肝臟。
問91:cmap='gray’是什么?
答91:colormap灰度顏色映射,CT圖像是灰度的。
第125-126行:顯示標簽
label_show=label[0,25,:,:].numpy()
plt.imshow(label_show,cmap='gray')
問92:標簽為什么沒有通道維度?
答92:標簽是類別索引[batch,depth,height,width]。
第129-145行:DoubleConv類
class DoubleConv(nn.Module):def __init__(self,in_channels,out_channels,num_groups=8):
問93:nn.Module是什么?
答93:PyTorch所有神經網絡層的基類。
問94:in_channels和out_channels是什么?
答94:輸入和輸出的特征圖數量(通道數)。
self.double_conv=nn.Sequential(nn.Conv3d(in_channels,out_channels,kernel_size=3,stride=1,padding=1),
問95:Sequential是什么?
答95:順序容器,依次執行內部的層。
問96:Conv3d的kernel_size=3是什么意思?
答96:3×3×3的立方體卷積核。
問97:stride=1表示什么?
答97:卷積核每次移動1個像素。
問98:padding=1的作用?
答98:邊緣填充1圈0,保持尺寸不變。
nn.GroupNorm(num_groups=num_groups,num_channels=out_channels),
問99:為什么不用BatchNorm3d?
答99:批次太小(2),統計不穩定,GroupNorm不依賴批次大小。
問100:歸一化的目的是什么?
答100:穩定訓練,防止梯度消失或爆炸。
nn.ReLU(inplace=True),
問101:ReLU是什么?
答101:Rectified Linear Unit,max(0,x)激活函數。
問102:inplace=True是什么意思?
答102:原地操作,覆蓋輸入省內存。
第148-151行:測試DoubleConv
img.shape
net=DoubleConv(1,64,num_groups=8)
out=net(img)
print(out.shape)
問103:1→64通道變化意味著什么?
答103:從單通道CT圖提取64種不同的特征。
第154-161行:Down類(下采樣)
class Down(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()
問104:super().init()做什么?
答104:調用父類Module的初始化,注冊參數。
self.encoder = nn.Sequential(nn.MaxPool3d(2, 2),DoubleConv(in_channels, out_channels)
)
問105:MaxPool3d(2,2)是什么操作?
答105:2×2×2窗口取最大值,尺寸減半。
問106:為什么先池化再卷積?
答106:U-Net的設計,減少分辨率同時增加通道數。
第164-179行:Up類(上采樣)
class Up(nn.Module):def __init__(self, in_channels, out_channels, trilinear=True):
問107:trilinear是什么?
答107:三線性插值,3D圖像的平滑上采樣方法。
if trilinear:self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
else:self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
問108:Upsample和ConvTranspose3d的區別?
答108:Upsample插值無參數,ConvTranspose3d有可學習參數。
問109://是什么運算?
答109:整數除法,向下取整。
問110:align_corners=True是什么?
答110:對齊角點像素,影響插值計算方式。
def forward(self, x1, x2):x1 = self.up(x1)
問111:x1和x2分別是什么?
答111:x1是深層特征要上采樣,x2是淺層特征要拼接。
diffZ = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffX = x2.size()[4] - x1.size()[4]
問112:為什么要計算差值?
答112:上采樣可能有尺寸誤差,需要padding對齊。
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
問113:這個padding順序是什么?
答113:[左右,上下,前后]每個維度的padding量。
問114:為什么用diffX - diffX // 2
?
答114:處理奇數差值,如diff=3時,一邊pad 1,另一邊pad 2。
x = torch.cat([x2, x1], dim=1)
問115:dim=1是哪個維度?
答115:通道維度,拼接特征圖。
第182-187行:Out類
class Out(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)
問116:kernel_size=1的作用?
答116:1×1×1卷積,只改變通道數不改變空間尺寸。
問117:這里out_channels應該是多少?
答117:2,對應背景和肝臟兩類。
第190-216行:UNet3d類
class UNet3d(nn.Module):def __init__(self, in_channels, n_classes, n_channels):
問118:三個參數分別是什么?
答118:輸入通道(1)、類別數(2)、基礎通道數(24)。
self.conv = DoubleConv(in_channels, n_channels)
self.enc1 = Down(n_channels, 2 * n_channels)
self.enc2 = Down(2 * n_channels, 4 * n_channels)
self.enc3 = Down(4 * n_channels, 8 * n_channels)
self.enc4 = Down(8 * n_channels, 8 * n_channels)
問119:通道數為什么是24,48,96,192,192?
答119:每層翻倍增加特征,最后一層保持防止爆內存。
問120:enc是什么縮寫?
答120:encoder編碼器,提取特征。
self.dec1 = Up(16 * n_channels, 4 * n_channels)
問121:為什么是16×n_channels?
答121:8(enc4)+8(enc3跳躍連接)=16。
def forward(self, x):x1 = self.conv(x)x2 = self.enc1(x1)x3 = self.enc2(x2)x4 = self.enc3(x3)x5 = self.enc4(x4)
問122:x1到x5的尺寸變化?
答122:
- x1: [2,24,50,96,96]
- x2: [2,48,25,48,48]
- x3: [2,96,12,24,24]
- x4: [2,192,6,12,12]
- x5: [2,192,3,6,6]
mask = self.dec1(x5, x4)
mask = self.dec2(mask, x3)
mask = self.dec3(mask, x2)
mask = self.dec4(mask, x1)
mask = self.out(mask)
問123:為什么叫mask?
答123:分割結果是掩碼,標記每個像素的類別。
第219-223行:創建模型
model=UNet3d(1,2,24).cuda()
img,label=next(iter(train_dl))
print(img.shape,label.shape)
img=img.cuda()
pred=model(img)
問124:.cuda()做什么?
答124:把模型和數據移到GPU上。
問125:24這個基礎通道數怎么選的?
答125:平衡性能和內存,太小欠擬合,太大爆顯存。
第228-229行:定義損失和優化器
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.00001)
問126:CrossEntropyLoss包含什么操作?
答126:LogSoftmax + NLLLoss,多分類標準損失。
問127:Adam是什么?
答127:Adaptive Moment Estimation,自適應學習率優化器。
問128:lr=0.00001為什么這么小?
答128:醫學圖像精度要求高,小學習率穩定訓練。
第232-289行:train函數
from tqdm import tqdm
def train(epoch, model, trainloader, testloader):
問129:tqdm是什么?
答129:進度條庫,顯示訓練進度。
correct = 0
total = 0
running_loss = 0
epoch_iou = []
問130:這些變量分別統計什么?
答130:正確像素數、總像素數、累計損失、每批次IOU。
model.train()
問131:model.train()做什么?
答131:啟用dropout和batch normalization的訓練模式。
for x, y in tqdm(trainloader):x, y = x.to('cuda'), y.to('cuda')
問132:為什么用.to(‘cuda’)而不是.cuda()?
答132:.to()更通用,可以指定設備字符串。
y_pred = model(x)
loss = loss_fn(y_pred, y)
問133:y_pred的形狀是什么?
答133:[2,2,50,96,96],批次×類別×深×高×寬。
optimizer.zero_grad()
loss.backward()
optimizer.step()
問134:這三步分別做什么?
答134:清除梯度、反向傳播計算梯度、更新參數。
問135:為什么要zero_grad?
答135:PyTorch梯度會累積,不清零會疊加。
with torch.no_grad():y_pred = torch.argmax(y_pred, dim=1)
問136:no_grad()的作用?
答136:禁用梯度計算,節省內存。
問137:argmax在dim=1做什么?
答137:在類別維度取最大值索引,得到預測類別。
correct += (y_pred == y).sum().item()
total += y.size(0)
問138:.item()是什么?
答138:將單元素張量轉為Python數值。
intersection = torch.logical_and(y, y_pred)
union = torch.logical_or(y, y_pred)
batch_iou = torch.sum(intersection) / torch.sum(union)
問139:logical_and和logical_or是什么運算?
答139:邏輯與(交集)和邏輯或(并集)。
問140:IOU的值域是什么?
答140:0到1,1表示完全重合。
epoch_loss = running_loss / len(trainloader.dataset)
epoch_acc = correct / (total*96*96*50)
問141:為什么除以total×96×96×50?
答141:總像素數=批次數×每批像素數。
model.eval()
問142:eval()模式改變什么?
答142:關閉dropout,batch norm用運行統計。
if np.mean(epoch_test_iou)>0.9:static_dict=model.state_dict()torch.save(static_dict,'./checkpoint/{}_trainIOU_{}_testIOU_{}.pth'.format(epoch,round(np.mean(epoch_iou), 3),round(np.mean(epoch_test_iou),3)))
問143:state_dict包含什么?
答143:模型所有參數的字典。
問144:.pth是什么格式?
答144:PyTorch的模型文件擴展名。
問145:round(x, 3)做什么?
答145:保留3位小數。
第292-301行:訓練循環
epochs = 100
train_loss = []
train_acc = []
test_loss = []
test_acc = []
問146:為什么創建這些列表但沒用?
答146:可能原計劃畫損失曲線,但沒實現。
for epoch in range(epochs):train(epoch, model, train_dl, test_dl)
問147:range(epochs)生成什么?
答147:0到99的整數序列。
第305-318行:預測和可視化
img,label=next(iter(train_dl))
print(img.shape,label.shape)
img=img.to('cuda')
pred=model(img)
問148:為什么又預測一次?
答148:訓練后查看效果。
label_show=label[0,20,:,:]
plt.imshow(label_show,cmap='gray')
問149:為什么選第20層?
答149:隨意選擇的中間層查看。
preds=pred.cpu()
問150:為什么要.cpu()?
答150:matplotlib不能直接顯示GPU張量。
plt.imshow(torch.argmax(preds.permute(1,2,0), axis=-1).detach().numpy(),cmap='gray')
問151:這行代碼有什么問題?
答151:維度不對!preds是5維的,不能直接permute(1,2,0)。
plt.imshow(torch.argmax(pred.permute(1,2,0), axis=-1).detach().numpy())
問152:這行和上一行的區別?
答152:用了pred而不是preds,但pred在GPU上會報錯。
問153:detach()的作用?
答153:斷開計算圖,返回不需要梯度的張量。