import math
import random
import warnings

import numpy as np
import scipy.ndimage

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import torch.backends.cudnn as cudnn

from util.logconf import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

def cropToShape(image, new_shape, center_list=None, fill=0.0):
    # log.debug([image.shape, new_shape, center_list])
    # assert len(image.shape) == 3, repr(image.shape)

    if center_list is None:
        center_list = [int(image.shape[i] / 2) for i in range(3)]

    crop_list = []
    for i in range(0, 3):
        crop_int = center_list[i]
        if image.shape[i] > new_shape[i] and crop_int is not None:

            # We can't just do crop_int +/- shape/2 since shape might be odd
            # and ints round down.
            start_int = crop_int - int(new_shape[i]/2)
            end_int = start_int + new_shape[i]
            crop_list.append(slice(max(0, start_int), end_int))
        else:
            crop_list.append(slice(0, image.shape[i]))

    # log.debug([image.shape, crop_list])
    image = image[crop_list]

    crop_list = []
    for i in range(0, 3):
        if image.shape[i] < new_shape[i]:
            crop_int = int((new_shape[i] - image.shape[i]) / 2)
            crop_list.append(slice(crop_int, crop_int + image.shape[i]))
        else:
            crop_list.append(slice(0, image.shape[i]))

    # log.debug([image.shape, crop_list])
    new_image = np.zeros(new_shape, dtype=image.dtype)
    new_image[:] = fill
    new_image[crop_list] = image

    return new_image


def zoomToShape(image, new_shape, square=True):
    # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)

    if square and image.shape[0] != image.shape[1]:
        crop_int = min(image.shape[0], image.shape[1])
        new_shape = [crop_int, crop_int, image.shape[2]]
        image = cropToShape(image, new_shape)

    zoom_shape = [new_shape[i] / image.shape[i] for i in range(3)]

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        image = scipy.ndimage.interpolation.zoom(
            image, zoom_shape,
            output=None, order=0, mode='nearest', cval=0.0, prefilter=True)

    return image

def randomOffset(image_list, offset_rows=0.125, offset_cols=0.125):

    center_list = [int(image_list[0].shape[i] / 2) for i in range(3)]
    center_list[0] += int(offset_rows * (random.random() - 0.5) * 2)
    center_list[1] += int(offset_cols * (random.random() - 0.5) * 2)
    center_list[2] = None

    new_list = []
    for image in image_list:
        new_image = cropToShape(image, image.shape, center_list)
        new_list.append(new_image)

    return new_list


