forked from KEMT/zpwiki
		
	Nahrát soubory do „pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model“
This commit is contained in:
		
							parent
							
								
									35f3889338
								
							
						
					
					
						commit
						526fe93ff9
					
				
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1,76 @@ | |||||||
|  | import argparse | ||||||
|  | import datetime | ||||||
|  | import os | ||||||
|  | import socket | ||||||
|  | import sys | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | from torch.utils.tensorboard import SummaryWriter | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | import torch.optim | ||||||
|  | 
 | ||||||
|  | from torch.optim import SGD, Adam | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | 
 | ||||||
|  | from util.util import enumerateWithEstimate | ||||||
|  | from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt | ||||||
|  | from util.logconf import logging | ||||||
|  | from util.util import xyz2irc | ||||||
|  | from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation | ||||||
|  | from p2ch13.train_seg import LunaTrainingApp | ||||||
|  | 
 | ||||||
|  | log = logging.getLogger(__name__) | ||||||
|  | # log.setLevel(logging.WARN) | ||||||
|  | # log.setLevel(logging.INFO) | ||||||
|  | log.setLevel(logging.DEBUG) | ||||||
|  | 
 | ||||||
|  | class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset): | ||||||
|  |     def __len__(self): | ||||||
|  |         # return 500 | ||||||
|  |         return 5000 | ||||||
|  |         return 1000 | ||||||
|  | 
 | ||||||
|  | class LunaBenchmarkApp(LunaTrainingApp): | ||||||
|  |     def initTrainDl(self): | ||||||
|  |         train_ds = BenchmarkLuna2dSegmentationDataset( | ||||||
|  |             val_stride=10, | ||||||
|  |             isValSet_bool=False, | ||||||
|  |             contextSlices_count=3, | ||||||
|  |             # augmentation_dict=self.augmentation_dict, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         batch_size = self.cli_args.batch_size | ||||||
|  |         if self.use_cuda: | ||||||
|  |             batch_size *= torch.cuda.device_count() | ||||||
|  | 
 | ||||||
|  |         train_dl = DataLoader( | ||||||
|  |             train_ds, | ||||||
|  |             batch_size=batch_size, | ||||||
|  |             num_workers=self.cli_args.num_workers, | ||||||
|  |             pin_memory=self.use_cuda, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return train_dl | ||||||
|  | 
 | ||||||
|  |     def main(self): | ||||||
|  |         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) | ||||||
|  | 
 | ||||||
|  |         train_dl = self.initTrainDl() | ||||||
|  | 
 | ||||||
|  |         for epoch_ndx in range(1, 2): | ||||||
|  |             log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format( | ||||||
|  |                 epoch_ndx, | ||||||
|  |                 self.cli_args.epochs, | ||||||
|  |                 len(train_dl), | ||||||
|  |                 len([]), | ||||||
|  |                 self.cli_args.batch_size, | ||||||
|  |                 (torch.cuda.device_count() if self.use_cuda else 1), | ||||||
|  |             )) | ||||||
|  | 
 | ||||||
|  |             self.doTraining(epoch_ndx, train_dl) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     LunaBenchmarkApp().main() | ||||||
| @ -0,0 +1,401 @@ | |||||||
|  | import copy | ||||||
|  | import csv | ||||||
|  | import functools | ||||||
|  | import glob | ||||||
|  | import math | ||||||
|  | import os | ||||||
|  | import random | ||||||
|  | 
 | ||||||
|  | from collections import namedtuple | ||||||
|  | 
 | ||||||
|  | import SimpleITK as sitk | ||||||
|  | import numpy as np | ||||||
|  | import scipy.ndimage.morphology as morph | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | import torch.cuda | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from torch.utils.data import Dataset | ||||||
|  | 
 | ||||||
|  | from util.disk import getCache | ||||||
|  | from util.util import XyzTuple, xyz2irc | ||||||
|  | from util.logconf import logging | ||||||
|  | 
 | ||||||
|  | log = logging.getLogger(__name__) | ||||||
|  | # log.setLevel(logging.WARN) | ||||||
|  | # log.setLevel(logging.INFO) | ||||||
|  | log.setLevel(logging.DEBUG) | ||||||
|  | 
 | ||||||
|  | raw_cache = getCache('part2ch13_raw') | ||||||
|  | 
 | ||||||
|  | MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask') | ||||||
|  | 
 | ||||||
|  | CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz') | ||||||
|  | 
 | ||||||
