From 4dca46861209192e8ad3cee59e39cae7a94bda54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Pokr=C3=BDvka?= Date: Thu, 12 Nov 2020 15:22:35 +0000 Subject: [PATCH] =?UTF-8?q?Nahr=C3=A1t=20soubory=20do=20=E2=80=9Epages/stu?= =?UTF-8?q?dents/2016/lukas=5Fpokryvka/dp2021/lungCancer/util=E2=80=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dp2021/lungCancer/util/__init__.py | Bin 0 -> 1024 bytes .../dp2021/lungCancer/util/augmentation.py | 331 ++++++++++++++++++ .../dp2021/lungCancer/util/disk.py | 136 +++++++ .../dp2021/lungCancer/util/logconf.py | 19 + .../dp2021/lungCancer/util/unet.py | 143 ++++++++ 5 files changed, 629 insertions(+) create mode 100644 pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/__init__.py create mode 100644 pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/augmentation.py create mode 100644 pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/disk.py create mode 100644 pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/logconf.py create mode 100644 pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/unet.py diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/__init__.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06d7405020018ddf3cacee90fd4af10487da3d20 GIT binary patch literal 1024 ScmZQz7zLvtFd70QH3R?z00031 literal 0 HcmV?d00001 diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/augmentation.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/augmentation.py new file mode 100644 index 00000000..c5345e84 --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/augmentation.py @@ -0,0 +1,331 @@ +import math +import random +import warnings + +import numpy as np +import scipy.ndimage + +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable +import torch.backends.cudnn as cudnn + +from util.logconf import logging +log = logging.getLogger(__name__) +# log.setLevel(logging.WARN) +# log.setLevel(logging.INFO) +log.setLevel(logging.DEBUG) + +def cropToShape(image, new_shape, center_list=None, fill=0.0): + # log.debug([image.shape, new_shape, center_list]) + # assert len(image.shape) == 3, repr(image.shape) + + if center_list is None: + center_list = [int(image.shape[i] / 2) for i in range(3)] + + crop_list = [] + for i in range(0, 3): + crop_int = center_list[i] + if image.shape[i] > new_shape[i] and crop_int is not None: + + # We can't just do crop_int +/- shape/2 since shape might be odd + # and ints round down. + start_int = crop_int - int(new_shape[i]/2) + end_int = start_int + new_shape[i] + crop_list.append(slice(max(0, start_int), end_int)) + else: + crop_list.append(slice(0, image.shape[i])) + + # log.debug([image.shape, crop_list]) + image = image[crop_list] + + crop_list = [] + for i in range(0, 3): + if image.shape[i] < new_shape[i]: + crop_int = int((new_shape[i] - image.shape[i]) / 2) + crop_list.append(slice(crop_int, crop_int + image.shape[i])) + else: + crop_list.append(slice(0, image.shape[i])) + + # log.debug([image.shape, crop_list]) + new_image = np.zeros(new_shape, dtype=image.dtype) + new_image[:] = fill + new_image[crop_list] = image + + return new_image + + +def zoomToShape(image, new_shape, square=True): + # assert image.shape[-1] in {1, 3, 4}, repr(image.shape) + + if square and image.shape[0] != image.shape[1]: + crop_int = min(image.shape[0], image.shape[1]) + new_shape = [crop_int, crop_int, image.shape[2]] + image = cropToShape(image, new_shape) + + zoom_shape = [new_shape[i] / image.shape[i] for i in range(3)] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + image = scipy.ndimage.interpolation.zoom( + image, zoom_shape, + output=None, order=0, mode='nearest', cval=0.0, prefilter=True) + + return image + +def randomOffset(image_list, offset_rows=0.125, offset_cols=0.125): + + center_list = [int(image_list[0].shape[i] / 2) for i in range(3)] + center_list[0] += int(offset_rows * (random.random() - 0.5) * 2) + center_list[1] += int(offset_cols * (random.random() - 0.5) * 2) + center_list[2] = None + + new_list = [] + for image in image_list: + new_image = cropToShape(image, image.shape, center_list) + new_list.append(new_image) + + return new_list + + +def randomZoom(image_list, scale=None, scale_min=0.8, scale_max=1.3): + if scale is None: + scale = scale_min + (scale_max - scale_min) * random.random() + + new_list = [] + for image in image_list: + # assert image.shape[-1] in {1, 3, 4}, repr(image.shape) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # log.info([image.shape]) + zimage = scipy.ndimage.interpolation.zoom( + image, [scale, scale, 1.0], + output=None, order=0, mode='nearest', cval=0.0, prefilter=True) + image = cropToShape(zimage, image.shape) + + new_list.append(image) + + return new_list + + +_randomFlip_transform_list = [ + # lambda a: np.rot90(a, axes=(0, 1)), + # lambda a: np.flip(a, 0), + lambda a: np.flip(a, 1), +] + +def randomFlip(image_list, transform_bits=None): + if transform_bits is None: + transform_bits = random.randrange(0, 2 ** len(_randomFlip_transform_list)) + + new_list = [] + for image in image_list: + # assert image.shape[-1] in {1, 3, 4}, repr(image.shape) + + for n in range(len(_randomFlip_transform_list)): + if transform_bits & 2**n: + # prhist(image, 'before') + image = _randomFlip_transform_list[n](image) + # prhist(image, 'after ') + + new_list.append(image) + + return new_list + + +def randomSpin(image_list, angle=None, range_tup=None, axes=(0, 1)): + if range_tup is None: + range_tup = (0, 360) + + if angle is None: + angle = range_tup[0] + (range_tup[1] - range_tup[0]) * random.random() + + new_list = [] + for image in image_list: + # assert image.shape[-1] in {1, 3, 4}, repr(image.shape) + + image = scipy.ndimage.interpolation.rotate( + image, angle, axes=axes, reshape=False, + output=None, order=0, mode='nearest', cval=0.0, prefilter=True) + + new_list.append(image) + + return new_list + + +def randomNoise(image_list, noise_min=-0.1, noise_max=0.1): + noise = np.zeros_like(image_list[0]) + noise += (noise_max - noise_min) * np.random.random_sample(image_list[0].shape) + noise_min + noise *= 5 + noise = scipy.ndimage.filters.gaussian_filter(noise, 3) + # noise += (noise_max - noise_min) * np.random.random_sample(image_hsv.shape) + noise_min + + new_list = [] + for image_hsv in image_list: + image_hsv = image_hsv + noise + + new_list.append(image_hsv) + + return new_list + + +def randomHsvShift(image_list, h=None, s=None, v=None, + h_min=-0.1, h_max=0.1, + s_min=0.5, s_max=2.0, + v_min=0.5, v_max=2.0): + if h is None: + h = h_min + (h_max - h_min) * random.random() + if s is None: + s = s_min + (s_max - s_min) * random.random() + if v is None: + v = v_min + (v_max - v_min) * random.random() + + new_list = [] + for image_hsv in image_list: + # assert image_hsv.shape[-1] == 3, repr(image_hsv.shape) + + image_hsv[:,:,0::3] += h + image_hsv[:,:,1::3] = image_hsv[:,:,1::3] ** s + image_hsv[:,:,2::3] = image_hsv[:,:,2::3] ** v + + new_list.append(image_hsv) + + return clampHsv(new_list) + + +def clampHsv(image_list): + new_list = [] + for image_hsv in image_list: + image_hsv = image_hsv.clone() + + # Hue wraps around + image_hsv[:,:,0][image_hsv[:,:,0] > 1] -= 1 + image_hsv[:,:,0][image_hsv[:,:,0] < 0] += 1 + + # Everything else clamps between 0 and 1 + image_hsv[image_hsv > 1] = 1 + image_hsv[image_hsv < 0] = 0 + + new_list.append(image_hsv) + + return new_list + + +# def torch_augment(input): +# theta = random.random() * math.pi * 2 +# s = math.sin(theta) +# c = math.cos(theta) +# c1 = 1 - c +# axis_vector = torch.rand(3, device='cpu', dtype=torch.float64) +# axis_vector -= 0.5 +# axis_vector /= axis_vector.abs().sum() +# l, m, n = axis_vector +# +# matrix = torch.tensor([ +# [l*l*c1 + c, m*l*c1 - n*s, n*l*c1 + m*s, 0], +# [l*m*c1 + n*s, m*m*c1 + c, n*m*c1 - l*s, 0], +# [l*n*c1 - m*s, m*n*c1 + l*s, n*n*c1 + c, 0], +# [0, 0, 0, 1], +# ], device=input.device, dtype=torch.float32) +# +# return th_affine3d(input, matrix) + + + + +# following from https://github.com/ncullen93/torchsample/blob/master/torchsample/utils.py +# MIT licensed + +# def th_affine3d(input, matrix): +# """ +# 3D Affine image transform on torch.Tensor +# """ +# A = matrix[:3,:3] +# b = matrix[:3,3] +# +# # make a meshgrid of normal coordinates +# coords = th_iterproduct(input.size(-3), input.size(-2), input.size(-1), dtype=torch.float32) +# +# # shift the coordinates so center is the origin +# coords[:,0] = coords[:,0] - (input.size(-3) / 2. - 0.5) +# coords[:,1] = coords[:,1] - (input.size(-2) / 2. - 0.5) +# coords[:,2] = coords[:,2] - (input.size(-1) / 2. - 0.5) +# +# # apply the coordinate transformation +# new_coords = coords.mm(A.t().contiguous()) + b.expand_as(coords) +# +# # shift the coordinates back so origin is origin +# new_coords[:,0] = new_coords[:,0] + (input.size(-3) / 2. - 0.5) +# new_coords[:,1] = new_coords[:,1] + (input.size(-2) / 2. - 0.5) +# new_coords[:,2] = new_coords[:,2] + (input.size(-1) / 2. - 0.5) +# +# # map new coordinates using bilinear interpolation +# input_transformed = th_trilinear_interp3d(input, new_coords) +# +# return input_transformed +# +# +# def th_trilinear_interp3d(input, coords): +# """ +# trilinear interpolation of 3D torch.Tensor image +# """ +# # take clamp then floor/ceil of x coords +# x = torch.clamp(coords[:,0], 0, input.size(-3)-2) +# x0 = x.floor() +# x1 = x0 + 1 +# # take clamp then floor/ceil of y coords +# y = torch.clamp(coords[:,1], 0, input.size(-2)-2) +# y0 = y.floor() +# y1 = y0 + 1 +# # take clamp then floor/ceil of z coords +# z = torch.clamp(coords[:,2], 0, input.size(-1)-2) +# z0 = z.floor() +# z1 = z0 + 1 +# +# stride = torch.tensor(input.stride()[-3:], dtype=torch.int64, device=input.device) +# x0_ix = x0.mul(stride[0]).long() +# x1_ix = x1.mul(stride[0]).long() +# y0_ix = y0.mul(stride[1]).long() +# y1_ix = y1.mul(stride[1]).long() +# z0_ix = z0.mul(stride[2]).long() +# z1_ix = z1.mul(stride[2]).long() +# +# # input_flat = th_flatten(input) +# input_flat = x.contiguous().view(x[0], x[1], -1) +# +# vals_000 = input_flat[:, :, x0_ix+y0_ix+z0_ix] +# vals_001 = input_flat[:, :, x0_ix+y0_ix+z1_ix] +# vals_010 = input_flat[:, :, x0_ix+y1_ix+z0_ix] +# vals_011 = input_flat[:, :, x0_ix+y1_ix+z1_ix] +# vals_100 = input_flat[:, :, x1_ix+y0_ix+z0_ix] +# vals_101 = input_flat[:, :, x1_ix+y0_ix+z1_ix] +# vals_110 = input_flat[:, :, x1_ix+y1_ix+z0_ix] +# vals_111 = input_flat[:, :, x1_ix+y1_ix+z1_ix] +# +# xd = x - x0 +# yd = y - y0 +# zd = z - z0 +# xm1 = 1 - xd +# ym1 = 1 - yd +# zm1 = 1 - zd +# +# x_mapped = ( +# vals_000.mul(xm1).mul(ym1).mul(zm1) + +# vals_001.mul(xm1).mul(ym1).mul(zd) + +# vals_010.mul(xm1).mul(yd).mul(zm1) + +# vals_011.mul(xm1).mul(yd).mul(zd) + +# vals_100.mul(xd).mul(ym1).mul(zm1) + +# vals_101.mul(xd).mul(ym1).mul(zd) + +# vals_110.mul(xd).mul(yd).mul(zm1) + +# vals_111.mul(xd).mul(yd).mul(zd) +# ) +# +# return x_mapped.view_as(input) +# +# def th_iterproduct(*args, dtype=None): +# return torch.from_numpy(np.indices(args).reshape((len(args),-1)).T) +# +# def th_flatten(x): +# """Flatten tensor""" +# return x.contiguous().view(x[0], x[1], -1) diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/disk.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/disk.py new file mode 100644 index 00000000..091d2bb6 --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/disk.py @@ -0,0 +1,136 @@ +import gzip + +from diskcache import FanoutCache, Disk +from diskcache.core import BytesType, MODE_BINARY, BytesIO + +from util.logconf import logging +log = logging.getLogger(__name__) +# log.setLevel(logging.WARN) +log.setLevel(logging.INFO) +# log.setLevel(logging.DEBUG) + + +class GzipDisk(Disk): + def store(self, value, read, key=None): + """ + Override from base class diskcache.Disk. + + Chunking is due to needing to work on pythons < 2.7.13: + - Issue #27130: In the "zlib" module, fix handling of large buffers + (typically 2 or 4 GiB). Previously, inputs were limited to 2 GiB, and + compression and decompression operations did not properly handle results of + 2 or 4 GiB. + + :param value: value to convert + :param bool read: True when value is file-like object + :return: (size, mode, filename, value) tuple for Cache table + """ + # pylint: disable=unidiomatic-typecheck + if type(value) is BytesType: + if read: + value = value.read() + read = False + + str_io = BytesIO() + gz_file = gzip.GzipFile(mode='wb', compresslevel=1, fileobj=str_io) + + for offset in range(0, len(value), 2**30): + gz_file.write(value[offset:offset+2**30]) + gz_file.close() + + value = str_io.getvalue() + + return super(GzipDisk, self).store(value, read) + + + def fetch(self, mode, filename, value, read): + """ + Override from base class diskcache.Disk. + + Chunking is due to needing to work on pythons < 2.7.13: + - Issue #27130: In the "zlib" module, fix handling of large buffers + (typically 2 or 4 GiB). Previously, inputs were limited to 2 GiB, and + compression and decompression operations did not properly handle results of + 2 or 4 GiB. + + :param int mode: value mode raw, binary, text, or pickle + :param str filename: filename of corresponding value + :param value: database value + :param bool read: when True, return an open file handle + :return: corresponding Python value + """ + value = super(GzipDisk, self).fetch(mode, filename, value, read) + + if mode == MODE_BINARY: + str_io = BytesIO(value) + gz_file = gzip.GzipFile(mode='rb', fileobj=str_io) + read_csio = BytesIO() + + while True: + uncompressed_data = gz_file.read(2**30) + if uncompressed_data: + read_csio.write(uncompressed_data) + else: + break + + value = read_csio.getvalue() + + return value + +def getCache(scope_str): + return FanoutCache('data-unversioned/cache/' + scope_str, + disk=GzipDisk, + shards=64, + timeout=1, + size_limit=3e11, + # disk_min_file_size=2**20, + ) + +# def disk_cache(base_path, memsize=2): +# def disk_cache_decorator(f): +# @functools.wraps(f) +# def wrapper(*args, **kwargs): +# args_str = repr(args) + repr(sorted(kwargs.items())) +# file_str = hashlib.md5(args_str.encode('utf8')).hexdigest() +# +# cache_path = os.path.join(base_path, f.__name__, file_str + '.pkl.gz') +# +# if not os.path.exists(os.path.dirname(cache_path)): +# os.makedirs(os.path.dirname(cache_path), exist_ok=True) +# +# if os.path.exists(cache_path): +# return pickle_loadgz(cache_path) +# else: +# ret = f(*args, **kwargs) +# pickle_dumpgz(cache_path, ret) +# return ret +# +# return wrapper +# +# return disk_cache_decorator +# +# +# def pickle_dumpgz(file_path, obj): +# log.debug("Writing {}".format(file_path)) +# with open(file_path, 'wb') as file_obj: +# with gzip.GzipFile(mode='wb', compresslevel=1, fileobj=file_obj) as gz_file: +# pickle.dump(obj, gz_file, pickle.HIGHEST_PROTOCOL) +# +# +# def pickle_loadgz(file_path): +# log.debug("Reading {}".format(file_path)) +# with open(file_path, 'rb') as file_obj: +# with gzip.GzipFile(mode='rb', fileobj=file_obj) as gz_file: +# return pickle.load(gz_file) +# +# +# def dtpath(dt=None): +# if dt is None: +# dt = datetime.datetime.now() +# +# return str(dt).rsplit('.', 1)[0].replace(' ', '--').replace(':', '.') +# +# +# def safepath(s): +# s = s.replace(' ', '_') +# return re.sub('[^A-Za-z0-9_.-]', '', s) diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/logconf.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/logconf.py new file mode 100644 index 00000000..65f7b9da --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/logconf.py @@ -0,0 +1,19 @@ +import logging +import logging.handlers + +root_logger = logging.getLogger() +root_logger.setLevel(logging.INFO) + +# Some libraries attempt to add their own root logger handlers. This is +# annoying and so we get rid of them. +for handler in list(root_logger.handlers): + root_logger.removeHandler(handler) + +logfmt_str = "%(asctime)s %(levelname)-8s pid:%(process)d %(name)s:%(lineno)03d:%(funcName)s %(message)s" +formatter = logging.Formatter(logfmt_str) + +streamHandler = logging.StreamHandler() +streamHandler.setFormatter(formatter) +streamHandler.setLevel(logging.DEBUG) + +root_logger.addHandler(streamHandler) diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/unet.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/unet.py new file mode 100644 index 00000000..9e16a525 --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/util/unet.py @@ -0,0 +1,143 @@ +# From https://github.com/jvanvugt/pytorch-unet +# https://raw.githubusercontent.com/jvanvugt/pytorch-unet/master/unet.py + +# MIT License +# +# Copyright (c) 2018 Joris +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Adapted from https://discuss.pytorch.org/t/unet-implementation/426 + +import torch +from torch import nn +import torch.nn.functional as F + + +class UNet(nn.Module): + def __init__(self, in_channels=1, n_classes=2, depth=5, wf=6, padding=False, + batch_norm=False, up_mode='upconv'): + """ + Implementation of + U-Net: Convolutional Networks for Biomedical Image Segmentation + (Ronneberger et al., 2015) + https://arxiv.org/abs/1505.04597 + + Using the default arguments will yield the exact version used + in the original paper + + Args: + in_channels (int): number of input channels + n_classes (int): number of output channels + depth (int): depth of the network + wf (int): number of filters in the first layer is 2**wf + padding (bool): if True, apply padding such that the input shape + is the same as the output. + This may introduce artifacts + batch_norm (bool): Use BatchNorm after layers with an + activation function + up_mode (str): one of 'upconv' or 'upsample'. + 'upconv' will use transposed convolutions for + learned upsampling. + 'upsample' will use bilinear upsampling. + """ + super(UNet, self).__init__() + assert up_mode in ('upconv', 'upsample') + self.padding = padding + self.depth = depth + prev_channels = in_channels + self.down_path = nn.ModuleList() + for i in range(depth): + self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i), + padding, batch_norm)) + prev_channels = 2**(wf+i) + + self.up_path = nn.ModuleList() + for i in reversed(range(depth - 1)): + self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode, + padding, batch_norm)) + prev_channels = 2**(wf+i) + + self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1) + + def forward(self, x): + blocks = [] + for i, down in enumerate(self.down_path): + x = down(x) + if i != len(self.down_path)-1: + blocks.append(x) + x = F.avg_pool2d(x, 2) + + for i, up in enumerate(self.up_path): + x = up(x, blocks[-i-1]) + + return self.last(x) + + +class UNetConvBlock(nn.Module): + def __init__(self, in_size, out_size, padding, batch_norm): + super(UNetConvBlock, self).__init__() + block = [] + + block.append(nn.Conv2d(in_size, out_size, kernel_size=3, + padding=int(padding))) + block.append(nn.ReLU()) + # block.append(nn.LeakyReLU()) + if batch_norm: + block.append(nn.BatchNorm2d(out_size)) + + block.append(nn.Conv2d(out_size, out_size, kernel_size=3, + padding=int(padding))) + block.append(nn.ReLU()) + # block.append(nn.LeakyReLU()) + if batch_norm: + block.append(nn.BatchNorm2d(out_size)) + + self.block = nn.Sequential(*block) + + def forward(self, x): + out = self.block(x) + return out + + +class UNetUpBlock(nn.Module): + def __init__(self, in_size, out_size, up_mode, padding, batch_norm): + super(UNetUpBlock, self).__init__() + if up_mode == 'upconv': + self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, + stride=2) + elif up_mode == 'upsample': + self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2), + nn.Conv2d(in_size, out_size, kernel_size=1)) + + self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm) + + def center_crop(self, layer, target_size): + _, _, layer_height, layer_width = layer.size() + diff_y = (layer_height - target_size[0]) // 2 + diff_x = (layer_width - target_size[1]) // 2 + return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])] + + def forward(self, x, bridge): + up = self.up(x) + crop1 = self.center_crop(bridge, up.shape[2:]) + out = torch.cat([up, crop1], 1) + out = self.conv_block(out) + + return out