import copy import csv import functools import glob import math import os import random from collections import namedtuple import SimpleITK as sitk import numpy as np import scipy.ndimage.morphology as morph import torch import torch.cuda import torch.nn.functional as F from torch.utils.data import Dataset from util.disk import getCache from util.util import XyzTuple, xyz2irc from util.logconf import logging log = logging.getLogger(__name__) # log.setLevel(logging.WARN) # log.setLevel(logging.INFO) log.setLevel(logging.DEBUG) raw_cache = getCache('part2ch13_raw') MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask') CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz') @functools.lru_cache(1) def getCandidateInfoList(requireOnDisk_bool=True): # We construct a set with all series_uids that are present on disk. # This will let us use the data, even if we haven't downloaded all of # the subsets yet. mhd_list = glob.glob('data-unversioned/subset*/*.mhd') presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} candidateInfo_list = [] with open('data/annotations_with_malignancy.csv', "r") as f: for row in list(csv.reader(f))[1:]: series_uid = row[0] annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) annotationDiameter_mm = float(row[4]) isMal_bool = {'False': False, 'True': True}[row[5]] candidateInfo_list.append( CandidateInfoTuple( True, True, isMal_bool, annotationDiameter_mm, series_uid, annotationCenter_xyz, ) ) with open('data/candidates.csv', "r") as f: for row in list(csv.reader(f))[1:]: series_uid = row[0] if series_uid not in presentOnDisk_set and requireOnDisk_bool: continue isNodule_bool = bool(int(row[4])) candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) if not isNodule_bool: candidateInfo_list.append( CandidateInfoTuple( False, False, False, 0.0, series_uid, candidateCenter_xyz, ) ) candidateInfo_list.sort(reverse=True) return candidateInfo_list @functools.lru_cache(1) def getCandidateInfoDict(requireOnDisk_bool=True): candidateInfo_list = getCandidateInfoList(requireOnDisk_bool) candidateInfo_dict = {} for candidateInfo_tup in candidateInfo_list: candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, []).append(candidateInfo_tup) return candidateInfo_dict class Ct: def __init__(self, series_uid): mhd_path = glob.glob( 'data-unversioned/subset*/{}.mhd'.format(series_uid) )[0] ct_mhd = sitk.ReadImage(mhd_path) self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. self.series_uid = series_uid self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3) candidateInfo_list = getCandidateInfoDict()[self.series_uid] self.positiveInfo_list = [ candidate_tup for candidate_tup in candidateInfo_list if candidate_tup.isNodule_bool ] self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list) self.positive_indexes = (self.positive_mask.sum(axis=(1,2)) .nonzero()[0].tolist()) def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700): boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool) for candidateInfo_tup in positiveInfo_list: center_irc = xyz2irc( candidateInfo_tup.center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a, ) ci = int(center_irc.index) cr = int(center_irc.row) cc = int(center_irc.col) index_radius = 2 try: while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \ self.hu_a[ci - index_radius, cr, cc] > threshold_hu: index_radius += 1 except IndexError: index_radius -= 1 row_radius = 2 try: while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \ self.hu_a[ci, cr - row_radius, cc] > threshold_hu: row_radius += 1 except IndexError: row_radius -= 1 col_radius = 2 try: while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \ self.hu_a[ci, cr, cc - col_radius] > threshold_hu: col_radius += 1 except IndexError: col_radius -= 1 # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]]) # assert row_radius > 0 # assert col_radius > 0 boundingBox_a[ ci - index_radius: ci + index_radius + 1, cr - row_radius: cr + row_radius + 1, cc - col_radius: cc + col_radius + 1] = True mask_a = boundingBox_a & (self.hu_a > threshold_hu) return mask_a def getRawCandidate(self, center_xyz, width_irc): center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a) slice_list = [] for axis, center_val in enumerate(center_irc): start_ndx = int(round(center_val - width_irc[axis]/2)) end_ndx = int(start_ndx + width_irc[axis]) assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis]) if start_ndx < 0: # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) start_ndx = 0 end_ndx = int(width_irc[axis]) if end_ndx > self.hu_a.shape[axis]: # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) end_ndx = self.hu_a.shape[axis] start_ndx = int(self.hu_a.shape[axis] - width_irc[axis]) slice_list.append(slice(start_ndx, end_ndx)) ct_chunk = self.hu_a[tuple(slice_list)] pos_chunk = self.positive_mask[tuple(slice_list)] return ct_chunk, pos_chunk, center_irc @functools.lru_cache(1, typed=True) def getCt(series_uid): return Ct(series_uid) @raw_cache.memoize(typed=True) def getCtRawCandidate(series_uid, center_xyz, width_irc): ct = getCt(series_uid) ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc) ct_chunk.clip(-1000, 1000, ct_chunk) return ct_chunk, pos_chunk, center_irc @raw_cache.memoize(typed=True) def getCtSampleSize(series_uid): ct = Ct(series_uid) return int(ct.hu_a.shape[0]), ct.positive_indexes class Luna2dSegmentationDataset(Dataset): def __init__(self, val_stride=0, isValSet_bool=None, series_uid=None, contextSlices_count=3, fullCt_bool=False, ): self.contextSlices_count = contextSlices_count self.fullCt_bool = fullCt_bool if series_uid: self.series_list = [series_uid] else: self.series_list = sorted(getCandidateInfoDict().keys()) if isValSet_bool: assert val_stride > 0, val_stride self.series_list = self.series_list[::val_stride] assert self.series_list elif val_stride > 0: del self.series_list[::val_stride] assert self.series_list self.sample_list = [] for series_uid in self.series_list: index_count, positive_indexes = getCtSampleSize(series_uid) if self.fullCt_bool: self.sample_list += [(series_uid, slice_ndx) for slice_ndx in range(index_count)] else: self.sample_list += [(series_uid, slice_ndx) for slice_ndx in positive_indexes] self.candidateInfo_list = getCandidateInfoList() series_set = set(self.series_list) self.candidateInfo_list = [cit for cit in self.candidateInfo_list if cit.series_uid in series_set] self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool] log.info("{!r}: {} {} series, {} slices, {} nodules".format( self, len(self.series_list), {None: 'general', True: 'validation', False: 'training'}[isValSet_bool], len(self.sample_list), len(self.pos_list), )) def __len__(self): return len(self.sample_list) def __getitem__(self, ndx): series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)] return self.getitem_fullSlice(series_uid, slice_ndx) def getitem_fullSlice(self, series_uid, slice_ndx): ct = getCt(series_uid) ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512)) start_ndx = slice_ndx - self.contextSlices_count end_ndx = slice_ndx + self.contextSlices_count + 1 for i, context_ndx in enumerate(range(start_ndx, end_ndx)): context_ndx = max(context_ndx, 0) context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1) ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32)) # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. # The lower bound gets rid of negative density stuff used to indicate out-of-FOV # The upper bound nukes any weird hotspots and clamps bone down ct_t.clamp_(-1000, 1000) pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0) return ct_t, pos_t, ct.series_uid, slice_ndx class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.ratio_int = 2 def __len__(self): return 300000 def shuffleSamples(self): random.shuffle(self.candidateInfo_list) random.shuffle(self.pos_list) def __getitem__(self, ndx): candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)] return self.getitem_trainingCrop(candidateInfo_tup) def getitem_trainingCrop(self, candidateInfo_tup): ct_a, pos_a, center_irc = getCtRawCandidate( candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96), ) pos_a = pos_a[3:4] row_offset = random.randrange(0,32) col_offset = random.randrange(0,32) ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64, col_offset:col_offset+64]).to(torch.float32) pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64, col_offset:col_offset+64]).to(torch.long) slice_ndx = center_irc.index return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx class PrepcacheLunaDataset(Dataset): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.candidateInfo_list = getCandidateInfoList() self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool] self.seen_set = set() self.candidateInfo_list.sort(key=lambda x: x.series_uid) def __len__(self): return len(self.candidateInfo_list) def __getitem__(self, ndx): # candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx) candidateInfo_tup = self.candidateInfo_list[ndx] getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96)) series_uid = candidateInfo_tup.series_uid if series_uid not in self.seen_set: self.seen_set.add(series_uid) getCtSampleSize(series_uid) # ct = getCt(series_uid) # for mask_ndx in ct.positive_indexes: # build2dLungMask(series_uid, mask_ndx) return 0, 1 #candidate_t, pos_t, series_uid, center_t class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset): def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3): assert contextSlices_count == 3 data = torch.load('./imgs_and_masks.pt') suids = list(set(data['suids'])) trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1) trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i} trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]]) if not isValSet_bool: self.imgs = data["imgs"][trn_mask] self.masks = data["masks"][trn_mask] self.suids = [s for s, i in zip(data["suids"], trn_mask) if i] else: self.imgs = data["imgs"][~trn_mask] self.masks = data["masks"][~trn_mask] self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i] # discard spurious hotspots and clamp bone self.imgs.clamp_(-1000, 1000) self.imgs /= 1000 def __len__(self): return len(self.imgs) def __getitem__(self, i): oh, ow = torch.randint(0, 32, (2,)) sl = self.masks.size(1)//2 return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999