forked from KEMT/zpwiki
		
	
		
			
				
	
	
		
			402 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			402 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copy
 | |
| import csv
 | |
| import functools
 | |
| import glob
 | |
| import math
 | |
| import os
 | |
| import random
 | |
| 
 | |
| from collections import namedtuple
 | |
| 
 | |
| import SimpleITK as sitk
 | |
| import numpy as np
 | |
| import scipy.ndimage.morphology as morph
 | |
| 
 | |
| import torch
 | |
| import torch.cuda
 | |
| import torch.nn.functional as F
 | |
| from torch.utils.data import Dataset
 | |
| 
 | |
| from util.disk import getCache
 | |
| from util.util import XyzTuple, xyz2irc
 | |
| from util.logconf import logging
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| # log.setLevel(logging.WARN)
 | |
| # log.setLevel(logging.INFO)
 | |
| log.setLevel(logging.DEBUG)
 | |
| 
 | |
| raw_cache = getCache('part2ch13_raw')
 | |
| 
 | |
| MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
 | |
| 
 | |
| CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
 | |
| 
 | |
| @functools.lru_cache(1)
 | |
| def getCandidateInfoList(requireOnDisk_bool=True):
 | |
|     # We construct a set with all series_uids that are present on disk.
 | |
|     # This will let us use the data, even if we haven't downloaded all of
 | |
|     # the subsets yet.
 | |
|     mhd_list = glob.glob('data-unversioned/subset*/*.mhd')
 | |
|     presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
 | |
| 
 | |
|     candidateInfo_list = []
 | |
|     with open('data/annotations_with_malignancy.csv', "r") as f:
 | |
|         for row in list(csv.reader(f))[1:]:
 | |
|             series_uid = row[0]
 | |
|             annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
 | |
|             annotationDiameter_mm = float(row[4])
 | |
|             isMal_bool = {'False': False, 'True': True}[row[5]]
 | |
| 
 | |