|  | @functools.lru_cache(1) | ||||||
|  | def getCandidateInfoList(requireOnDisk_bool=True): | ||||||
|  |     # We construct a set with all series_uids that are present on disk. | ||||||
|  |     # This will let us use the data, even if we haven't downloaded all of | ||||||
|  |     # the subsets yet. | ||||||
|  |     mhd_list = glob.glob('data-unversioned/subset*/*.mhd') | ||||||
|  |     presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} | ||||||
|  | 
 | ||||||
|  |     candidateInfo_list = [] | ||||||
|  |     with open('data/annotations_with_malignancy.csv', "r") as f: | ||||||
|  |         for row in list(csv.reader(f))[1:]: | ||||||
|  |             series_uid = row[0] | ||||||
|  |             annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) | ||||||
|  |             annotationDiameter_mm = float(row[4]) | ||||||
|  |             isMal_bool = {'False': False, 'True': True}[row[5]] | ||||||
|  | 
 | ||||||
|  |             candidateInfo_list.append( | ||||||
|  |                 CandidateInfoTuple( | ||||||
|  |                     True, | ||||||
|  |                     True, | ||||||
|  |                     isMal_bool, | ||||||
|  |                     annotationDiameter_mm, | ||||||
|  |                     series_uid, | ||||||
|  |                     annotationCenter_xyz, | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     with open('data/candidates.csv', "r") as f: | ||||||
|  |         for row in list(csv.reader(f))[1:]: | ||||||
|  |             series_uid = row[0] | ||||||
|  | 
 | ||||||
|  |             if series_uid not in presentOnDisk_set and requireOnDisk_bool: | ||||||
|  |                 continue | ||||||
|  | 
 | ||||||
|  |             isNodule_bool = bool(int(row[4])) | ||||||
|  |             candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) | ||||||
|  | 
 | ||||||
|  |             if not isNodule_bool: | ||||||
|  |                 candidateInfo_list.append( | ||||||
|  |                     CandidateInfoTuple( | ||||||
|  |                         False, | ||||||
|  |                         False, | ||||||
|  |                         False, | ||||||
|  |                         0.0, | ||||||
|  |                         series_uid, | ||||||
|  |                         candidateCenter_xyz, | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |     candidateInfo_list.sort(reverse=True) | ||||||
|  |     return candidateInfo_list | ||||||
|  | 
 | ||||||
|  | @functools.lru_cache(1) | ||||||
|  | def getCandidateInfoDict(requireOnDisk_bool=True): | ||||||
|  |     candidateInfo_list = getCandidateInfoList(requireOnDisk_bool) | ||||||
|  |     candidateInfo_dict = {} | ||||||
|  | 
 | ||||||
|  |     for candidateInfo_tup in candidateInfo_list: | ||||||
|  |         candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, | ||||||
|  |                                       []).append(candidateInfo_tup) | ||||||
|  | 
 | ||||||
|  |     return candidateInfo_dict | ||||||
|  | 
 | ||||||
|  | class Ct: | ||||||
|  |     def __init__(self, series_uid): | ||||||
|  |         mhd_path = glob.glob( | ||||||
|  |             'data-unversioned/subset*/{}.mhd'.format(series_uid) | ||||||
|  |         )[0] | ||||||
|  | 
 | ||||||
|  |         ct_mhd = sitk.ReadImage(mhd_path) | ||||||
|  |         self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) | ||||||
|  | 
 | ||||||
|  |         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale | ||||||
|  |         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. | ||||||
|  | 
 | ||||||
|  |         self.series_uid = series_uid | ||||||
|  | 
 | ||||||
|  |         self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) | ||||||
|  |         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) | ||||||
|  |         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3) | ||||||
|  | 
 | ||||||
|  |         candidateInfo_list = getCandidateInfoDict()[self.series_uid] | ||||||
|  | 
 | ||||||
|  |         self.positiveInfo_list = [ | ||||||
|  |             candidate_tup | ||||||
|  |             for candidate_tup in candidateInfo_list | ||||||
|  |             if candidate_tup.isNodule_bool | ||||||
|  |         ] | ||||||
|  |         self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list) | ||||||
|  |         self.positive_indexes = (self.positive_mask.sum(axis=(1,2)) | ||||||
|  |                                  .nonzero()[0].tolist()) | ||||||
|  | 
 | ||||||
|  |     def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700): | ||||||
|  |         boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool) | ||||||
|  | 
 | ||||||
