forked from KEMT/zpwiki
		
	
		
			
				
	
	
		
			225 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			225 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import math
 | 
						|
import random
 | 
						|
from collections import namedtuple
 | 
						|
 | 
						|
import torch
 | 
						|
from torch import nn as nn
 | 
						|
import torch.nn.functional as F
 | 
						|
 | 
						|
from util.logconf import logging
 | 
						|
from util.unet import UNet
 | 
						|
 | 
						|
log = logging.getLogger(__name__)
 | 
						|
# log.setLevel(logging.WARN)
 | 
						|
# log.setLevel(logging.INFO)
 | 
						|
log.setLevel(logging.DEBUG)
 | 
						|
 | 
						|
class UNetWrapper(nn.Module):
 | 
						|
    def __init__(self, **kwargs):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
 | 
						|
        self.unet = UNet(**kwargs)
 | 
						|
        self.final = nn.Sigmoid()
 | 
						|
 | 
						|
        self._init_weights()
 | 
						|
 | 
						|
    def _init_weights(self):
 | 
						|
        init_set = {
 | 
						|
            nn.Conv2d,
 | 
						|
            nn.Conv3d,
 | 
						|
            nn.ConvTranspose2d,
 | 
						|
            nn.ConvTranspose3d,
 | 
						|
            nn.Linear,
 | 
						|
        }
 | 
						|
        for m in self.modules():
 | 
						|
            if type(m) in init_set:
 | 
						|
                nn.init.kaiming_normal_(
 | 
						|
                    m.weight.data, mode='fan_out', nonlinearity='relu', a=0
 | 
						|
                )
 | 
						|
                if m.bias is not None:
 | 
						|
                    fan_in, fan_out = \
 | 
						|
                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
 | 
						|
                    bound = 1 / math.sqrt(fan_out)
 | 
						|
                    nn.init.normal_(m.bias, -bound, bound)
 | 
						|
 | 
						|
        # nn.init.constant_(self.unet.last.bias, -4)
 | 
						|
        # nn.init.constant_(self.unet.last.bias, 4)
 | 
						|
 | 
						|
 | 
						|
    def forward(self, input_batch):
 | 
						|
        bn_output = self.input_batchnorm(input_batch)
 | 
						|
        un_output = self.unet(bn_output)
 | 
						|
        fn_output = self.final(un_output)
 | 
						|
        return fn_output
 | 
						|
 | 
						|
