zpwiki/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/augmentation.py

332 lines
10 KiB
Python
Raw Permalink Normal View History

import math
import random
import warnings
import numpy as np
import scipy.ndimage
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import torch.backends.cudnn as cudnn
from util.logconf import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
def cropToShape(image, new_shape, center_list=None, fill=0.0):
# log.debug([image.shape, new_shape, center_list])
# assert len(image.shape) == 3, repr(image.shape)
if center_list is None:
center_list = [int(image.shape[i] / 2) for i in range(3)]
crop_list = []
for i in range(0, 3):
crop_int = center_list[i]
if image.shape[i] > new_shape[i] and crop_int is not None:
# We can't just do crop_int +/- shape/2 since shape might be odd
# and ints round down.
start_int = crop_int - int(new_shape[i]/2)
end_int = start_int + new_shape[i]
crop_list.append(slice(max(0, start_int), end_int))
else:
crop_list.append(slice(0, image.shape[i]))
# log.debug([image.shape, crop_list])
image = image[crop_list]
crop_list = []
for i in range(0, 3):
if image.shape[i] < new_shape[i]:
crop_int = int((new_shape[i] - image.shape[i]) / 2)
crop_list.append(slice(crop_int, crop_int + image.shape[i]))
else:
crop_list.append(slice(0, image.shape[i]))
# log.debug([image.shape, crop_list])
new_image = np.zeros(new_shape, dtype=image.dtype)
new_image[:] = fill
new_image[crop_list] = image
return new_image
def zoomToShape(image, new_shape, square=True):
# assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
if square and image.shape[0] != image.shape[1]:
crop_int = min(image.shape[0], image.shape[1])
new_shape = [crop_int, crop_int, image.shape[2]]
image = cropToShape(image, new_shape)
zoom_shape = [new_shape[i] / image.shape[i] for i in range(3)]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
image = scipy.ndimage.interpolation.zoom(
image, zoom_shape,
output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
return image
def randomOffset(image_list, offset_rows=0.125, offset_cols=0.125):
center_list = [int(image_list[0].shape[i] / 2) for i in range(3)]
center_list[0] += int(offset_rows * (random.random() - 0.5) * 2)
center_list[1] += int(offset_cols * (random.random() - 0.5) * 2)
center_list[2] = None
new_list = []
for image in image_list:
new_image = cropToShape(image, image.shape, center_list)
new_list.append(new_image)
return new_list
def randomZoom(image_list, scale=None, scale_min=0.8, scale_max=1.3):
if scale is None:
scale = scale_min + (scale_max - scale_min) * random.random()
new_list = []
for image in image_list:
# assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# log.info([image.shape])
zimage = scipy.ndimage.interpolation.zoom(
image, [scale, scale, 1.0],
output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
image = cropToShape(zimage, image.shape)
new_list.append(image)
return new_list
_randomFlip_transform_list = [
# lambda a: np.rot90(a, axes=(0, 1)),
# lambda a: np.flip(a, 0),
lambda a: np.flip(a, 1),
]
def randomFlip(image_list, transform_bits=None):
if transform_bits is None:
transform_bits = random.randrange(0, 2 ** len(_randomFlip_transform_list))
new_list = []
for image in image_list:
# assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
for n in range(len(_randomFlip_transform_list)):
if transform_bits & 2**n:
# prhist(image, 'before')
image = _randomFlip_transform_list[n](image)
# prhist(image, 'after ')
new_list.append(image)
return new_list
def randomSpin(image_list, angle=None, range_tup=None, axes=(0, 1)):
if range_tup is None:
range_tup = (0, 360)
if angle is None:
angle = range_tup[0] + (range_tup[1] - range_tup[0]) * random.random()
new_list = []
for image in image_list:
# assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
image = scipy.ndimage.interpolation.rotate(
image, angle, axes=axes, reshape=False,
output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
new_list.append(image)
return new_list
def randomNoise(image_list, noise_min=-0.1, noise_max=0.1):
noise = np.zeros_like(image_list[0])
noise += (noise_max - noise_min) * np.random.random_sample(image_list[0].shape) + noise_min
noise *= 5
noise = scipy.ndimage.filters.gaussian_filter(noise, 3)
# noise += (noise_max - noise_min) * np.random.random_sample(image_hsv.shape) + noise_min
new_list = []
for image_hsv in image_list:
image_hsv = image_hsv + noise
new_list.append(image_hsv)
return new_list
def randomHsvShift(image_list, h=None, s=None, v=None,
h_min=-0.1, h_max=0.1,
s_min=0.5, s_max=2.0,
v_min=0.5, v_max=2.0):
if h is None:
h = h_min + (h_max - h_min) * random.random()
if s is None:
s = s_min + (s_max - s_min) * random.random()
if v is None:
v = v_min + (v_max - v_min) * random.random()
new_list = []
for image_hsv in image_list:
# assert image_hsv.shape[-1] == 3, repr(image_hsv.shape)
image_hsv[:,:,0::3] += h
image_hsv[:,:,1::3] = image_hsv[:,:,1::3] ** s
image_hsv[:,:,2::3] = image_hsv[:,:,2::3] ** v
new_list.append(image_hsv)
return clampHsv(new_list)
def clampHsv(image_list):
new_list = []
for image_hsv in image_list:
image_hsv = image_hsv.clone()
# Hue wraps around
image_hsv[:,:,0][image_hsv[:,:,0] > 1] -= 1
image_hsv[:,:,0][image_hsv[:,:,0] < 0] += 1
# Everything else clamps between 0 and 1
image_hsv[image_hsv > 1] = 1
image_hsv[image_hsv < 0] = 0
new_list.append(image_hsv)
return new_list
# def torch_augment(input):
# theta = random.random() * math.pi * 2
# s = math.sin(theta)
# c = math.cos(theta)
# c1 = 1 - c
# axis_vector = torch.rand(3, device='cpu', dtype=torch.float64)
# axis_vector -= 0.5
# axis_vector /= axis_vector.abs().sum()
# l, m, n = axis_vector
#
# matrix = torch.tensor([
# [l*l*c1 + c, m*l*c1 - n*s, n*l*c1 + m*s, 0],
# [l*m*c1 + n*s, m*m*c1 + c, n*m*c1 - l*s, 0],
# [l*n*c1 - m*s, m*n*c1 + l*s, n*n*c1 + c, 0],
# [0, 0, 0, 1],
# ], device=input.device, dtype=torch.float32)
#
# return th_affine3d(input, matrix)
# following from https://github.com/ncullen93/torchsample/blob/master/torchsample/utils.py
# MIT licensed
# def th_affine3d(input, matrix):
# """
# 3D Affine image transform on torch.Tensor
# """
# A = matrix[:3,:3]
# b = matrix[:3,3]
#
# # make a meshgrid of normal coordinates
# coords = th_iterproduct(input.size(-3), input.size(-2), input.size(-1), dtype=torch.float32)
#
# # shift the coordinates so center is the origin
# coords[:,0] = coords[:,0] - (input.size(-3) / 2. - 0.5)
# coords[:,1] = coords[:,1] - (input.size(-2) / 2. - 0.5)
# coords[:,2] = coords[:,2] - (input.size(-1) / 2. - 0.5)
#
# # apply the coordinate transformation
# new_coords = coords.mm(A.t().contiguous()) + b.expand_as(coords)
#
# # shift the coordinates back so origin is origin
# new_coords[:,0] = new_coords[:,0] + (input.size(-3) / 2. - 0.5)
# new_coords[:,1] = new_coords[:,1] + (input.size(-2) / 2. - 0.5)
# new_coords[:,2] = new_coords[:,2] + (input.size(-1) / 2. - 0.5)
#
# # map new coordinates using bilinear interpolation
# input_transformed = th_trilinear_interp3d(input, new_coords)
#
# return input_transformed
#
#
# def th_trilinear_interp3d(input, coords):
# """
# trilinear interpolation of 3D torch.Tensor image
# """
# # take clamp then floor/ceil of x coords
# x = torch.clamp(coords[:,0], 0, input.size(-3)-2)
# x0 = x.floor()
# x1 = x0 + 1
# # take clamp then floor/ceil of y coords
# y = torch.clamp(coords[:,1], 0, input.size(-2)-2)
# y0 = y.floor()
# y1 = y0 + 1
# # take clamp then floor/ceil of z coords
# z = torch.clamp(coords[:,2], 0, input.size(-1)-2)
# z0 = z.floor()
# z1 = z0 + 1
#
# stride = torch.tensor(input.stride()[-3:], dtype=torch.int64, device=input.device)
# x0_ix = x0.mul(stride[0]).long()
# x1_ix = x1.mul(stride[0]).long()
# y0_ix = y0.mul(stride[1]).long()
# y1_ix = y1.mul(stride[1]).long()
# z0_ix = z0.mul(stride[2]).long()
# z1_ix = z1.mul(stride[2]).long()
#
# # input_flat = th_flatten(input)
# input_flat = x.contiguous().view(x[0], x[1], -1)
#
# vals_000 = input_flat[:, :, x0_ix+y0_ix+z0_ix]
# vals_001 = input_flat[:, :, x0_ix+y0_ix+z1_ix]
# vals_010 = input_flat[:, :, x0_ix+y1_ix+z0_ix]
# vals_011 = input_flat[:, :, x0_ix+y1_ix+z1_ix]
# vals_100 = input_flat[:, :, x1_ix+y0_ix+z0_ix]
# vals_101 = input_flat[:, :, x1_ix+y0_ix+z1_ix]
# vals_110 = input_flat[:, :, x1_ix+y1_ix+z0_ix]
# vals_111 = input_flat[:, :, x1_ix+y1_ix+z1_ix]
#
# xd = x - x0
# yd = y - y0
# zd = z - z0
# xm1 = 1 - xd
# ym1 = 1 - yd
# zm1 = 1 - zd
#
# x_mapped = (
# vals_000.mul(xm1).mul(ym1).mul(zm1) +
# vals_001.mul(xm1).mul(ym1).mul(zd) +
# vals_010.mul(xm1).mul(yd).mul(zm1) +
# vals_011.mul(xm1).mul(yd).mul(zd) +
# vals_100.mul(xd).mul(ym1).mul(zm1) +
# vals_101.mul(xd).mul(ym1).mul(zd) +
# vals_110.mul(xd).mul(yd).mul(zm1) +
# vals_111.mul(xd).mul(yd).mul(zd)
# )
#
# return x_mapped.view_as(input)
#
# def th_iterproduct(*args, dtype=None):
# return torch.from_numpy(np.indices(args).reshape((len(args),-1)).T)
#
# def th_flatten(x):
# """Flatten tensor"""
# return x.contiguous().view(x[0], x[1], -1)