|  |         for candidateInfo_tup in positiveInfo_list: | ||||||
|  |             center_irc = xyz2irc( | ||||||
|  |                 candidateInfo_tup.center_xyz, | ||||||
|  |                 self.origin_xyz, | ||||||
|  |                 self.vxSize_xyz, | ||||||
|  |                 self.direction_a, | ||||||
|  |             ) | ||||||
|  |             ci = int(center_irc.index) | ||||||
|  |             cr = int(center_irc.row) | ||||||
|  |             cc = int(center_irc.col) | ||||||
|  | 
 | ||||||
|  |             index_radius = 2 | ||||||
|  |             try: | ||||||
|  |                 while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \ | ||||||
|  |                         self.hu_a[ci - index_radius, cr, cc] > threshold_hu: | ||||||
|  |                     index_radius += 1 | ||||||
|  |             except IndexError: | ||||||
|  |                 index_radius -= 1 | ||||||
|  | 
 | ||||||
|  |             row_radius = 2 | ||||||
|  |             try: | ||||||
|  |                 while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \ | ||||||
|  |                         self.hu_a[ci, cr - row_radius, cc] > threshold_hu: | ||||||
|  |                     row_radius += 1 | ||||||
|  |             except IndexError: | ||||||
|  |                 row_radius -= 1 | ||||||
|  | 
 | ||||||
|  |             col_radius = 2 | ||||||
|  |             try: | ||||||
|  |                 while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \ | ||||||
|  |                         self.hu_a[ci, cr, cc - col_radius] > threshold_hu: | ||||||
|  |                     col_radius += 1 | ||||||
|  |             except IndexError: | ||||||
|  |                 col_radius -= 1 | ||||||
|  | 
 | ||||||
|  |             # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]]) | ||||||
|  |             # assert row_radius > 0 | ||||||
|  |             # assert col_radius > 0 | ||||||
|  | 
 | ||||||
|  |             boundingBox_a[ | ||||||
|  |                  ci - index_radius: ci + index_radius + 1, | ||||||
|  |                  cr - row_radius: cr + row_radius + 1, | ||||||
|  |                  cc - col_radius: cc + col_radius + 1] = True | ||||||
|  | 
 | ||||||
|  |         mask_a = boundingBox_a & (self.hu_a > threshold_hu) | ||||||
|  | 
 | ||||||
|  |         return mask_a | ||||||
|  | 
 | ||||||
|  |     def getRawCandidate(self, center_xyz, width_irc): | ||||||
|  |         center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, | ||||||
|  |                              self.direction_a) | ||||||
|  | 
 | ||||||
|  |         slice_list = [] | ||||||
|  |         for axis, center_val in enumerate(center_irc): | ||||||
|  |             start_ndx = int(round(center_val - width_irc[axis]/2)) | ||||||
|  |             end_ndx = int(start_ndx + width_irc[axis]) | ||||||
|  | 
 | ||||||
|  |             assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis]) | ||||||
|  | 
 | ||||||