|             candidateInfo_list.append(
 | |
|                 CandidateInfoTuple(
 | |
|                     True,
 | |
|                     True,
 | |
|                     isMal_bool,
 | |
|                     annotationDiameter_mm,
 | |
|                     series_uid,
 | |
|                     annotationCenter_xyz,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|     with open('data/candidates.csv', "r") as f:
 | |
|         for row in list(csv.reader(f))[1:]:
 | |
|             series_uid = row[0]
 | |
| 
 | |
|             if series_uid not in presentOnDisk_set and requireOnDisk_bool:
 | |
|                 continue
 | |
| 
 | |
|             isNodule_bool = bool(int(row[4]))
 | |
|             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
 | |
| 
 | |
|             if not isNodule_bool:
 | |
|                 candidateInfo_list.append(
 | |
|                     CandidateInfoTuple(
 | |
|                         False,
 | |
|                         False,
 | |
|                         False,
 | |
|                         0.0,
 | |
|                         series_uid,
 | |
|                         candidateCenter_xyz,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|     candidateInfo_list.sort(reverse=True)
 | |
|     return candidateInfo_list
 | |
| 
 | |
| @functools.lru_cache(1)
 | |
| def getCandidateInfoDict(requireOnDisk_bool=True):
 | |
|     candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
 | |
|     candidateInfo_dict = {}
 | |
| 
 | |
|     for candidateInfo_tup in candidateInfo_list:
 | |
|         candidateInfo_dict.setdefault(candidateInfo_tup.series_uid,
 | |
|                                       []).append(candidateInfo_tup)
 | |
| 
 | |
|     return candidateInfo_dict
 | |
| 
 | |
| class Ct:
 | |
|     def __init__(self, series_uid):
 | |
|         mhd_path = glob.glob(
 | |
|             'data-unversioned/subset*/{}.mhd'.format(series_uid)
 | |
|         )[0]
 | |
| 
 | |
|         ct_mhd = sitk.ReadImage(mhd_path)
 | |
|         self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
 | |
| 
 | |
|         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
 | |
|         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
 | |
| 
 | |
|         self.series_uid = series_uid
 | |
| 
 | |
|         self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
 | |
|         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
 | |
|         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
 | |
| 
 | |
|         candidateInfo_list = getCandidateInfoDict()[self.series_uid]
 | |
| 
 | |
|         self.positiveInfo_list = [
 | |
|             candidate_tup
 | |
|             for candidate_tup in candidateInfo_list
 | |
|             if candidate_tup.isNodule_bool
 | |
|         ]
 | |
|         self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
 | |
|         self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
 | |
|                                  .nonzero()[0].tolist())
 | |
| 
 | |
|     def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
 | |
|         boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
 | |
| 
 | |
|         for candidateInfo_tup in positiveInfo_list:
 | |
|             center_irc = xyz2irc(
 | |
|                 candidateInfo_tup.center_xyz,
 | |
|                 self.origin_xyz,
 | |
|                 self.vxSize_xyz,
 | |
|                 self.direction_a,
 | |
|             )
 | |
|             ci = int(center_irc.index)
 | |
|             cr = int(center_irc.row)
 | |
|             cc = int(center_irc.col)
 | |
| 
 | |
|             index_radius = 2
 | |
|             try:
 | |
|                 while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
 | |
|                         self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
 | |
|                     index_radius += 1
 | |
|             except IndexError:
 | |
|                 index_radius -= 1
 | |
| 
 | |
|             row_radius = 2
 | |
|             try:
 | |
|                 while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
 | |
|                         self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
 | |
|                     row_radius += 1
 | |
|             except IndexError:
 | |
|                 row_radius -= 1
 | |
| 
 | |
|             col_radius = 2
 | |
|             try:
 | |
|                 while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
 | |
|                         self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
 | |
|                     col_radius += 1
 | |
|             except IndexError:
 | |
|                 col_radius -= 1
 | |
| 
 | |
|             # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
 | |
|             # assert row_radius > 0
 | |
|             # assert col_radius > 0
 | |
| 
 | |
|             boundingBox_a[
 | |
|                  ci - index_radius: ci + index_radius + 1,
 | |
|                  cr - row_radius: cr + row_radius + 1,
 | |
|                  cc - col_radius: cc + col_radius + 1] = True
 | |
| 
 | |
|         mask_a = boundingBox_a & (self.hu_a > threshold_hu)
 | |
| 
 | |
|         return mask_a
 | |
| 
 | |
|     def getRawCandidate(self, center_xyz, width_irc):
 | |
|         center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
 | |
|                              self.direction_a)
 | |
| 
 | |
|         slice_list = []
 | |
|         for axis, center_val in enumerate(center_irc):
 | |
|             start_ndx = int(round(center_val - width_irc[axis]/2))
 | |
|             end_ndx = int(start_ndx + width_irc[axis])
 | |
| 
 | |
|             assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
 | |
| 
 | |
|             if start_ndx < 0:
 | |
|                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
 | |
|                 #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
 | |
|                 start_ndx = 0
 | |
|                 end_ndx = int(width_irc[axis])
 | |
| 
 | |
|             if end_ndx > self.hu_a.shape[axis]:
 | |
|                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
 | |
|                 #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
 | |
|                 end_ndx = self.hu_a.shape[axis]
 | |
|                 start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
 | |
| 
 | |
|             slice_list.append(slice(start_ndx, end_ndx))
 | |
| 
 | |
|         ct_chunk = self.hu_a[tuple(slice_list)]
 | |
|         pos_chunk = self.positive_mask[tuple(slice_list)]
 | |
| 
 | |
|         return ct_chunk, pos_chunk, center_irc
 | |
| 
 | |
| @functools.lru_cache(1, typed=True)
 | |
| def getCt(series_uid):
 | |
|     return Ct(series_uid)
 | |
| 
 | |
| @raw_cache.memoize(typed=True)
 | |
| def getCtRawCandidate(series_uid, center_xyz, width_irc):
 | |
|     ct = getCt(series_uid)
 | |
|     ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz,
 | |
|                                                          width_irc)
 | |
|     ct_chunk.clip(-1000, 1000, ct_chunk)
 | |
|     return ct_chunk, pos_chunk, center_irc
 | |
| 
 | |
| @raw_cache.memoize(typed=True)
 | |
| def getCtSampleSize(series_uid):
 | |
|     ct = Ct(series_uid)
 | |
|     return int(ct.hu_a.shape[0]), ct.positive_indexes
 | |
| 
 | |
| 
 | |
| class Luna2dSegmentationDataset(Dataset):
 | |
|     def __init__(self,
 | |
|                  val_stride=0,
 | |
|                  isValSet_bool=None,
 | |
|                  series_uid=None,
 | |
|                  contextSlices_count=3,
 | |
|                  fullCt_bool=False,
 | |
|             ):
 | |
|         self.contextSlices_count = contextSlices_count
 | |
|         self.fullCt_bool = fullCt_bool
 | |
| 
 | |
|         if series_uid:
 | |
|             self.series_list = [series_uid]
 | |
|         else:
 | |
|             self.series_list = sorted(getCandidateInfoDict().keys())
 | |
| 
 | |
|         if isValSet_bool:
 | |
|             assert val_stride > 0, val_stride
 | |
|             self.series_list = self.series_list[::val_stride]
 | |
|             assert self.series_list
 | |
|         elif val_stride > 0:
 | |
|             del self.series_list[::val_stride]
 | |
|             assert self.series_list
 | |
| 
 | |
|         self.sample_list = []
 | |
|         for series_uid in self.series_list:
 | |
|             index_count, positive_indexes = getCtSampleSize(series_uid)
 | |
| 
 | |
|             if self.fullCt_bool:
 | |
|                 self.sample_list += [(series_uid, slice_ndx)
 | |
|                                      for slice_ndx in range(index_count)]
 | |
|             else:
 | |
|                 self.sample_list += [(series_uid, slice_ndx)
 | |
|                                      for slice_ndx in positive_indexes]
 | |
| 
 | |
|         self.candidateInfo_list = getCandidateInfoList()
 | |
| 
 | |
|         series_set = set(self.series_list)
 | |
|         self.candidateInfo_list = [cit for cit in self.candidateInfo_list
 | |
|                                    if cit.series_uid in series_set]
 | |
| 
 | |
|         self.pos_list = [nt for nt in self.candidateInfo_list
 | |
|                             if nt.isNodule_bool]
 | |
| 
 | |
|         log.info("{!r}: {} {} series, {} slices, {} nodules".format(
 | |
|             self,
 | |
|             len(self.series_list),
 | |
|             {None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
 | |
|             len(self.sample_list),
 | |
|             len(self.pos_list),
 | |
|         ))
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.sample_list)
 | |
| 
 | |
|     def __getitem__(self, ndx):
 | |
|         series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]
 | |