class SegmentationAugmentation(nn.Module):
 | 
						|
    def __init__(
 | 
						|
            self, flip=None, offset=None, scale=None, rotate=None, noise=None
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        self.flip = flip
 | 
						|
        self.offset = offset
 | 
						|
        self.scale = scale
 | 
						|
        self.rotate = rotate
 | 
						|
        self.noise = noise
 | 
						|
 | 
						|
    def forward(self, input_g, label_g):
 | 
						|
        transform_t = self._build2dTransformMatrix()
 | 
						|
        transform_t = transform_t.expand(input_g.shape[0], -1, -1)
 | 
						|
        transform_t = transform_t.to(input_g.device, torch.float32)
 | 
						|
        affine_t = F.affine_grid(transform_t[:,:2],
 | 
						|
                input_g.size(), align_corners=False)
 | 
						|
 | 
						|
        augmented_input_g = F.grid_sample(input_g,
 | 
						|
                affine_t, padding_mode='border',
 | 
						|
                align_corners=False)
 | 
						|
        augmented_label_g = F.grid_sample(label_g.to(torch.float32),
 | 
						|
                affine_t, padding_mode='border',
 | 
						|
                align_corners=False)
 | 
						|
 | 
						|
        if self.noise:
 | 
						|
            noise_t = torch.randn_like(augmented_input_g)
 | 
						|
            noise_t *= self.noise
 | 
						|
 | 
						|
            augmented_input_g += noise_t
 | 
						|
 | 
						|
        return augmented_input_g, augmented_label_g > 0.5
 | 
						|
 | 
						|
    def _build2dTransformMatrix(self):
 | 
						|
        transform_t = torch.eye(3)
 | 
						|
 | 
						|
        for i in range(2):
 | 
						|
            if self.flip:
 | 
						|
                if random.random() > 0.5:
 | 
						|
                    transform_t[i,i] *= -1
 | 
						|
 | 
						|
            if self.offset:
 | 
						|
                offset_float = self.offset
 | 
						|
                random_float = (random.random() * 2 - 1)
 | 
						|
                transform_t[2,i] = offset_float * random_float
 | 
						|
 | 
						|
            if self.scale:
 | 
						|
                scale_float = self.scale
 | 
						|
                random_float = (random.random() * 2 - 1)
 | 
						|
                transform_t[i,i] *= 1.0 + scale_float * random_float
 | 
						|
 | 
						|
        if self.rotate:
 | 
						|
            angle_rad = random.random() * math.pi * 2
 | 
						|
            s = math.sin(angle_rad)
 | 
						|
            c = math.cos(angle_rad)
 | 
						|
 | 
						|
            rotation_t = torch.tensor([
 | 
						|
                [c, -s, 0],
 | 
						|
                [s, c, 0],
 | 
						|
                [0, 0, 1]])
 | 
						|
 | 
						|
            transform_t @= rotation_t
 | 
						|
 | 
						|
        return transform_t
 | 
						|
 | 
						|
 | 
						|
# MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
 | 
						|
#
 | 
						|
# class SegmentationMask(nn.Module):
 | 
						|
#     def __init__(self):
 | 
						|
#         super().__init__()
 | 
						|
#
 | 
						|
#         self.conv_list = nn.ModuleList([
 | 
						|
#             self._make_circle_conv(radius) for radius in range(1, 8)
 | 
						|
#         ])
 | 
						|
#
 | 
						|
#     def _make_circle_conv(self, radius):
 | 
						|
#         diameter = 1 + radius * 2
 | 
						|
#
 | 
						|
#         a = torch.linspace(-1, 1, steps=diameter)**2
 | 
						|
#         b = (a[None] + a[:, None])**0.5
 | 
						|
#
 | 
						|
#         circle_weights = (b <= 1.0).to(torch.float32)
 | 
						|
#
 | 
						|
#         conv = nn.Conv2d(1, 1, kernel_size=diameter, padding=radius, bias=False)
 | 
						|
#         conv.weight.data.fill_(1)
 | 
						|
#         conv.weight.data *= circle_weights / circle_weights.sum()
 | 
						|
#
 | 
						|
#         return conv
 | 
						|
#
 | 
						|
#
 | 
						|
#     def erode(self, input_mask, radius, threshold=1):
 | 
						|
#         conv = self.conv_list[radius - 1]
 | 
						|
#         input_float = input_mask.to(torch.float32)
 | 
						|
#         result = conv(input_float)
 | 
						|
#
 | 
						|
#         # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
 | 
						|
#         # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
 | 
						|
#
 | 
						|
#         return result >= threshold
 | 
						|
#
 | 
						|
#     def deposit(self, input_mask, radius, threshold=0):
 | 
						|
#         conv = self.conv_list[radius - 1]
 | 
						|
#         input_float = input_mask.to(torch.float32)
 | 
						|
#         result = conv(input_float)
 | 
						|
#
 | 
						|
#         # log.debug(['deposit in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
 | 
						|
#         # log.debug(['deposit out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
 | 
						|
#
 | 
						|
#         return result > threshold
 | 
						|
#
 | 
						|
#     def fill_cavity(self, input_mask):
 | 
						|
#         cumsum = input_mask.cumsum(-1)
 | 
						|
#         filled_mask = (cumsum > 0)
 | 
						|
#         filled_mask &= (cumsum < cumsum[..., -1:])
 | 
						|
#         cumsum = input_mask.cumsum(-2)
 | 
						|
#         filled_mask &= (cumsum > 0)
 | 
						|
#         filled_mask &= (cumsum < cumsum[..., -1:, :])
 | 
						|
#
 | 
						|
#         return filled_mask
 | 
						|
#
 | 
						|
#
 | 
						|
#     def forward(self, input_g, raw_pos_g):
 | 
						|
#         gcc_g = input_g + 1
 | 
						|
#
 | 
						|
#         with torch.no_grad():
 | 
						|
#             # log.info(['gcc_g', gcc_g.min(), gcc_g.mean(), gcc_g.max()])
 | 
						|
#
 | 
						|
#             raw_dense_mask = gcc_g > 0.7
 | 
						|
#             dense_mask = self.deposit(raw_dense_mask, 2)
 | 
						|
#             dense_mask = self.erode(dense_mask, 6)
 | 
						|
#             dense_mask = self.deposit(dense_mask, 4)
 | 
						|
#
 | 
						|
#             body_mask = self.fill_cavity(dense_mask)
 | 
						|
#             air_mask = self.deposit(body_mask & ~dense_mask, 5)
 | 
						|
#             air_mask = self.erode(air_mask, 6)
 | 
						|
#
 | 
						|
#             lung_mask = self.deposit(air_mask, 5)
 | 
						|
#
 | 
						|
#             raw_candidate_mask = gcc_g > 0.4
 | 
						|
#             raw_candidate_mask &= air_mask
 | 
						|
#             candidate_mask = self.erode(raw_candidate_mask, 1)
 | 
						|
#             candidate_mask = self.deposit(candidate_mask, 1)
 | 
						|
#
 | 
						|
#             pos_mask = self.deposit((raw_pos_g > 0.5) & lung_mask, 2)
 | 
						|
#
 | 
						|
#             neg_mask = self.deposit(candidate_mask, 1)
 | 
						|
#             neg_mask &= ~pos_mask
 | 
						|
#             neg_mask &= lung_mask
 | 
						|
#
 | 
						|
#             # label_g = (neg_mask | pos_mask).to(torch.float32)
 | 
						|
#             label_g = (pos_mask).to(torch.float32)
 | 
						|
#             neg_g = neg_mask.to(torch.float32)
 | 
						|
#             pos_g = pos_mask.to(torch.float32)
 | 
						|
#
 | 
						|
#         mask_dict = {
 | 
						|
#             'raw_dense_mask': raw_dense_mask,
 | 
						|
#             'dense_mask': dense_mask,
 | 
						|
#             'body_mask': body_mask,
 | 
						|
#             'air_mask': air_mask,
 | 
						|
#             'raw_candidate_mask': raw_candidate_mask,
 | 
						|
#             'candidate_mask': candidate_mask,
 | 
						|
#             'lung_mask': lung_mask,
 | 
						|
#             'neg_mask': neg_mask,
 | 
						|
#             'pos_mask': pos_mask,
 | 
						|
#         }
 | 
						|
#
 | 
						|
#         return label_g, neg_g, pos_g, lung_mask, mask_dict
 |