|  |             if start_ndx < 0: | ||||||
|  |                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( | ||||||
|  |                 #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) | ||||||
|  |                 start_ndx = 0 | ||||||
|  |                 end_ndx = int(width_irc[axis]) | ||||||
|  | 
 | ||||||
|  |             if end_ndx > self.hu_a.shape[axis]: | ||||||
|  |                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( | ||||||
|  |                 #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) | ||||||
|  |                 end_ndx = self.hu_a.shape[axis] | ||||||
|  |                 start_ndx = int(self.hu_a.shape[axis] - width_irc[axis]) | ||||||
|  | 
 | ||||||
|  |             slice_list.append(slice(start_ndx, end_ndx)) | ||||||
|  | 
 | ||||||
|  |         ct_chunk = self.hu_a[tuple(slice_list)] | ||||||
|  |         pos_chunk = self.positive_mask[tuple(slice_list)] | ||||||
|  | 
 | ||||||
|  |         return ct_chunk, pos_chunk, center_irc | ||||||
|  | 
 | ||||||
|  | @functools.lru_cache(1, typed=True) | ||||||
|  | def getCt(series_uid): | ||||||
|  |     return Ct(series_uid) | ||||||
|  | 
 | ||||||
|  | @raw_cache.memoize(typed=True) | ||||||
|  | def getCtRawCandidate(series_uid, center_xyz, width_irc): | ||||||
|  |     ct = getCt(series_uid) | ||||||
|  |     ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz, | ||||||
|  |                                                          width_irc) | ||||||
|  |     ct_chunk.clip(-1000, 1000, ct_chunk) | ||||||
|  |     return ct_chunk, pos_chunk, center_irc | ||||||
|  | 
 | ||||||
|  | @raw_cache.memoize(typed=True) | ||||||
|  | def getCtSampleSize(series_uid): | ||||||
|  |     ct = Ct(series_uid) | ||||||
|  |     return int(ct.hu_a.shape[0]), ct.positive_indexes | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Luna2dSegmentationDataset(Dataset): | ||||||
|  |     def __init__(self, | ||||||
|  |                  val_stride=0, | ||||||
|  |                  isValSet_bool=None, | ||||||
|  |                  series_uid=None, | ||||||
|  |                  contextSlices_count=3, | ||||||
|  |                  fullCt_bool=False, | ||||||
|  |             ): | ||||||
|  |         self.contextSlices_count = contextSlices_count | ||||||
|  |         self.fullCt_bool = fullCt_bool | ||||||
|  | 
 | ||||||
|  |         if series_uid: | ||||||
|  |             self.series_list = [series_uid] | ||||||
|  |         else: | ||||||
|  |             self.series_list = sorted(getCandidateInfoDict().keys()) | ||||||
|  | 
 | ||||||
|  |         if isValSet_bool: | ||||||
|  |             assert val_stride > 0, val_stride | ||||||
|  |             self.series_list = self.series_list[::val_stride] | ||||||
|  |             assert self.series_list | ||||||
|  |         elif val_stride > 0: | ||||||
|  |             del self.series_list[::val_stride] | ||||||
|  |             assert self.series_list | ||||||
|  | 
 | ||||||
|  |         self.sample_list = [] | ||||||
|  |         for series_uid in self.series_list: | ||||||
|  |             index_count, positive_indexes = getCtSampleSize(series_uid) | ||||||
|  | 
 | ||||||
|  |             if self.fullCt_bool: | ||||||
|  |                 self.sample_list += [(series_uid, slice_ndx) | ||||||
|  |                                      for slice_ndx in range(index_count)] | ||||||
|  |             else: | ||||||
|  |                 self.sample_list += [(series_uid, slice_ndx) | ||||||
|  |                                      for slice_ndx in positive_indexes] | ||||||
|  | 
 | ||||||
|  |         self.candidateInfo_list = getCandidateInfoList() | ||||||
|  | 
 | ||||||
|  |         series_set = set(self.series_list) | ||||||
|  |         self.candidateInfo_list = [cit for cit in self.candidateInfo_list | ||||||
|  |                                    if cit.series_uid in series_set] | ||||||
|  | 
 | ||||||
|  |         self.pos_list = [nt for nt in self.candidateInfo_list | ||||||
|  |                             if nt.isNodule_bool] | ||||||
|  | 
 | ||||||
|  |         log.info("{!r}: {} {} series, {} slices, {} nodules".format( | ||||||
|  |             self, | ||||||
|  |             len(self.series_list), | ||||||
|  |             {None: 'general', True: 'validation', False: 'training'}[isValSet_bool], | ||||||
|  |             len(self.sample_list), | ||||||
|  |             len(self.pos_list), | ||||||
|  |         )) | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.sample_list) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, ndx): | ||||||
|  |         series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)] | ||||||
|  |         return self.getitem_fullSlice(series_uid, slice_ndx) | ||||||
|  | 
 | ||||||
|  |     def getitem_fullSlice(self, series_uid, slice_ndx): | ||||||
|  |         ct = getCt(series_uid) | ||||||
|  |         ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512)) | ||||||
|  | 
 | ||||||
|  |         start_ndx = slice_ndx - self.contextSlices_count | ||||||
|  |         end_ndx = slice_ndx + self.contextSlices_count + 1 | ||||||
|  |         for i, context_ndx in enumerate(range(start_ndx, end_ndx)): | ||||||
|  |             context_ndx = max(context_ndx, 0) | ||||||
|  |             context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1) | ||||||
|  |             ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32)) | ||||||
|  | 
 | ||||||
|  |         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale | ||||||
|  |         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. | ||||||
|  |         # The lower bound gets rid of negative density stuff used to indicate out-of-FOV | ||||||
|  |         # The upper bound nukes any weird hotspots and clamps bone down | ||||||
|  |         ct_t.clamp_(-1000, 1000) | ||||||
|  | 
 | ||||||
|  |         pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0) | ||||||
|  | 
 | ||||||
|  |         return ct_t, pos_t, ct.series_uid, slice_ndx | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset): | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  |         self.ratio_int = 2 | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return 300000 | ||||||
|  | 
 | ||||||