|         return self.getitem_fullSlice(series_uid, slice_ndx)
 | |
| 
 | |
|     def getitem_fullSlice(self, series_uid, slice_ndx):
 | |
|         ct = getCt(series_uid)
 | |
|         ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512))
 | |
| 
 | |
|         start_ndx = slice_ndx - self.contextSlices_count
 | |
|         end_ndx = slice_ndx + self.contextSlices_count + 1
 | |
|         for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
 | |
|             context_ndx = max(context_ndx, 0)
 | |
|             context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
 | |
|             ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
 | |
| 
 | |
|         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
 | |
|         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
 | |
|         # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
 | |
|         # The upper bound nukes any weird hotspots and clamps bone down
 | |
|         ct_t.clamp_(-1000, 1000)
 | |
| 
 | |
|         pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)
 | |
| 
 | |
|         return ct_t, pos_t, ct.series_uid, slice_ndx
 | |
| 
 | |
| 
 | |
| class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
| 
 | |
|         self.ratio_int = 2
 | |
| 
 | |
|     def __len__(self):
 | |
|         return 300000
 | |
| 
 | |
|     def shuffleSamples(self):
 | |
|         random.shuffle(self.candidateInfo_list)
 | |
|         random.shuffle(self.pos_list)
 | |
| 
 | |