def randomZoom(image_list, scale=None, scale_min=0.8, scale_max=1.3):
    if scale is None:
        scale = scale_min + (scale_max - scale_min) * random.random()

    new_list = []
    for image in image_list:
        # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            # log.info([image.shape])
            zimage = scipy.ndimage.interpolation.zoom(
                image, [scale, scale, 1.0],
                output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
        image = cropToShape(zimage, image.shape)

        new_list.append(image)

    return new_list


_randomFlip_transform_list = [
    # lambda a: np.rot90(a, axes=(0, 1)),
    # lambda a: np.flip(a, 0),
    lambda a: np.flip(a, 1),
]

def randomFlip(image_list, transform_bits=None):
    if transform_bits is None:
        transform_bits = random.randrange(0, 2 ** len(_randomFlip_transform_list))

    new_list = []
    for image in image_list:
        # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)

        for n in range(len(_randomFlip_transform_list)):
            if transform_bits & 2**n:
                # prhist(image, 'before')
                image = _randomFlip_transform_list[n](image)
                # prhist(image, 'after ')

        new_list.append(image)

    return new_list


def randomSpin(image_list, angle=None, range_tup=None, axes=(0, 1)):
    if range_tup is None:
        range_tup = (0, 360)

    if angle is None:
        angle = range_tup[0] + (range_tup[1] - range_tup[0]) * random.random()

    new_list = []
    for image in image_list:
        # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)

        image = scipy.ndimage.interpolation.rotate(
                image, angle, axes=axes, reshape=False,
                output=None, order=0, mode='nearest', cval=0.0, prefilter=True)

        new_list.append(image)

    return new_list


def randomNoise(image_list, noise_min=-0.1, noise_max=0.1):
    noise = np.zeros_like(image_list[0])
    noise += (noise_max - noise_min) * np.random.random_sample(image_list[0].shape) + noise_min
    noise *= 5
    noise = scipy.ndimage.filters.gaussian_filter(noise, 3)
    # noise += (noise_max - noise_min) * np.random.random_sample(image_hsv.shape) + noise_min

    new_list = []
    for image_hsv in image_list:
        image_hsv = image_hsv + noise

        new_list.append(image_hsv)

    return new_list


def randomHsvShift(image_list, h=None, s=None, v=None,
                   h_min=-0.1, h_max=0.1,
                   s_min=0.5, s_max=2.0,
                   v_min=0.5, v_max=2.0):
    if h is None:
        h = h_min + (h_max - h_min) * random.random()
    if s is None:
        s = s_min + (s_max - s_min) * random.random()
    if v is None:
        v = v_min + (v_max - v_min) * random.random()

    new_list = []
    for image_hsv in image_list:
        # assert image_hsv.shape[-1] == 3, repr(image_hsv.shape)

        image_hsv[:,:,0::3] += h
        image_hsv[:,:,1::3] = image_hsv[:,:,1::3] ** s
        image_hsv[:,:,2::3] = image_hsv[:,:,2::3] ** v

        new_list.append(image_hsv)

    return clampHsv(new_list)


def clampHsv(image_list):
    new_list = []
    for image_hsv in image_list:
        image_hsv = image_hsv.clone()

        # Hue wraps around
        image_hsv[:,:,0][image_hsv[:,:,0] > 1] -= 1
        image_hsv[:,:,0][image_hsv[:,:,0] < 0] += 1

        # Everything else clamps between 0 and 1
        image_hsv[image_hsv > 1] = 1
        image_hsv[image_hsv < 0] = 0

        new_list.append(image_hsv)

    return new_list


# def torch_augment(input):
#     theta = random.random() * math.pi * 2
#     s = math.sin(theta)
#     c = math.cos(theta)
#     c1 = 1 - c
#     axis_vector = torch.rand(3, device='cpu', dtype=torch.float64)
#     axis_vector -= 0.5
#     axis_vector /= axis_vector.abs().sum()
#     l, m, n = axis_vector
#
#     matrix = torch.tensor([
#         [l*l*c1 +   c, m*l*c1 - n*s, n*l*c1 + m*s, 0],
#         [l*m*c1 + n*s, m*m*c1 +   c, n*m*c1 - l*s, 0],
#         [l*n*c1 - m*s, m*n*c1 + l*s, n*n*c1 +   c, 0],
#         [0, 0, 0, 1],
#     ], device=input.device, dtype=torch.float32)
#
#     return th_affine3d(input, matrix)




# following from https://github.com/ncullen93/torchsample/blob/master/torchsample/utils.py
# MIT licensed

# def th_affine3d(input, matrix):
#     """
#     3D Affine image transform on torch.Tensor
#     """
#     A = matrix[:3,:3]
#     b = matrix[:3,3]
#
#     # make a meshgrid of normal coordinates
#     coords = th_iterproduct(input.size(-3), input.size(-2), input.size(-1), dtype=torch.float32)
#
#     # shift the coordinates so center is the origin
#     coords[:,0] = coords[:,0] - (input.size(-3) / 2. - 0.5)
#     coords[:,1] = coords[:,1] - (input.size(-2) / 2. - 0.5)
#     coords[:,2] = coords[:,2] - (input.size(-1) / 2. - 0.5)
#
#     # apply the coordinate transformation
#     new_coords = coords.mm(A.t().contiguous()) + b.expand_as(coords)
#
#     # shift the coordinates back so origin is origin
#     new_coords[:,0] = new_coords[:,0] + (input.size(-3) / 2. - 0.5)
#     new_coords[:,1] = new_coords[:,1] + (input.size(-2) / 2. - 0.5)
#     new_coords[:,2] = new_coords[:,2] + (input.size(-1) / 2. - 0.5)
#
#     # map new coordinates using bilinear interpolation
#     input_transformed = th_trilinear_interp3d(input, new_coords)
#
#     return input_transformed
#
#
# def th_trilinear_interp3d(input, coords):
#     """
#     trilinear interpolation of 3D torch.Tensor image
#     """
#     # take clamp then floor/ceil of x coords
#     x = torch.clamp(coords[:,0], 0, input.size(-3)-2)
#     x0 = x.floor()
#     x1 = x0 + 1
#     # take clamp then floor/ceil of y coords
#     y = torch.clamp(coords[:,1], 0, input.size(-2)-2)
#     y0 = y.floor()
#     y1 = y0 + 1
#     # take clamp then floor/ceil of z coords
#     z = torch.clamp(coords[:,2], 0, input.size(-1)-2)
#     z0 = z.floor()
#     z1 = z0 + 1
#
#     stride = torch.tensor(input.stride()[-3:], dtype=torch.int64, device=input.device)
#     x0_ix = x0.mul(stride[0]).long()
#     x1_ix = x1.mul(stride[0]).long()
#     y0_ix = y0.mul(stride[1]).long()
#     y1_ix = y1.mul(stride[1]).long()
#     z0_ix = z0.mul(stride[2]).long()
#     z1_ix = z1.mul(stride[2]).long()
#
#     # input_flat = th_flatten(input)
#     input_flat = x.contiguous().view(x[0], x[1], -1)
#
#     vals_000 = input_flat[:, :, x0_ix+y0_ix+z0_ix]
#     vals_001 = input_flat[:, :, x0_ix+y0_ix+z1_ix]
#     vals_010 = input_flat[:, :, x0_ix+y1_ix+z0_ix]
#     vals_011 = input_flat[:, :, x0_ix+y1_ix+z1_ix]
#     vals_100 = input_flat[:, :, x1_ix+y0_ix+z0_ix]
#     vals_101 = input_flat[:, :, x1_ix+y0_ix+z1_ix]
#     vals_110 = input_flat[:, :, x1_ix+y1_ix+z0_ix]
#     vals_111 = input_flat[:, :, x1_ix+y1_ix+z1_ix]
#
#     xd = x - x0
#     yd = y - y0
#     zd = z - z0
#     xm1 = 1 - xd
#     ym1 = 1 - yd
#     zm1 = 1 - zd
#
#     x_mapped = (
#             vals_000.mul(xm1).mul(ym1).mul(zm1) +
#             vals_001.mul(xm1).mul(ym1).mul(zd) +
#             vals_010.mul(xm1).mul(yd).mul(zm1) +
#             vals_011.mul(xm1).mul(yd).mul(zd) +
#             vals_100.mul(xd).mul(ym1).mul(zm1) +
#             vals_101.mul(xd).mul(ym1).mul(zd) +
#             vals_110.mul(xd).mul(yd).mul(zm1) +
#             vals_111.mul(xd).mul(yd).mul(zd)
#     )
#
#     return x_mapped.view_as(input)
#
# def th_iterproduct(*args, dtype=None):
#     return torch.from_numpy(np.indices(args).reshape((len(args),-1)).T)
#
# def th_flatten(x):
#     """Flatten tensor"""
#     return x.contiguous().view(x[0], x[1], -1)