|  |     def shuffleSamples(self): | ||||||
|  |         random.shuffle(self.candidateInfo_list) | ||||||
|  |         random.shuffle(self.pos_list) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, ndx): | ||||||
|  |         candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)] | ||||||
|  |         return self.getitem_trainingCrop(candidateInfo_tup) | ||||||
|  | 
 | ||||||
|  |     def getitem_trainingCrop(self, candidateInfo_tup): | ||||||
|  |         ct_a, pos_a, center_irc = getCtRawCandidate( | ||||||
|  |             candidateInfo_tup.series_uid, | ||||||
|  |             candidateInfo_tup.center_xyz, | ||||||
|  |             (7, 96, 96), | ||||||
|  |         ) | ||||||
|  |         pos_a = pos_a[3:4] | ||||||
|  | 
 | ||||||
|  |         row_offset = random.randrange(0,32) | ||||||
|  |         col_offset = random.randrange(0,32) | ||||||
|  |         ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64, | ||||||
|  |                                      col_offset:col_offset+64]).to(torch.float32) | ||||||
|  |         pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64, | ||||||
|  |                                        col_offset:col_offset+64]).to(torch.long) | ||||||
|  | 
 | ||||||
|  |         slice_ndx = center_irc.index | ||||||
|  | 
 | ||||||
|  |         return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx | ||||||
|  | 
 | ||||||
|  | class PrepcacheLunaDataset(Dataset): | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  |         self.candidateInfo_list = getCandidateInfoList() | ||||||
|  |         self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool] | ||||||
|  | 
 | ||||||
|  |         self.seen_set = set() | ||||||
|  |         self.candidateInfo_list.sort(key=lambda x: x.series_uid) | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.candidateInfo_list) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, ndx): | ||||||
|  |         # candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx) | ||||||
|  | 
 | ||||||
|  |         candidateInfo_tup = self.candidateInfo_list[ndx] | ||||||
|  |         getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96)) | ||||||
|  | 
 | ||||||
|  |         series_uid = candidateInfo_tup.series_uid | ||||||
|  |         if series_uid not in self.seen_set: | ||||||
|  |             self.seen_set.add(series_uid) | ||||||
|  | 
 | ||||||
|  |             getCtSampleSize(series_uid) | ||||||
|  |             # ct = getCt(series_uid) | ||||||
|  |             # for mask_ndx in ct.positive_indexes: | ||||||
|  |             #     build2dLungMask(series_uid, mask_ndx) | ||||||
|  | 
 | ||||||
|  |         return 0, 1 #candidate_t, pos_t, series_uid, center_t | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset): | ||||||
|  |     def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3): | ||||||
|  |         assert contextSlices_count == 3 | ||||||
|  |         data = torch.load('./imgs_and_masks.pt') | ||||||
|  |         suids = list(set(data['suids'])) | ||||||
|  |         trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1) | ||||||
|  |         trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i} | ||||||
|  |         trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]]) | ||||||
|  |         if not isValSet_bool: | ||||||
|  |             self.imgs = data["imgs"][trn_mask] | ||||||
|  |             self.masks = data["masks"][trn_mask] | ||||||
|  |             self.suids = [s for s, i in zip(data["suids"], trn_mask) if i] | ||||||
|  |         else: | ||||||
|  |             self.imgs = data["imgs"][~trn_mask] | ||||||
|  |             self.masks = data["masks"][~trn_mask] | ||||||
|  |             self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i] | ||||||
|  |         # discard spurious hotspots and clamp bone | ||||||
|  |         self.imgs.clamp_(-1000, 1000) | ||||||
|  |         self.imgs /= 1000 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.imgs) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, i): | ||||||
|  |         oh, ow = torch.randint(0, 32, (2,)) | ||||||
|  |         sl = self.masks.size(1)//2 | ||||||
|  |         return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999 | ||||||
| @ -0,0 +1,224 @@ | |||||||
|  | import math | ||||||
|  | import random | ||||||
|  | from collections import namedtuple | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch import nn as nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | 
 | ||||||
|  | from util.logconf import logging | ||||||
|  | from util.unet import UNet | ||||||
|  | 
 | ||||||
|  | log = logging.getLogger(__name__) | ||||||
|  | # log.setLevel(logging.WARN) | ||||||
|  | # log.setLevel(logging.INFO) | ||||||
|  | log.setLevel(logging.DEBUG) | ||||||
|  | 
 | ||||||
