forked from KEMT/zpwiki
		
	Nahrát soubory do „pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util“
This commit is contained in:
		
							parent
							
								
									526fe93ff9
								
							
						
					
					
						commit
						4dca468612
					
				
										
											Binary file not shown.
										
									
								
							@ -0,0 +1,331 @@
 | 
				
			|||||||
 | 
					import math
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import scipy.ndimage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.autograd import Function
 | 
				
			||||||
 | 
					from torch.autograd.function import once_differentiable
 | 
				
			||||||
 | 
					import torch.backends.cudnn as cudnn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from util.logconf import logging
 | 
				
			||||||
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					# log.setLevel(logging.WARN)
 | 
				
			||||||
 | 
					# log.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					log.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cropToShape(image, new_shape, center_list=None, fill=0.0):
 | 
				
			||||||
 | 
					    # log.debug([image.shape, new_shape, center_list])
 | 
				
			||||||
 | 
					    # assert len(image.shape) == 3, repr(image.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if center_list is None:
 | 
				
			||||||
 | 
					        center_list = [int(image.shape[i] / 2) for i in range(3)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    crop_list = []
 | 
				
			||||||
 | 
					    for i in range(0, 3):
 | 
				
			||||||
 | 
					        crop_int = center_list[i]
 | 
				
			||||||
 | 
					        if image.shape[i] > new_shape[i] and crop_int is not None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # We can't just do crop_int +/- shape/2 since shape might be odd
 | 
				
			||||||
 | 
					            # and ints round down.
 | 
				
			||||||
 | 
					            start_int = crop_int - int(new_shape[i]/2)
 | 
				
			||||||
 | 
					            end_int = start_int + new_shape[i]
 | 
				
			||||||
 | 
					            crop_list.append(slice(max(0, start_int), end_int))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            crop_list.append(slice(0, image.shape[i]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # log.debug([image.shape, crop_list])
 | 
				
			||||||
 | 
					    image = image[crop_list]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    crop_list = []
 | 
				
			||||||
 | 
					    for i in range(0, 3):
 | 
				
			||||||
 | 
					        if image.shape[i] < new_shape[i]:
 | 
				
			||||||
 | 
					            crop_int = int((new_shape[i] - image.shape[i]) / 2)
 | 
				
			||||||
 | 
					            crop_list.append(slice(crop_int, crop_int + image.shape[i]))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            crop_list.append(slice(0, image.shape[i]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # log.debug([image.shape, crop_list])
 | 
				
			||||||
 | 
					    new_image = np.zeros(new_shape, dtype=image.dtype)
 | 
				
			||||||
 | 
					    new_image[:] = fill
 | 
				
			||||||
 | 
					    new_image[crop_list] = image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def zoomToShape(image, new_shape, square=True):
 | 
				
			||||||
 | 
					    # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if square and image.shape[0] != image.shape[1]:
 | 
				
			||||||
 | 
					        crop_int = min(image.shape[0], image.shape[1])
 | 
				
			||||||
 | 
					        new_shape = [crop_int, crop_int, image.shape[2]]
 | 
				
			||||||
 | 
					        image = cropToShape(image, new_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    zoom_shape = [new_shape[i] / image.shape[i] for i in range(3)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with warnings.catch_warnings():
 | 
				
			||||||
 | 
					        warnings.simplefilter("ignore")
 | 
				
			||||||
 | 
					        image = scipy.ndimage.interpolation.zoom(
 | 
				
			||||||
 | 
					            image, zoom_shape,
 | 
				
			||||||
 | 
					            output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def randomOffset(image_list, offset_rows=0.125, offset_cols=0.125):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    center_list = [int(image_list[0].shape[i] / 2) for i in range(3)]
 | 
				
			||||||
 | 
					    center_list[0] += int(offset_rows * (random.random() - 0.5) * 2)
 | 
				
			||||||
 | 
					    center_list[1] += int(offset_cols * (random.random() - 0.5) * 2)
 | 
				
			||||||
 | 
					    center_list[2] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_list = []
 | 
				
			||||||
 | 
					    for image in image_list:
 | 
				
			||||||
 | 
					        new_image = cropToShape(image, image.shape, center_list)
 | 
				
			||||||
 | 
					        new_list.append(new_image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def randomZoom(image_list, scale=None, scale_min=0.8, scale_max=1.3):
 | 
				
			||||||
 | 
					    if scale is None:
 | 
				
			||||||
 | 
					        scale = scale_min + (scale_max - scale_min) * random.random()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_list = []
 | 
				
			||||||
 | 
					    for image in image_list:
 | 
				
			||||||
 | 
					        # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with warnings.catch_warnings():
 | 
				
			||||||
 | 
					            warnings.simplefilter("ignore")
 | 
				
			||||||
 | 
					            # log.info([image.shape])
 | 
				
			||||||
 | 
					            zimage = scipy.ndimage.interpolation.zoom(
 | 
				
			||||||
 | 
					                image, [scale, scale, 1.0],
 | 
				
			||||||
 | 
					                output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
 | 
				
			||||||
 | 
					        image = cropToShape(zimage, image.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_list.append(image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_randomFlip_transform_list = [
 | 
				
			||||||
 | 
					    # lambda a: np.rot90(a, axes=(0, 1)),
 | 
				
			||||||
 | 
					    # lambda a: np.flip(a, 0),
 | 
				
			||||||
 | 
					    lambda a: np.flip(a, 1),
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def randomFlip(image_list, transform_bits=None):
 | 
				
			||||||
 | 
					    if transform_bits is None:
 | 
				
			||||||
 | 
					        transform_bits = random.randrange(0, 2 ** len(_randomFlip_transform_list))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_list = []
 | 
				
			||||||
 | 
					    for image in image_list:
 | 
				
			||||||
 | 
					        # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for n in range(len(_randomFlip_transform_list)):
 | 
				
			||||||
 | 
					            if transform_bits & 2**n:
 | 
				
			||||||
 | 
					                # prhist(image, 'before')
 | 
				
			||||||
 | 
					                image = _randomFlip_transform_list[n](image)
 | 
				
			||||||
 | 
					                # prhist(image, 'after ')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_list.append(image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def randomSpin(image_list, angle=None, range_tup=None, axes=(0, 1)):
 | 
				
			||||||
 | 
					    if range_tup is None:
 | 
				
			||||||
 | 
					        range_tup = (0, 360)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if angle is None:
 | 
				
			||||||
 | 
					        angle = range_tup[0] + (range_tup[1] - range_tup[0]) * random.random()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_list = []
 | 
				
			||||||
 | 
					    for image in image_list:
 | 
				
			||||||
 | 
					        # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image = scipy.ndimage.interpolation.rotate(
 | 
				
			||||||
 | 
					                image, angle, axes=axes, reshape=False,
 | 
				
			||||||
 | 
					                output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_list.append(image)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def randomNoise(image_list, noise_min=-0.1, noise_max=0.1):
 | 
				
			||||||
 | 
					    noise = np.zeros_like(image_list[0])
 | 
				
			||||||
 | 
					    noise += (noise_max - noise_min) * np.random.random_sample(image_list[0].shape) + noise_min
 | 
				
			||||||
 | 
					    noise *= 5
 | 
				
			||||||
 | 
					    noise = scipy.ndimage.filters.gaussian_filter(noise, 3)
 | 
				
			||||||
 | 
					    # noise += (noise_max - noise_min) * np.random.random_sample(image_hsv.shape) + noise_min
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_list = []
 | 
				
			||||||
 | 
					    for image_hsv in image_list:
 | 
				
			||||||
 | 
					        image_hsv = image_hsv + noise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_list.append(image_hsv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def randomHsvShift(image_list, h=None, s=None, v=None,
 | 
				
			||||||
 | 
					                   h_min=-0.1, h_max=0.1,
 | 
				
			||||||
 | 
					                   s_min=0.5, s_max=2.0,
 | 
				
			||||||
 | 
					                   v_min=0.5, v_max=2.0):
 | 
				
			||||||
 | 
					    if h is None:
 | 
				
			||||||
 | 
					        h = h_min + (h_max - h_min) * random.random()
 | 
				
			||||||
 | 
					    if s is None:
 | 
				
			||||||
 | 
					        s = s_min + (s_max - s_min) * random.random()
 | 
				
			||||||
 | 
					    if v is None:
 | 
				
			||||||
 | 
					        v = v_min + (v_max - v_min) * random.random()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_list = []
 | 
				
			||||||
 | 
					    for image_hsv in image_list:
 | 
				
			||||||
 | 
					        # assert image_hsv.shape[-1] == 3, repr(image_hsv.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        image_hsv[:,:,0::3] += h
 | 
				
			||||||
 | 
					        image_hsv[:,:,1::3] = image_hsv[:,:,1::3] ** s
 | 
				
			||||||
 | 
					        image_hsv[:,:,2::3] = image_hsv[:,:,2::3] ** v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_list.append(image_hsv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return clampHsv(new_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def clampHsv(image_list):
 | 
				
			||||||
 | 
					    new_list = []
 | 
				
			||||||
 | 
					    for image_hsv in image_list:
 | 
				
			||||||
 | 
					        image_hsv = image_hsv.clone()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Hue wraps around
 | 
				
			||||||
 | 
					        image_hsv[:,:,0][image_hsv[:,:,0] > 1] -= 1
 | 
				
			||||||
 | 
					        image_hsv[:,:,0][image_hsv[:,:,0] < 0] += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Everything else clamps between 0 and 1
 | 
				
			||||||
 | 
					        image_hsv[image_hsv > 1] = 1
 | 
				
			||||||
 | 
					        image_hsv[image_hsv < 0] = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_list.append(image_hsv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# def torch_augment(input):
 | 
				
			||||||
 | 
					#     theta = random.random() * math.pi * 2
 | 
				
			||||||
 | 
					#     s = math.sin(theta)
 | 
				
			||||||
 | 
					#     c = math.cos(theta)
 | 
				
			||||||
 | 
					#     c1 = 1 - c
 | 
				
			||||||
 | 
					#     axis_vector = torch.rand(3, device='cpu', dtype=torch.float64)
 | 
				
			||||||
 | 
					#     axis_vector -= 0.5
 | 
				
			||||||
 | 
					#     axis_vector /= axis_vector.abs().sum()
 | 
				
			||||||
 | 
					#     l, m, n = axis_vector
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     matrix = torch.tensor([
 | 
				
			||||||
 | 
					#         [l*l*c1 +   c, m*l*c1 - n*s, n*l*c1 + m*s, 0],
 | 
				
			||||||
 | 
					#         [l*m*c1 + n*s, m*m*c1 +   c, n*m*c1 - l*s, 0],
 | 
				
			||||||
 | 
					#         [l*n*c1 - m*s, m*n*c1 + l*s, n*n*c1 +   c, 0],
 | 
				
			||||||
 | 
					#         [0, 0, 0, 1],
 | 
				
			||||||
 | 
					#     ], device=input.device, dtype=torch.float32)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     return th_affine3d(input, matrix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# following from https://github.com/ncullen93/torchsample/blob/master/torchsample/utils.py
 | 
				
			||||||
 | 
					# MIT licensed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# def th_affine3d(input, matrix):
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					#     3D Affine image transform on torch.Tensor
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					#     A = matrix[:3,:3]
 | 
				
			||||||
 | 
					#     b = matrix[:3,3]
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     # make a meshgrid of normal coordinates
 | 
				
			||||||
 | 
					#     coords = th_iterproduct(input.size(-3), input.size(-2), input.size(-1), dtype=torch.float32)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     # shift the coordinates so center is the origin
 | 
				
			||||||
 | 
					#     coords[:,0] = coords[:,0] - (input.size(-3) / 2. - 0.5)
 | 
				
			||||||
 | 
					#     coords[:,1] = coords[:,1] - (input.size(-2) / 2. - 0.5)
 | 
				
			||||||
 | 
					#     coords[:,2] = coords[:,2] - (input.size(-1) / 2. - 0.5)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     # apply the coordinate transformation
 | 
				
			||||||
 | 
					#     new_coords = coords.mm(A.t().contiguous()) + b.expand_as(coords)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     # shift the coordinates back so origin is origin
 | 
				
			||||||
 | 
					#     new_coords[:,0] = new_coords[:,0] + (input.size(-3) / 2. - 0.5)
 | 
				
			||||||
 | 
					#     new_coords[:,1] = new_coords[:,1] + (input.size(-2) / 2. - 0.5)
 | 
				
			||||||
 | 
					#     new_coords[:,2] = new_coords[:,2] + (input.size(-1) / 2. - 0.5)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     # map new coordinates using bilinear interpolation
 | 
				
			||||||
 | 
					#     input_transformed = th_trilinear_interp3d(input, new_coords)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     return input_transformed
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# def th_trilinear_interp3d(input, coords):
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					#     trilinear interpolation of 3D torch.Tensor image
 | 
				
			||||||
 | 
					#     """
 | 
				
			||||||
 | 
					#     # take clamp then floor/ceil of x coords
 | 
				
			||||||
 | 
					#     x = torch.clamp(coords[:,0], 0, input.size(-3)-2)
 | 
				
			||||||
 | 
					#     x0 = x.floor()
 | 
				
			||||||
 | 
					#     x1 = x0 + 1
 | 
				
			||||||
 | 
					#     # take clamp then floor/ceil of y coords
 | 
				
			||||||
 | 
					#     y = torch.clamp(coords[:,1], 0, input.size(-2)-2)
 | 
				
			||||||
 | 
					#     y0 = y.floor()
 | 
				
			||||||
 | 
					#     y1 = y0 + 1
 | 
				
			||||||
 | 
					#     # take clamp then floor/ceil of z coords
 | 
				
			||||||
 | 
					#     z = torch.clamp(coords[:,2], 0, input.size(-1)-2)
 | 
				
			||||||
 | 
					#     z0 = z.floor()
 | 
				
			||||||
 | 
					#     z1 = z0 + 1
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     stride = torch.tensor(input.stride()[-3:], dtype=torch.int64, device=input.device)
 | 
				
			||||||
 | 
					#     x0_ix = x0.mul(stride[0]).long()
 | 
				
			||||||
 | 
					#     x1_ix = x1.mul(stride[0]).long()
 | 
				
			||||||
 | 
					#     y0_ix = y0.mul(stride[1]).long()
 | 
				
			||||||
 | 
					#     y1_ix = y1.mul(stride[1]).long()
 | 
				
			||||||
 | 
					#     z0_ix = z0.mul(stride[2]).long()
 | 
				
			||||||
 | 
					#     z1_ix = z1.mul(stride[2]).long()
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     # input_flat = th_flatten(input)
 | 
				
			||||||
 | 
					#     input_flat = x.contiguous().view(x[0], x[1], -1)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     vals_000 = input_flat[:, :, x0_ix+y0_ix+z0_ix]
 | 
				
			||||||
 | 
					#     vals_001 = input_flat[:, :, x0_ix+y0_ix+z1_ix]
 | 
				
			||||||
 | 
					#     vals_010 = input_flat[:, :, x0_ix+y1_ix+z0_ix]
 | 
				
			||||||
 | 
					#     vals_011 = input_flat[:, :, x0_ix+y1_ix+z1_ix]
 | 
				
			||||||
 | 
					#     vals_100 = input_flat[:, :, x1_ix+y0_ix+z0_ix]
 | 
				
			||||||
 | 
					#     vals_101 = input_flat[:, :, x1_ix+y0_ix+z1_ix]
 | 
				
			||||||
 | 
					#     vals_110 = input_flat[:, :, x1_ix+y1_ix+z0_ix]
 | 
				
			||||||
 | 
					#     vals_111 = input_flat[:, :, x1_ix+y1_ix+z1_ix]
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     xd = x - x0
 | 
				
			||||||
 | 
					#     yd = y - y0
 | 
				
			||||||
 | 
					#     zd = z - z0
 | 
				
			||||||
 | 
					#     xm1 = 1 - xd
 | 
				
			||||||
 | 
					#     ym1 = 1 - yd
 | 
				
			||||||
 | 
					#     zm1 = 1 - zd
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     x_mapped = (
 | 
				
			||||||
 | 
					#             vals_000.mul(xm1).mul(ym1).mul(zm1) +
 | 
				
			||||||
 | 
					#             vals_001.mul(xm1).mul(ym1).mul(zd) +
 | 
				
			||||||
 | 
					#             vals_010.mul(xm1).mul(yd).mul(zm1) +
 | 
				
			||||||
 | 
					#             vals_011.mul(xm1).mul(yd).mul(zd) +
 | 
				
			||||||
 | 
					#             vals_100.mul(xd).mul(ym1).mul(zm1) +
 | 
				
			||||||
 | 
					#             vals_101.mul(xd).mul(ym1).mul(zd) +
 | 
				
			||||||
 | 
					#             vals_110.mul(xd).mul(yd).mul(zm1) +
 | 
				
			||||||
 | 
					#             vals_111.mul(xd).mul(yd).mul(zd)
 | 
				
			||||||
 | 
					#     )
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     return x_mapped.view_as(input)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# def th_iterproduct(*args, dtype=None):
 | 
				
			||||||
 | 
					#     return torch.from_numpy(np.indices(args).reshape((len(args),-1)).T)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# def th_flatten(x):
 | 
				
			||||||
 | 
					#     """Flatten tensor"""
 | 
				
			||||||
 | 
					#     return x.contiguous().view(x[0], x[1], -1)
 | 
				
			||||||
@ -0,0 +1,136 @@
 | 
				
			|||||||
 | 
					import gzip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from diskcache import FanoutCache, Disk
 | 
				
			||||||
 | 
					from diskcache.core import BytesType, MODE_BINARY, BytesIO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from util.logconf import logging
 | 
				
			||||||
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					# log.setLevel(logging.WARN)
 | 
				
			||||||
 | 
					log.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					# log.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GzipDisk(Disk):
 | 
				
			||||||
 | 
					    def store(self, value, read, key=None):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Override from base class diskcache.Disk.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Chunking is due to needing to work on pythons < 2.7.13:
 | 
				
			||||||
 | 
					        - Issue #27130: In the "zlib" module, fix handling of large buffers
 | 
				
			||||||
 | 
					          (typically 2 or 4 GiB).  Previously, inputs were limited to 2 GiB, and
 | 
				
			||||||
 | 
					          compression and decompression operations did not properly handle results of
 | 
				
			||||||
 | 
					          2 or 4 GiB.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param value: value to convert
 | 
				
			||||||
 | 
					        :param bool read: True when value is file-like object
 | 
				
			||||||
 | 
					        :return: (size, mode, filename, value) tuple for Cache table
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # pylint: disable=unidiomatic-typecheck
 | 
				
			||||||
 | 
					        if type(value) is BytesType:
 | 
				
			||||||
 | 
					            if read:
 | 
				
			||||||
 | 
					                value = value.read()
 | 
				
			||||||
 | 
					                read = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            str_io = BytesIO()
 | 
				
			||||||
 | 
					            gz_file = gzip.GzipFile(mode='wb', compresslevel=1, fileobj=str_io)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for offset in range(0, len(value), 2**30):
 | 
				
			||||||
 | 
					                gz_file.write(value[offset:offset+2**30])
 | 
				
			||||||
 | 
					            gz_file.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            value = str_io.getvalue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return super(GzipDisk, self).store(value, read)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def fetch(self, mode, filename, value, read):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Override from base class diskcache.Disk.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Chunking is due to needing to work on pythons < 2.7.13:
 | 
				
			||||||
 | 
					        - Issue #27130: In the "zlib" module, fix handling of large buffers
 | 
				
			||||||
 | 
					          (typically 2 or 4 GiB).  Previously, inputs were limited to 2 GiB, and
 | 
				
			||||||
 | 
					          compression and decompression operations did not properly handle results of
 | 
				
			||||||
 | 
					          2 or 4 GiB.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param int mode: value mode raw, binary, text, or pickle
 | 
				
			||||||
 | 
					        :param str filename: filename of corresponding value
 | 
				
			||||||
 | 
					        :param value: database value
 | 
				
			||||||
 | 
					        :param bool read: when True, return an open file handle
 | 
				
			||||||
 | 
					        :return: corresponding Python value
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        value = super(GzipDisk, self).fetch(mode, filename, value, read)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if mode == MODE_BINARY:
 | 
				
			||||||
 | 
					            str_io = BytesIO(value)
 | 
				
			||||||
 | 
					            gz_file = gzip.GzipFile(mode='rb', fileobj=str_io)
 | 
				
			||||||
 | 
					            read_csio = BytesIO()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            while True:
 | 
				
			||||||
 | 
					                uncompressed_data = gz_file.read(2**30)
 | 
				
			||||||
 | 
					                if uncompressed_data:
 | 
				
			||||||
 | 
					                    read_csio.write(uncompressed_data)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            value = read_csio.getvalue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def getCache(scope_str):
 | 
				
			||||||
 | 
					    return FanoutCache('data-unversioned/cache/' + scope_str,
 | 
				
			||||||
 | 
					                       disk=GzipDisk,
 | 
				
			||||||
 | 
					                       shards=64,
 | 
				
			||||||
 | 
					                       timeout=1,
 | 
				
			||||||
 | 
					                       size_limit=3e11,
 | 
				
			||||||
 | 
					                       # disk_min_file_size=2**20,
 | 
				
			||||||
 | 
					                       )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# def disk_cache(base_path, memsize=2):
 | 
				
			||||||
 | 
					#     def disk_cache_decorator(f):
 | 
				
			||||||
 | 
					#         @functools.wraps(f)
 | 
				
			||||||
 | 
					#         def wrapper(*args, **kwargs):
 | 
				
			||||||
 | 
					#             args_str = repr(args) + repr(sorted(kwargs.items()))
 | 
				
			||||||
 | 
					#             file_str = hashlib.md5(args_str.encode('utf8')).hexdigest()
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             cache_path = os.path.join(base_path, f.__name__, file_str + '.pkl.gz')
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             if not os.path.exists(os.path.dirname(cache_path)):
 | 
				
			||||||
 | 
					#                 os.makedirs(os.path.dirname(cache_path), exist_ok=True)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             if os.path.exists(cache_path):
 | 
				
			||||||
 | 
					#                 return pickle_loadgz(cache_path)
 | 
				
			||||||
 | 
					#             else:
 | 
				
			||||||
 | 
					#                 ret = f(*args, **kwargs)
 | 
				
			||||||
 | 
					#                 pickle_dumpgz(cache_path, ret)
 | 
				
			||||||
 | 
					#                 return ret
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         return wrapper
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     return disk_cache_decorator
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# def pickle_dumpgz(file_path, obj):
 | 
				
			||||||
 | 
					#     log.debug("Writing {}".format(file_path))
 | 
				
			||||||
 | 
					#     with open(file_path, 'wb') as file_obj:
 | 
				
			||||||
 | 
					#         with gzip.GzipFile(mode='wb', compresslevel=1, fileobj=file_obj) as gz_file:
 | 
				
			||||||
 | 
					#             pickle.dump(obj, gz_file, pickle.HIGHEST_PROTOCOL)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# def pickle_loadgz(file_path):
 | 
				
			||||||
 | 
					#     log.debug("Reading {}".format(file_path))
 | 
				
			||||||
 | 
					#     with open(file_path, 'rb') as file_obj:
 | 
				
			||||||
 | 
					#         with gzip.GzipFile(mode='rb', fileobj=file_obj) as gz_file:
 | 
				
			||||||
 | 
					#             return pickle.load(gz_file)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# def dtpath(dt=None):
 | 
				
			||||||
 | 
					#     if dt is None:
 | 
				
			||||||
 | 
					#         dt = datetime.datetime.now()
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     return str(dt).rsplit('.', 1)[0].replace(' ', '--').replace(':', '.')
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# def safepath(s):
 | 
				
			||||||
 | 
					#     s = s.replace(' ', '_')
 | 
				
			||||||
 | 
					#     return re.sub('[^A-Za-z0-9_.-]', '', s)
 | 
				
			||||||
@ -0,0 +1,19 @@
 | 
				
			|||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import logging.handlers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					root_logger = logging.getLogger()
 | 
				
			||||||
 | 
					root_logger.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Some libraries attempt to add their own root logger handlers. This is
 | 
				
			||||||
 | 
					# annoying and so we get rid of them.
 | 
				
			||||||
 | 
					for handler in list(root_logger.handlers):
 | 
				
			||||||
 | 
					    root_logger.removeHandler(handler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logfmt_str = "%(asctime)s %(levelname)-8s pid:%(process)d %(name)s:%(lineno)03d:%(funcName)s %(message)s"
 | 
				
			||||||
 | 
					formatter = logging.Formatter(logfmt_str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					streamHandler = logging.StreamHandler()
 | 
				
			||||||
 | 
					streamHandler.setFormatter(formatter)
 | 
				
			||||||
 | 
					streamHandler.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					root_logger.addHandler(streamHandler)
 | 
				
			||||||
@ -0,0 +1,143 @@
 | 
				
			|||||||
 | 
					# From https://github.com/jvanvugt/pytorch-unet
 | 
				
			||||||
 | 
					# https://raw.githubusercontent.com/jvanvugt/pytorch-unet/master/unet.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# MIT License
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2018 Joris
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					# copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					# SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Adapted from https://discuss.pytorch.org/t/unet-implementation/426
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UNet(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False,
 | 
				
			||||||
 | 
					                 batch_norm=False, up_mode='upconv'):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Implementation of
 | 
				
			||||||
 | 
					        U-Net: Convolutional Networks for Biomedical Image Segmentation
 | 
				
			||||||
 | 
					        (Ronneberger et al., 2015)
 | 
				
			||||||
 | 
					        https://arxiv.org/abs/1505.04597
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Using the default arguments will yield the exact version used
 | 
				
			||||||
 | 
					        in the original paper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            in_channels (int): number of input channels
 | 
				
			||||||
 | 
					            n_classes (int): number of output channels
 | 
				
			||||||
 | 
					            depth (int): depth of the network
 | 
				
			||||||
 | 
					            wf (int): number of filters in the first layer is 2**wf
 | 
				
			||||||
 | 
					            padding (bool): if True, apply padding such that the input shape
 | 
				
			||||||
 | 
					                            is the same as the output.
 | 
				
			||||||
 | 
					                            This may introduce artifacts
 | 
				
			||||||
 | 
					            batch_norm (bool): Use BatchNorm after layers with an
 | 
				
			||||||
 | 
					                               activation function
 | 
				
			||||||
 | 
					            up_mode (str): one of 'upconv' or 'upsample'.
 | 
				
			||||||
 | 
					                           'upconv' will use transposed convolutions for
 | 
				
			||||||
 | 
					                           learned upsampling.
 | 
				
			||||||
 | 
					                           'upsample' will use bilinear upsampling.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        super(UNet, self).__init__()
 | 
				
			||||||
 | 
					        assert up_mode in ('upconv', 'upsample')
 | 
				
			||||||
 | 
					        self.padding = padding
 | 
				
			||||||
 | 
					        self.depth = depth
 | 
				
			||||||
 | 
					        prev_channels = in_channels
 | 
				
			||||||
 | 
					        self.down_path = nn.ModuleList()
 | 
				
			||||||
 | 
					        for i in range(depth):
 | 
				
			||||||
 | 
					            self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
 | 
				
			||||||
 | 
					                                                padding, batch_norm))
 | 
				
			||||||
 | 
					            prev_channels = 2**(wf+i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.up_path = nn.ModuleList()
 | 
				
			||||||
 | 
					        for i in reversed(range(depth - 1)):
 | 
				
			||||||
 | 
					            self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
 | 
				
			||||||
 | 
					                                            padding, batch_norm))
 | 
				
			||||||
 | 
					            prev_channels = 2**(wf+i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        blocks = []
 | 
				
			||||||
 | 
					        for i, down in enumerate(self.down_path):
 | 
				
			||||||
 | 
					            x = down(x)
 | 
				
			||||||
 | 
					            if i != len(self.down_path)-1:
 | 
				
			||||||
 | 
					                blocks.append(x)
 | 
				
			||||||
 | 
					                x = F.avg_pool2d(x, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i, up in enumerate(self.up_path):
 | 
				
			||||||
 | 
					            x = up(x, blocks[-i-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.last(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UNetConvBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_size, out_size, padding, batch_norm):
 | 
				
			||||||
 | 
					        super(UNetConvBlock, self).__init__()
 | 
				
			||||||
 | 
					        block = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
 | 
				
			||||||
 | 
					                               padding=int(padding)))
 | 
				
			||||||
 | 
					        block.append(nn.ReLU())
 | 
				
			||||||
 | 
					        # block.append(nn.LeakyReLU())
 | 
				
			||||||
 | 
					        if batch_norm:
 | 
				
			||||||
 | 
					            block.append(nn.BatchNorm2d(out_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
 | 
				
			||||||
 | 
					                               padding=int(padding)))
 | 
				
			||||||
 | 
					        block.append(nn.ReLU())
 | 
				
			||||||
 | 
					        # block.append(nn.LeakyReLU())
 | 
				
			||||||
 | 
					        if batch_norm:
 | 
				
			||||||
 | 
					            block.append(nn.BatchNorm2d(out_size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.block = nn.Sequential(*block)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        out = self.block(x)
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UNetUpBlock(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
 | 
				
			||||||
 | 
					        super(UNetUpBlock, self).__init__()
 | 
				
			||||||
 | 
					        if up_mode == 'upconv':
 | 
				
			||||||
 | 
					            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2,
 | 
				
			||||||
 | 
					                                         stride=2)
 | 
				
			||||||
 | 
					        elif up_mode == 'upsample':
 | 
				
			||||||
 | 
					            self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
 | 
				
			||||||
 | 
					                                    nn.Conv2d(in_size, out_size, kernel_size=1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def center_crop(self, layer, target_size):
 | 
				
			||||||
 | 
					        _, _, layer_height, layer_width = layer.size()
 | 
				
			||||||
 | 
					        diff_y = (layer_height - target_size[0]) // 2
 | 
				
			||||||
 | 
					        diff_x = (layer_width - target_size[1]) // 2
 | 
				
			||||||
 | 
					        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, bridge):
 | 
				
			||||||
 | 
					        up = self.up(x)
 | 
				
			||||||
 | 
					        crop1 = self.center_crop(bridge, up.shape[2:])
 | 
				
			||||||
 | 
					        out = torch.cat([up, crop1], 1)
 | 
				
			||||||
 | 
					        out = self.conv_block(out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user