2021年認證杯SPSSPRO杯數學建模
A題 醫學圖像的配準
原題再現:
??圖像的配準是圖像處理領域中的一個典型問題和技術難點,其目的在于比較或融合同一對象在不同條件下獲取的圖像。例如為了更好地綜合多種信息來辨識不同組織或病變,醫生可能使用多種儀器對患者的同一部位進行成像。在綜合多幅圖像時,首先需要將它們嚴格對齊,使得圖上同一個坐標的位置對應的是真實對象的同一個點,這個過程稱之為配準。現在的許多醫學成像技術,包括 CT、MRI、PET 等,最終生成的是人體的斷層影像。在這里,我們主要關心的是斷層成像的配準問題。
??我們考慮對一個患者的腹部進行斷層成像。由于人體組織是柔軟的,所以即使使用同一臺成像設備,兩次成像的結果也并不完全一致。最終輸出時還會對圖像進行自動放縮,所以輸出圖片的大小也并不完全相同。想要精確配準,需要將其中一次的成像結果進行某種仿射變換(或非線性變換),以盡可能地匹配另一次的結果(或將兩次結果都映射到同一個標準模板中)。求得合適的變換就是圖像配準的核心任務。
??第二階段問題:多模態的配準是指對來自不同設備的圖像進行配準,例如對CT和MRI圖像進行配準。有的組織或病變部位在單一的成像技術下與周邊組織的區分不明顯,所以我們可以對多種成像技術得到的圖片進行融合處理,讓每個像素點的成像結果表現為一個多維向量(每個分量都是一種成像技術的成像結果),這樣可以更好地識別組織或病變的細節。多模態的配準則是圖像融合的第一步。
??現在我們有對患者同一身體部位(同一時間)的CT、MRI和PET成像結果。但在圖像融合處理時遇到了兩個問題:首先,每一種成像技術對不同組織的區分能力是不同的,例如有些不同的組織在CT下看起來區別不大,但在MRI下區分卻十分明顯;另外一些組織在MRI下區別不大,但在CT下區分卻十分明顯。所以對不同的成像設備而言,即使是同一個位置的成像結果,也并非完全相似,這給配準帶來了難度。第二,在進行斷層成像時,雖然對每個設備而言,我們能夠確切地知道每個斷層的位置,但不同設備掃描的斷層位置并不完全相同。請你設計一個有效的方法,對這樣的成像結果進行圖像的融合。
整體求解過程概述(摘要)
??一種基于多模態醫學圖像的圖像融合方法可以顯著提高融合圖像的質量。一種有效的圖像融合技術通過保留從源圖像中收集到的所有可行的和顯著的信息而不引入任何缺陷或不必要的扭曲來產生輸出圖像。
??大多數深度學習方法采用所謂的單流高到低、低到高的網絡結構,能夠獲得滿意的整體配準結果。然而,一些嚴重變形的局部區域的精確定位,這是精確定位手術目標的關鍵,卻經常被忽視。因此,這些方法對一些難以對齊的區域不敏感,例如畸形肝葉的圖像配準融合。針對這一問題,我們提出一種新的無監督配準網絡,即全分辨率殘差配準網絡(F3RNet),用于嚴重變形器官的變形配準。該方法以殘差學習的方式結合了兩個并行處理流。一個流利用全分辨率信息,促進準確的體素級注冊。另一個流學習深度多尺度殘差表示以獲得魯棒特征識別與提取。此外,我們還使用了提升小波變換,混合融合等方法對圖像特征進行精確分類。最后對三維卷積進行因式分解,得到圖像配準與融合結果。我們選用了腹部和肺部的CT-MRI數據集對所提方法進行驗證,實驗結果表明所提方法可以獲取更高質量的配準融合圖像,同時又能顯著提高配準融合效率。
??為了驗證我們提出方法的有效性,我們選用自備的醫學CT圖像和MRI圖像數據集進行了評估,本文配準模型在肺部圖像與腦部圖像的配準值有一定程度的提升。從模型的運行的結果可以看出,當AUC值為0.883時,得到Dice稀疏為0.575,準確率為0.884,靈敏度為0.647,特異性為0.929,F1得分為0.640。模型中還進行了病變程度與匹配程度的相關性分析,在保證快捷的同時方便簡潔地解決了問題,具有一定的推廣意義。
問題分析:
??考慮到在進行疾病診斷時,不同的醫學圖像在進行疾病診斷時所發揮的作用時不同的,為更全面的了解病人的病情,往往需要多種醫學圖像提供不同的信息,通過這些信息的綜合分析,能夠為臨床診斷治療提供全面的信息。深度學習的發展為多模態醫學配準和圖像融合提供了新的思路,利用深度學習的模型尋找待配準圖像對像素點間的空間對應關系,使固定圖像與浮動圖像的像素點在空間解剖結構上對齊,根據對齊關系進行利用融合特性對這些提取的特征進行融合,能夠更清晰的反應病變的情況。該方案的具體解決措施如下:
??首先,提出一種全分辨率殘差配準網絡用于配準不同醫學圖像中的相關特征,通過設計全分辨率和多尺度殘差塊這兩個并行的網絡,全分辨率網絡在密集網絡表現較好,能夠通過常規殘差塊提取不同醫學圖像中的低階特征。多尺度殘差塊利用連續池化和卷積操作來增加識別范圍,善于捕捉高級特征,從而提高識別性能。這兩個并行網絡能夠全面提取不同醫學圖像的特征,不會因為醫學圖像分辨率不同導致特征提取不全面的情況。
??其次,在提取完全特征的基礎上,使用小波變換融合和混合圖像融合相結合的方法進行醫學圖像的融合。小波變換融合利用小波變換對源參考的醫學圖像進行分解,利用定量度量度量技術的性能,并比較技術的效率,以找到合適的融合規則。再使用混合圖像融合技術可以從有噪聲、失真的圖像中提取原始圖像特征,從而獲得改進和增強的圖像質量。
模型假設:
??在數學建模的過程中,為了使模型簡單明確,簡化運算過程,在不影響模型意義與計算精度的前提下,建立了如下假設:
??(1)假設同一組的醫學圖像來自于同一位患者。
??(2)假設用來采集每一張醫學圖像的設備都能正常采集醫學圖像。
??(3)假設采集到的醫學圖像沒有人為修改與損壞,并且能夠直接使用。
??(4)假設同一組診斷圖像為患者同一部位的醫學圖像。
??(5)假設患者在兩次成像之間身體沒有發生其他部分的病變,或者存在導致成像偏差的身體異狀。
??(6)假設每臺計算機輔助醫療系統的參數設定與配置都保持一致。
論文縮略圖:
程序代碼:
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import pathlib
import torch
import math
import warnings
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageColor
__all__=["make_grid","save_image", "draw_bounding_boxes", "draw_segmentation_masks"]
def make_grid(tensor: Union[torch.Tensor, List[torch.Tensor]],nrow: int = 8,padding: int = 2,normalize: bool = False,value_range: Optional[Tuple[int, int]] = None,scale_each: bool = False,pad_value: int = 0,**kwargs
) -> torch.Tensor:if not (torch.is_tensor(tensor) or(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')if "range" in kwargs.keys():warning = "range will be deprecated, please use value_range instead."warnings.warn(warning)value_range = kwargs["range"]# if list of tensors, convert to a 4D mini-batch Tensorif isinstance(tensor, list):tensor = torch.stack(tensor, dim=0)if tensor.dim() == 2: # single image H x Wtensor = tensor.unsqueeze(0)if tensor.dim() == 3: # single imageif tensor.size(0) == 1: # if single-channel, convert to 3-channeltensor = torch.cat((tensor, tensor, tensor), 0)tensor = tensor.unsqueeze(0)if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel imagestensor = torch.cat((tensor, tensor, tensor), 1)if normalize is True:tensor = tensor.clone() # avoid modifying tensor in-placeif value_range is not None:assert isinstance(value_range, tuple), \"value_range has to be a tuple (min, max) if specified. min and max are numbers"def norm_ip(img, low, high):img.clamp_(min=low, max=high)img.sub_(low).div_(max(high - low, 1e-5))def norm_range(t, value_range):if value_range is not None:norm_ip(t, value_range[0], value_range[1])else:norm_ip(t, float(t.min()), float(t.max()))if scale_each is True:for t in tensor: # loop over mini-batch dimensionnorm_range(t, value_range)else:norm_range(tensor, value_range)if tensor.size(0) == 1:return tensor.squeeze(0)# make the mini-batch of images into a gridnmaps = tensor.size(0)xmaps = min(nrow, nmaps)ymaps = int(math.ceil(float(nmaps) / xmaps))height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)num_channels = tensor.size(1)grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)k = 0for y in range(ymaps):for x in range(xmaps):if k >= nmaps:break# Tensor.copy_() is a valid method but seems to be missing from the stubs# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]2, x * width + padding, width - padding).copy_(tensor[k])k = k + 1return grid
@torch.no_grad()
def save_image(tensor: Union[torch.Tensor, List[torch.Tensor]],fp: Union[Text, pathlib.Path, BinaryIO],format: Optional[str] = None,**kwargs
) -> None:grid = make_grid(tensor, **kwargs)# Add 0.5 after unnormalizing to [0, 255] to round to nearest integerndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()im = Image.fromarray(ndarr)im.save(fp, format=format)
def draw_bounding_boxes(image: torch.Tensor,boxes: torch.Tensor,labels: Optional[List[str]] = None,colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,fill: Optional[bool] = False,width: int = 1,font: Optional[str] = None,font_size: int = 10
) -> torch.Tensor:if not isinstance(image, torch.Tensor):raise TypeError(f"Tensor expected, got {type(image)}")elif image.dtype != torch.uint8:raise ValueError(f"Tensor uint8 expected, got {image.dtype}")elif image.dim() != 3:raise ValueError("Pass individual images, not batches")ndarr = image.permute(1, 2, 0).numpy()img_to_draw = Image.fromarray(ndarr)img_boxes = boxes.to(torch.int64).tolist()if fill:draw = ImageDraw.Draw(img_to_draw, "RGBA")else:draw = ImageDraw.Draw(img_to_draw)txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size)for i, bbox in enumerate(img_boxes):if colors is None:color = Noneelse:color = colors[i]if fill:if color is None:fill_color = (255, 255, 255, 100)elif isinstance(color, str):# This will automatically raise Error if rgb cannot be parsed.fill_color = ImageColor.getrgb(color) + (100,)elif isinstance(color, tuple):fill_color = color + (100,)draw.rectangle(bbox, width=width, outline=color, fill=fill_color)else:draw.rectangle(bbox, width=width, outline=color)if labels is not None:draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
def draw_segmentation_masks(image: torch.Tensor,masks: torch.Tensor,alpha: float = 0.2,colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
) -> torch.Tensor:if not isinstance(image, torch.Tensor):raise TypeError(f"Tensor expected, got {type(image)}")elif image.dtype != torch.uint8:raise ValueError(f"Tensor uint8 expected, got {image.dtype}")elif image.dim() != 3:raise ValueError("Pass individual images, not batches")elif image.size()[0] != 3:raise ValueError("Pass an RGB image. Other Image formats are not supported")num_masks = masks.size()[0]masks = masks.argmax(0)if colors is None:palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])colors_t = torch.as_tensor([i for i in range(num_masks)])[:, None] * palettecolor_arr = (colors_t % 255).numpy().astype("uint8")else:color_list = []for color in colors:if isinstance(color, str):# This will automatically raise Error if rgb cannot be parsed.fill_color = ImageColor.getrgb(color)color_list.append(fill_color)elif isinstance(color, tuple):color_list.append(color)color_arr = np.array(color_list).astype("uint8")_, h, w = image.size()img_to_draw = Image.fromarray(masks.byte().cpu().numpy()).resize((w, h))img_to_draw.putpalette(color_arr)img_to_draw = torch.from_numpy(np.array(img_to_draw.convert('RGB')))img_to_draw = img_to_draw.permute((2, 0, 1))return (image.float() * alpha + img_to_draw.float() * (1.0 - alpha)).to(dtype=torch.uint8)