|  | class UNetWrapper(nn.Module): | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels']) | ||||||
|  |         self.unet = UNet(**kwargs) | ||||||
|  |         self.final = nn.Sigmoid() | ||||||
|  | 
 | ||||||
|  |         self._init_weights() | ||||||
|  | 
 | ||||||
|  |     def _init_weights(self): | ||||||
|  |         init_set = { | ||||||
|  |             nn.Conv2d, | ||||||
|  |             nn.Conv3d, | ||||||
|  |             nn.ConvTranspose2d, | ||||||
|  |             nn.ConvTranspose3d, | ||||||
|  |             nn.Linear, | ||||||
|  |         } | ||||||
|  |         for m in self.modules(): | ||||||
|  |             if type(m) in init_set: | ||||||
|  |                 nn.init.kaiming_normal_( | ||||||
|  |                     m.weight.data, mode='fan_out', nonlinearity='relu', a=0 | ||||||
|  |                 ) | ||||||
|  |                 if m.bias is not None: | ||||||
|  |                     fan_in, fan_out = \ | ||||||
|  |                         nn.init._calculate_fan_in_and_fan_out(m.weight.data) | ||||||
|  |                     bound = 1 / math.sqrt(fan_out) | ||||||
|  |                     nn.init.normal_(m.bias, -bound, bound) | ||||||
|  | 
 | ||||||
|  |         # nn.init.constant_(self.unet.last.bias, -4) | ||||||
|  |         # nn.init.constant_(self.unet.last.bias, 4) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def forward(self, input_batch): | ||||||
|  |         bn_output = self.input_batchnorm(input_batch) | ||||||
|  |         un_output = self.unet(bn_output) | ||||||
|  |         fn_output = self.final(un_output) | ||||||
|  |         return fn_output | ||||||
|  | 
 | ||||||