|     def __getitem__(self, ndx):
 | |
|         candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
 | |
|         return self.getitem_trainingCrop(candidateInfo_tup)
 | |
| 
 | |
|     def getitem_trainingCrop(self, candidateInfo_tup):
 | |
|         ct_a, pos_a, center_irc = getCtRawCandidate(
 | |
|             candidateInfo_tup.series_uid,
 | |
|             candidateInfo_tup.center_xyz,
 | |
|             (7, 96, 96),
 | |
|         )
 | |
|         pos_a = pos_a[3:4]
 | |
| 
 | |
|         row_offset = random.randrange(0,32)
 | |
|         col_offset = random.randrange(0,32)
 | |
|         ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64,
 | |
|                                      col_offset:col_offset+64]).to(torch.float32)
 | |
|         pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64,
 | |
|                                        col_offset:col_offset+64]).to(torch.long)
 | |
| 
 | |
|         slice_ndx = center_irc.index
 | |
| 
 | |
|         return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx
 | |
| 
 | |
| class PrepcacheLunaDataset(Dataset):
 | |
|     def __init__(self, *args, **kwargs):
 | |
|         super().__init__(*args, **kwargs)
 | |
| 
 | |
|         self.candidateInfo_list = getCandidateInfoList()
 | |
|         self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
 | |
| 
 | |
|         self.seen_set = set()
 | |
|         self.candidateInfo_list.sort(key=lambda x: x.series_uid)
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.candidateInfo_list)
 | |
| 
 | |
|     def __getitem__(self, ndx):
 | |
|         # candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx)
 | |
| 
 | |
|         candidateInfo_tup = self.candidateInfo_list[ndx]
 | |
|         getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96))
 | |
| 
 | |
|         series_uid = candidateInfo_tup.series_uid
 | |
|         if series_uid not in self.seen_set:
 | |
|             self.seen_set.add(series_uid)
 | |
| 
 | |
|             getCtSampleSize(series_uid)
 | |
|             # ct = getCt(series_uid)
 | |
|             # for mask_ndx in ct.positive_indexes:
 | |
|             #     build2dLungMask(series_uid, mask_ndx)
 | |
| 
 | |
|         return 0, 1 #candidate_t, pos_t, series_uid, center_t
 | |
| 
 | |
| 
 | |
| class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset):
 | |
|     def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3):
 | |
|         assert contextSlices_count == 3
 | |
|         data = torch.load('./imgs_and_masks.pt')
 | |
|         suids = list(set(data['suids']))
 | |
|         trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1)
 | |
|         trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i}
 | |
|         trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]])
 | |
|         if not isValSet_bool:
 | |
|             self.imgs = data["imgs"][trn_mask]
 | |
|             self.masks = data["masks"][trn_mask]
 | |
|             self.suids = [s for s, i in zip(data["suids"], trn_mask) if i]
 | |
|         else:
 | |
|             self.imgs = data["imgs"][~trn_mask]
 | |
|             self.masks = data["masks"][~trn_mask]
 | |
|             self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i]
 | |
|         # discard spurious hotspots and clamp bone
 | |
|         self.imgs.clamp_(-1000, 1000)
 | |
|         self.imgs /= 1000
 | |
| 
 | |
| 
 | |
|     def __len__(self):
 | |
|         return len(self.imgs)
 | |
| 
 | |
|     def __getitem__(self, i):
 | |
|         oh, ow = torch.randint(0, 32, (2,))
 | |
|         sl = self.masks.size(1)//2
 | |
|         return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999
 |