|  | class SegmentationAugmentation(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |             self, flip=None, offset=None, scale=None, rotate=None, noise=None | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.flip = flip | ||||||
|  |         self.offset = offset | ||||||
|  |         self.scale = scale | ||||||
|  |         self.rotate = rotate | ||||||
|  |         self.noise = noise | ||||||
|  | 
 | ||||||
|  |     def forward(self, input_g, label_g): | ||||||
|  |         transform_t = self._build2dTransformMatrix() | ||||||
|  |         transform_t = transform_t.expand(input_g.shape[0], -1, -1) | ||||||
|  |         transform_t = transform_t.to(input_g.device, torch.float32) | ||||||
|  |         affine_t = F.affine_grid(transform_t[:,:2], | ||||||
|  |                 input_g.size(), align_corners=False) | ||||||
|  | 
 | ||||||
|  |         augmented_input_g = F.grid_sample(input_g, | ||||||
|  |                 affine_t, padding_mode='border', | ||||||
|  |                 align_corners=False) | ||||||
|  |         augmented_label_g = F.grid_sample(label_g.to(torch.float32), | ||||||
|  |                 affine_t, padding_mode='border', | ||||||
|  |                 align_corners=False) | ||||||
|  | 
 | ||||||
|  |         if self.noise: | ||||||
|  |             noise_t = torch.randn_like(augmented_input_g) | ||||||
|  |             noise_t *= self.noise | ||||||
|  | 
 | ||||||
|  |             augmented_input_g += noise_t | ||||||
|  | 
 | ||||||
|  |         return augmented_input_g, augmented_label_g > 0.5 | ||||||
|  | 
 | ||||||
|  |     def _build2dTransformMatrix(self): | ||||||
|  |         transform_t = torch.eye(3) | ||||||
|  | 
 | ||||||
|  |         for i in range(2): | ||||||
|  |             if self.flip: | ||||||
|  |                 if random.random() > 0.5: | ||||||
|  |                     transform_t[i,i] *= -1 | ||||||
|  | 
 | ||||||
|  |             if self.offset: | ||||||
|  |                 offset_float = self.offset | ||||||
|  |                 random_float = (random.random() * 2 - 1) | ||||||
|  |                 transform_t[2,i] = offset_float * random_float | ||||||
|  | 
 | ||||||
|  |             if self.scale: | ||||||
|  |                 scale_float = self.scale | ||||||
|  |                 random_float = (random.random() * 2 - 1) | ||||||
|  |                 transform_t[i,i] *= 1.0 + scale_float * random_float | ||||||
|  | 
 | ||||||
|  |         if self.rotate: | ||||||
|  |             angle_rad = random.random() * math.pi * 2 | ||||||
|  |             s = math.sin(angle_rad) | ||||||
|  |             c = math.cos(angle_rad) | ||||||
|  | 
 | ||||||
|  |             rotation_t = torch.tensor([ | ||||||
|  |                 [c, -s, 0], | ||||||
|  |                 [s, c, 0], | ||||||
|  |                 [0, 0, 1]]) | ||||||
|  | 
 | ||||||
|  |             transform_t @= rotation_t | ||||||
|  | 
 | ||||||
|  |         return transform_t | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask') | ||||||
|  | # | ||||||
|  | # class SegmentationMask(nn.Module): | ||||||
|  | #     def __init__(self): | ||||||
|  | #         super().__init__() | ||||||
|  | # | ||||||
|  | #         self.conv_list = nn.ModuleList([ | ||||||
|  | #             self._make_circle_conv(radius) for radius in range(1, 8) | ||||||
|  | #         ]) | ||||||
|  | # | ||||||
|  | #     def _make_circle_conv(self, radius): | ||||||
|  | #         diameter = 1 + radius * 2 | ||||||
|  | # | ||||||
|  | #         a = torch.linspace(-1, 1, steps=diameter)**2 | ||||||
|  | #         b = (a[None] + a[:, None])**0.5 | ||||||
|  | # | ||||||
|  | #         circle_weights = (b <= 1.0).to(torch.float32) | ||||||
|  | # | ||||||
|  | #         conv = nn.Conv2d(1, 1, kernel_size=diameter, padding=radius, bias=False) | ||||||
|  | #         conv.weight.data.fill_(1) | ||||||
|  | #         conv.weight.data *= circle_weights / circle_weights.sum() | ||||||
|  | # | ||||||
|  | #         return conv | ||||||
|  | # | ||||||
|  | # | ||||||
|  | #     def erode(self, input_mask, radius, threshold=1): | ||||||
|  | #         conv = self.conv_list[radius - 1] | ||||||
|  | #         input_float = input_mask.to(torch.float32) | ||||||
|  | #         result = conv(input_float) | ||||||
|  | # | ||||||
|  | #         # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()]) | ||||||
|  | #         # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()]) | ||||||
|  | # | ||||||
|  | #         return result >= threshold | ||||||
|  | # | ||||||
|  | #     def deposit(self, input_mask, radius, threshold=0): | ||||||
|  | #         conv = self.conv_list[radius - 1] | ||||||
|  | #         input_float = input_mask.to(torch.float32) | ||||||
|  | #         result = conv(input_float) | ||||||
|  | # | ||||||
|  | #         # log.debug(['deposit in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()]) | ||||||
|  | #         # log.debug(['deposit out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()]) | ||||||
|  | # | ||||||
|  | #         return result > threshold | ||||||
|  | # | ||||||
|  | #     def fill_cavity(self, input_mask): | ||||||
|  | #         cumsum = input_mask.cumsum(-1) | ||||||
|  | #         filled_mask = (cumsum > 0) | ||||||
|  | #         filled_mask &= (cumsum < cumsum[..., -1:]) | ||||||
|  | #         cumsum = input_mask.cumsum(-2) | ||||||
|  | #         filled_mask &= (cumsum > 0) | ||||||
|  | #         filled_mask &= (cumsum < cumsum[..., -1:, :]) | ||||||
|  | # | ||||||
|  | #         return filled_mask | ||||||
|  | # | ||||||
|  | # | ||||||
|  | #     def forward(self, input_g, raw_pos_g): | ||||||
|  | #         gcc_g = input_g + 1 | ||||||
|  | # | ||||||
|  | #         with torch.no_grad(): | ||||||
|  | #             # log.info(['gcc_g', gcc_g.min(), gcc_g.mean(), gcc_g.max()]) | ||||||
|  | # | ||||||
|  | #             raw_dense_mask = gcc_g > 0.7 | ||||||
|  | #             dense_mask = self.deposit(raw_dense_mask, 2) | ||||||
|  | #             dense_mask = self.erode(dense_mask, 6) | ||||||
|  | #             dense_mask = self.deposit(dense_mask, 4) | ||||||
|  | # | ||||||
|  | #             body_mask = self.fill_cavity(dense_mask) | ||||||
|  | #             air_mask = self.deposit(body_mask & ~dense_mask, 5) | ||||||
|  | #             air_mask = self.erode(air_mask, 6) | ||||||
|  | # | ||||||
|  | #             lung_mask = self.deposit(air_mask, 5) | ||||||
|  | # | ||||||
|  | #             raw_candidate_mask = gcc_g > 0.4 | ||||||
|  | #             raw_candidate_mask &= air_mask | ||||||
|  | #             candidate_mask = self.erode(raw_candidate_mask, 1) | ||||||
|  | #             candidate_mask = self.deposit(candidate_mask, 1) | ||||||
|  | # | ||||||
|  | #             pos_mask = self.deposit((raw_pos_g > 0.5) & lung_mask, 2) | ||||||
|  | # | ||||||
|  | #             neg_mask = self.deposit(candidate_mask, 1) | ||||||
|  | #             neg_mask &= ~pos_mask | ||||||
|  | #             neg_mask &= lung_mask | ||||||
|  | # | ||||||
|  | #             # label_g = (neg_mask | pos_mask).to(torch.float32) | ||||||
|  | #             label_g = (pos_mask).to(torch.float32) | ||||||
|  | #             neg_g = neg_mask.to(torch.float32) | ||||||
|  | #             pos_g = pos_mask.to(torch.float32) | ||||||
|  | # | ||||||
|  | #         mask_dict = { | ||||||
|  | #             'raw_dense_mask': raw_dense_mask, | ||||||
|  | #             'dense_mask': dense_mask, | ||||||
|  | #             'body_mask': body_mask, | ||||||
|  | #             'air_mask': air_mask, | ||||||
|  | #             'raw_candidate_mask': raw_candidate_mask, | ||||||
|  | #             'candidate_mask': candidate_mask, | ||||||
|  | #             'lung_mask': lung_mask, | ||||||
|  | #             'neg_mask': neg_mask, | ||||||
|  | #             'pos_mask': pos_mask, | ||||||
|  | #         } | ||||||
|  | # | ||||||
|  | #         return label_g, neg_g, pos_g, lung_mask, mask_dict | ||||||
| @ -0,0 +1,69 @@ | |||||||
|  | import timing | ||||||
|  | import argparse | ||||||
|  | import sys | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | import torch.nn as nn | ||||||
|  | from torch.autograd import Variable | ||||||
|  | from torch.optim import SGD | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | 
 | ||||||
|  | from util.util import enumerateWithEstimate | ||||||
|  | from .dsets import PrepcacheLunaDataset, getCtSampleSize | ||||||
|  | from util.logconf import logging | ||||||
|  | # from .model import LunaModel | ||||||
|  | 
 | ||||||
|  | log = logging.getLogger(__name__) | ||||||
|  | # log.setLevel(logging.WARN) | ||||||
|  | log.setLevel(logging.INFO) | ||||||
|  | # log.setLevel(logging.DEBUG) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class LunaPrepCacheApp: | ||||||
|  |     @classmethod | ||||||
|  |     def __init__(self, sys_argv=None): | ||||||
|  |         if sys_argv is None: | ||||||
|  |             sys_argv = sys.argv[1:] | ||||||
|  | 
 | ||||||
|  |         parser = argparse.ArgumentParser() | ||||||
|  |         parser.add_argument('--batch-size', | ||||||
|  |             help='Batch size to use for training', | ||||||
|  |             default=1024, | ||||||
|  |             type=int, | ||||||
|  |         ) | ||||||
|  |         parser.add_argument('--num-workers', | ||||||
|  |             help='Number of worker processes for background data loading', | ||||||
|  |             default=8, | ||||||
|  |             type=int, | ||||||
|  |         ) | ||||||
|  |         # parser.add_argument('--scaled', | ||||||
|  |         #     help="Scale the CT chunks to square voxels.", | ||||||
|  |         #     default=False, | ||||||
|  |         #     action='store_true', | ||||||
|  |         # ) | ||||||
|  | 
 | ||||||
|  |         self.cli_args = parser.parse_args(sys_argv) | ||||||
|  | 
 | ||||||
|  |     def main(self): | ||||||
|  |         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) | ||||||
|  | 
 | ||||||
|  |         self.prep_dl = DataLoader( | ||||||
|  |             PrepcacheLunaDataset( | ||||||
|  |                 # sortby_str='series_uid', | ||||||
|  |             ), | ||||||
|  |             batch_size=self.cli_args.batch_size, | ||||||
|  |             num_workers=self.cli_args.num_workers, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         batch_iter = enumerateWithEstimate( | ||||||
|  |             self.prep_dl, | ||||||
|  |             "Stuffing cache", | ||||||
|  |             start_ndx=self.prep_dl.num_workers, | ||||||
|  |         ) | ||||||
|  |         for batch_ndx, batch_tup in batch_iter: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     LunaPrepCacheApp().main() | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user