forked from KEMT/zpwiki
		
	Nahrát soubory do „pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model“
This commit is contained in:
		
							parent
							
								
									35f3889338
								
							
						
					
					
						commit
						526fe93ff9
					
				
										
											Binary file not shown.
										
									
								
							| @ -0,0 +1,76 @@ | ||||
| import argparse | ||||
| import datetime | ||||
| import os | ||||
| import socket | ||||
| import sys | ||||
| 
 | ||||
| import numpy as np | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| 
 | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.optim | ||||
| 
 | ||||
| from torch.optim import SGD, Adam | ||||
| from torch.utils.data import DataLoader | ||||
| 
 | ||||
| from util.util import enumerateWithEstimate | ||||
| from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt | ||||
| from util.logconf import logging | ||||
| from util.util import xyz2irc | ||||
| from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation | ||||
| from p2ch13.train_seg import LunaTrainingApp | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| # log.setLevel(logging.WARN) | ||||
| # log.setLevel(logging.INFO) | ||||
| log.setLevel(logging.DEBUG) | ||||
| 
 | ||||
| class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset): | ||||
|     def __len__(self): | ||||
|         # return 500 | ||||
|         return 5000 | ||||
|         return 1000 | ||||
| 
 | ||||
| class LunaBenchmarkApp(LunaTrainingApp): | ||||
|     def initTrainDl(self): | ||||
|         train_ds = BenchmarkLuna2dSegmentationDataset( | ||||
|             val_stride=10, | ||||
|             isValSet_bool=False, | ||||
|             contextSlices_count=3, | ||||
|             # augmentation_dict=self.augmentation_dict, | ||||
|         ) | ||||
| 
 | ||||
|         batch_size = self.cli_args.batch_size | ||||
|         if self.use_cuda: | ||||
|             batch_size *= torch.cuda.device_count() | ||||
| 
 | ||||
|         train_dl = DataLoader( | ||||
|             train_ds, | ||||
|             batch_size=batch_size, | ||||
|             num_workers=self.cli_args.num_workers, | ||||
|             pin_memory=self.use_cuda, | ||||
|         ) | ||||
| 
 | ||||
|         return train_dl | ||||
| 
 | ||||
|     def main(self): | ||||
|         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) | ||||
| 
 | ||||
|         train_dl = self.initTrainDl() | ||||
| 
 | ||||
|         for epoch_ndx in range(1, 2): | ||||
|             log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format( | ||||
|                 epoch_ndx, | ||||
|                 self.cli_args.epochs, | ||||
|                 len(train_dl), | ||||
|                 len([]), | ||||
|                 self.cli_args.batch_size, | ||||
|                 (torch.cuda.device_count() if self.use_cuda else 1), | ||||
|             )) | ||||
| 
 | ||||
|             self.doTraining(epoch_ndx, train_dl) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     LunaBenchmarkApp().main() | ||||
| @ -0,0 +1,401 @@ | ||||
| import copy | ||||
| import csv | ||||
| import functools | ||||
| import glob | ||||
| import math | ||||
| import os | ||||
| import random | ||||
| 
 | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| import SimpleITK as sitk | ||||
| import numpy as np | ||||
| import scipy.ndimage.morphology as morph | ||||
| 
 | ||||
| import torch | ||||
| import torch.cuda | ||||
| import torch.nn.functional as F | ||||
| from torch.utils.data import Dataset | ||||
| 
 | ||||
| from util.disk import getCache | ||||
| from util.util import XyzTuple, xyz2irc | ||||
| from util.logconf import logging | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| # log.setLevel(logging.WARN) | ||||
| # log.setLevel(logging.INFO) | ||||
| log.setLevel(logging.DEBUG) | ||||
| 
 | ||||
| raw_cache = getCache('part2ch13_raw') | ||||
| 
 | ||||
| MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask') | ||||
| 
 | ||||
| CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz') | ||||
| 
 | ||||
| @functools.lru_cache(1) | ||||
| def getCandidateInfoList(requireOnDisk_bool=True): | ||||
|     # We construct a set with all series_uids that are present on disk. | ||||
|     # This will let us use the data, even if we haven't downloaded all of | ||||
|     # the subsets yet. | ||||
|     mhd_list = glob.glob('data-unversioned/subset*/*.mhd') | ||||
|     presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} | ||||
| 
 | ||||
|     candidateInfo_list = [] | ||||
|     with open('data/annotations_with_malignancy.csv', "r") as f: | ||||
|         for row in list(csv.reader(f))[1:]: | ||||
|             series_uid = row[0] | ||||
|             annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) | ||||
|             annotationDiameter_mm = float(row[4]) | ||||
|             isMal_bool = {'False': False, 'True': True}[row[5]] | ||||
| 
 | ||||
|             candidateInfo_list.append( | ||||
|                 CandidateInfoTuple( | ||||
|                     True, | ||||
|                     True, | ||||
|                     isMal_bool, | ||||
|                     annotationDiameter_mm, | ||||
|                     series_uid, | ||||
|                     annotationCenter_xyz, | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|     with open('data/candidates.csv', "r") as f: | ||||
|         for row in list(csv.reader(f))[1:]: | ||||
|             series_uid = row[0] | ||||
| 
 | ||||
|             if series_uid not in presentOnDisk_set and requireOnDisk_bool: | ||||
|                 continue | ||||
| 
 | ||||
|             isNodule_bool = bool(int(row[4])) | ||||
|             candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) | ||||
| 
 | ||||
|             if not isNodule_bool: | ||||
|                 candidateInfo_list.append( | ||||
|                     CandidateInfoTuple( | ||||
|                         False, | ||||
|                         False, | ||||
|                         False, | ||||
|                         0.0, | ||||
|                         series_uid, | ||||
|                         candidateCenter_xyz, | ||||
|                     ) | ||||
|                 ) | ||||
| 
 | ||||
|     candidateInfo_list.sort(reverse=True) | ||||
|     return candidateInfo_list | ||||
| 
 | ||||
| @functools.lru_cache(1) | ||||
| def getCandidateInfoDict(requireOnDisk_bool=True): | ||||
|     candidateInfo_list = getCandidateInfoList(requireOnDisk_bool) | ||||
|     candidateInfo_dict = {} | ||||
| 
 | ||||
|     for candidateInfo_tup in candidateInfo_list: | ||||
|         candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, | ||||
|                                       []).append(candidateInfo_tup) | ||||
| 
 | ||||
|     return candidateInfo_dict | ||||
| 
 | ||||
| class Ct: | ||||
|     def __init__(self, series_uid): | ||||
|         mhd_path = glob.glob( | ||||
|             'data-unversioned/subset*/{}.mhd'.format(series_uid) | ||||
|         )[0] | ||||
| 
 | ||||
|         ct_mhd = sitk.ReadImage(mhd_path) | ||||
|         self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) | ||||
| 
 | ||||
|         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale | ||||
|         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. | ||||
| 
 | ||||
|         self.series_uid = series_uid | ||||
| 
 | ||||
|         self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) | ||||
|         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) | ||||
|         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3) | ||||
| 
 | ||||
|         candidateInfo_list = getCandidateInfoDict()[self.series_uid] | ||||
| 
 | ||||
|         self.positiveInfo_list = [ | ||||
|             candidate_tup | ||||
|             for candidate_tup in candidateInfo_list | ||||
|             if candidate_tup.isNodule_bool | ||||
|         ] | ||||
|         self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list) | ||||
|         self.positive_indexes = (self.positive_mask.sum(axis=(1,2)) | ||||
|                                  .nonzero()[0].tolist()) | ||||
| 
 | ||||
|     def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700): | ||||
|         boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool) | ||||
| 
 | ||||
|         for candidateInfo_tup in positiveInfo_list: | ||||
|             center_irc = xyz2irc( | ||||
|                 candidateInfo_tup.center_xyz, | ||||
|                 self.origin_xyz, | ||||
|                 self.vxSize_xyz, | ||||
|                 self.direction_a, | ||||
|             ) | ||||
|             ci = int(center_irc.index) | ||||
|             cr = int(center_irc.row) | ||||
|             cc = int(center_irc.col) | ||||
| 
 | ||||
|             index_radius = 2 | ||||
|             try: | ||||
|                 while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \ | ||||
|                         self.hu_a[ci - index_radius, cr, cc] > threshold_hu: | ||||
|                     index_radius += 1 | ||||
|             except IndexError: | ||||
|                 index_radius -= 1 | ||||
| 
 | ||||
|             row_radius = 2 | ||||
|             try: | ||||
|                 while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \ | ||||
|                         self.hu_a[ci, cr - row_radius, cc] > threshold_hu: | ||||
|                     row_radius += 1 | ||||
|             except IndexError: | ||||
|                 row_radius -= 1 | ||||
| 
 | ||||
|             col_radius = 2 | ||||
|             try: | ||||
|                 while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \ | ||||
|                         self.hu_a[ci, cr, cc - col_radius] > threshold_hu: | ||||
|                     col_radius += 1 | ||||
|             except IndexError: | ||||
|                 col_radius -= 1 | ||||
| 
 | ||||
|             # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]]) | ||||
|             # assert row_radius > 0 | ||||
|             # assert col_radius > 0 | ||||
| 
 | ||||
|             boundingBox_a[ | ||||
|                  ci - index_radius: ci + index_radius + 1, | ||||
|                  cr - row_radius: cr + row_radius + 1, | ||||
|                  cc - col_radius: cc + col_radius + 1] = True | ||||
| 
 | ||||
|         mask_a = boundingBox_a & (self.hu_a > threshold_hu) | ||||
| 
 | ||||
|         return mask_a | ||||
| 
 | ||||
|     def getRawCandidate(self, center_xyz, width_irc): | ||||
|         center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, | ||||
|                              self.direction_a) | ||||
| 
 | ||||
|         slice_list = [] | ||||
|         for axis, center_val in enumerate(center_irc): | ||||
|             start_ndx = int(round(center_val - width_irc[axis]/2)) | ||||
|             end_ndx = int(start_ndx + width_irc[axis]) | ||||
| 
 | ||||
|             assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis]) | ||||
| 
 | ||||
|             if start_ndx < 0: | ||||
|                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( | ||||
|                 #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) | ||||
|                 start_ndx = 0 | ||||
|                 end_ndx = int(width_irc[axis]) | ||||
| 
 | ||||
|             if end_ndx > self.hu_a.shape[axis]: | ||||
|                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( | ||||
|                 #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) | ||||
|                 end_ndx = self.hu_a.shape[axis] | ||||
|                 start_ndx = int(self.hu_a.shape[axis] - width_irc[axis]) | ||||
| 
 | ||||
|             slice_list.append(slice(start_ndx, end_ndx)) | ||||
| 
 | ||||
|         ct_chunk = self.hu_a[tuple(slice_list)] | ||||
|         pos_chunk = self.positive_mask[tuple(slice_list)] | ||||
| 
 | ||||
|         return ct_chunk, pos_chunk, center_irc | ||||
| 
 | ||||
| @functools.lru_cache(1, typed=True) | ||||
| def getCt(series_uid): | ||||
|     return Ct(series_uid) | ||||
| 
 | ||||
| @raw_cache.memoize(typed=True) | ||||
| def getCtRawCandidate(series_uid, center_xyz, width_irc): | ||||
|     ct = getCt(series_uid) | ||||
|     ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz, | ||||
|                                                          width_irc) | ||||
|     ct_chunk.clip(-1000, 1000, ct_chunk) | ||||
|     return ct_chunk, pos_chunk, center_irc | ||||
| 
 | ||||
| @raw_cache.memoize(typed=True) | ||||
| def getCtSampleSize(series_uid): | ||||
|     ct = Ct(series_uid) | ||||
|     return int(ct.hu_a.shape[0]), ct.positive_indexes | ||||
| 
 | ||||
| 
 | ||||
| class Luna2dSegmentationDataset(Dataset): | ||||
|     def __init__(self, | ||||
|                  val_stride=0, | ||||
|                  isValSet_bool=None, | ||||
|                  series_uid=None, | ||||
|                  contextSlices_count=3, | ||||
|                  fullCt_bool=False, | ||||
|             ): | ||||
|         self.contextSlices_count = contextSlices_count | ||||
|         self.fullCt_bool = fullCt_bool | ||||
| 
 | ||||
|         if series_uid: | ||||
|             self.series_list = [series_uid] | ||||
|         else: | ||||
|             self.series_list = sorted(getCandidateInfoDict().keys()) | ||||
| 
 | ||||
|         if isValSet_bool: | ||||
|             assert val_stride > 0, val_stride | ||||
|             self.series_list = self.series_list[::val_stride] | ||||
|             assert self.series_list | ||||
|         elif val_stride > 0: | ||||
|             del self.series_list[::val_stride] | ||||
|             assert self.series_list | ||||
| 
 | ||||
|         self.sample_list = [] | ||||
|         for series_uid in self.series_list: | ||||
|             index_count, positive_indexes = getCtSampleSize(series_uid) | ||||
| 
 | ||||
|             if self.fullCt_bool: | ||||
|                 self.sample_list += [(series_uid, slice_ndx) | ||||
|                                      for slice_ndx in range(index_count)] | ||||
|             else: | ||||
|                 self.sample_list += [(series_uid, slice_ndx) | ||||
|                                      for slice_ndx in positive_indexes] | ||||
| 
 | ||||
|         self.candidateInfo_list = getCandidateInfoList() | ||||
| 
 | ||||
|         series_set = set(self.series_list) | ||||
|         self.candidateInfo_list = [cit for cit in self.candidateInfo_list | ||||
|                                    if cit.series_uid in series_set] | ||||
| 
 | ||||
|         self.pos_list = [nt for nt in self.candidateInfo_list | ||||
|                             if nt.isNodule_bool] | ||||
| 
 | ||||
|         log.info("{!r}: {} {} series, {} slices, {} nodules".format( | ||||
|             self, | ||||
|             len(self.series_list), | ||||
|             {None: 'general', True: 'validation', False: 'training'}[isValSet_bool], | ||||
|             len(self.sample_list), | ||||
|             len(self.pos_list), | ||||
|         )) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return len(self.sample_list) | ||||
| 
 | ||||
|     def __getitem__(self, ndx): | ||||
|         series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)] | ||||
|         return self.getitem_fullSlice(series_uid, slice_ndx) | ||||
| 
 | ||||
|     def getitem_fullSlice(self, series_uid, slice_ndx): | ||||
|         ct = getCt(series_uid) | ||||
|         ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512)) | ||||
| 
 | ||||
|         start_ndx = slice_ndx - self.contextSlices_count | ||||
|         end_ndx = slice_ndx + self.contextSlices_count + 1 | ||||
|         for i, context_ndx in enumerate(range(start_ndx, end_ndx)): | ||||
|             context_ndx = max(context_ndx, 0) | ||||
|             context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1) | ||||
|             ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32)) | ||||
| 
 | ||||
|         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale | ||||
|         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. | ||||
|         # The lower bound gets rid of negative density stuff used to indicate out-of-FOV | ||||
|         # The upper bound nukes any weird hotspots and clamps bone down | ||||
|         ct_t.clamp_(-1000, 1000) | ||||
| 
 | ||||
|         pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0) | ||||
| 
 | ||||
|         return ct_t, pos_t, ct.series_uid, slice_ndx | ||||
| 
 | ||||
| 
 | ||||
| class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
| 
 | ||||
|         self.ratio_int = 2 | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return 300000 | ||||
| 
 | ||||
|     def shuffleSamples(self): | ||||
|         random.shuffle(self.candidateInfo_list) | ||||
|         random.shuffle(self.pos_list) | ||||
| 
 | ||||
|     def __getitem__(self, ndx): | ||||
|         candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)] | ||||
|         return self.getitem_trainingCrop(candidateInfo_tup) | ||||
| 
 | ||||
|     def getitem_trainingCrop(self, candidateInfo_tup): | ||||
|         ct_a, pos_a, center_irc = getCtRawCandidate( | ||||
|             candidateInfo_tup.series_uid, | ||||
|             candidateInfo_tup.center_xyz, | ||||
|             (7, 96, 96), | ||||
|         ) | ||||
|         pos_a = pos_a[3:4] | ||||
| 
 | ||||
|         row_offset = random.randrange(0,32) | ||||
|         col_offset = random.randrange(0,32) | ||||
|         ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64, | ||||
|                                      col_offset:col_offset+64]).to(torch.float32) | ||||
|         pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64, | ||||
|                                        col_offset:col_offset+64]).to(torch.long) | ||||
| 
 | ||||
|         slice_ndx = center_irc.index | ||||
| 
 | ||||
|         return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx | ||||
| 
 | ||||
| class PrepcacheLunaDataset(Dataset): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
| 
 | ||||
|         self.candidateInfo_list = getCandidateInfoList() | ||||
|         self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool] | ||||
| 
 | ||||
|         self.seen_set = set() | ||||
|         self.candidateInfo_list.sort(key=lambda x: x.series_uid) | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return len(self.candidateInfo_list) | ||||
| 
 | ||||
|     def __getitem__(self, ndx): | ||||
|         # candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx) | ||||
| 
 | ||||
|         candidateInfo_tup = self.candidateInfo_list[ndx] | ||||
|         getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96)) | ||||
| 
 | ||||
|         series_uid = candidateInfo_tup.series_uid | ||||
|         if series_uid not in self.seen_set: | ||||
|             self.seen_set.add(series_uid) | ||||
| 
 | ||||
|             getCtSampleSize(series_uid) | ||||
|             # ct = getCt(series_uid) | ||||
|             # for mask_ndx in ct.positive_indexes: | ||||
|             #     build2dLungMask(series_uid, mask_ndx) | ||||
| 
 | ||||
|         return 0, 1 #candidate_t, pos_t, series_uid, center_t | ||||
| 
 | ||||
| 
 | ||||
| class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset): | ||||
|     def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3): | ||||
|         assert contextSlices_count == 3 | ||||
|         data = torch.load('./imgs_and_masks.pt') | ||||
|         suids = list(set(data['suids'])) | ||||
|         trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1) | ||||
|         trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i} | ||||
|         trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]]) | ||||
|         if not isValSet_bool: | ||||
|             self.imgs = data["imgs"][trn_mask] | ||||
|             self.masks = data["masks"][trn_mask] | ||||
|             self.suids = [s for s, i in zip(data["suids"], trn_mask) if i] | ||||
|         else: | ||||
|             self.imgs = data["imgs"][~trn_mask] | ||||
|             self.masks = data["masks"][~trn_mask] | ||||
|             self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i] | ||||
|         # discard spurious hotspots and clamp bone | ||||
|         self.imgs.clamp_(-1000, 1000) | ||||
|         self.imgs /= 1000 | ||||
| 
 | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return len(self.imgs) | ||||
| 
 | ||||
|     def __getitem__(self, i): | ||||
|         oh, ow = torch.randint(0, 32, (2,)) | ||||
|         sl = self.masks.size(1)//2 | ||||
|         return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999 | ||||
| @ -0,0 +1,224 @@ | ||||
| import math | ||||
| import random | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| import torch | ||||
| from torch import nn as nn | ||||
| import torch.nn.functional as F | ||||
| 
 | ||||
| from util.logconf import logging | ||||
| from util.unet import UNet | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| # log.setLevel(logging.WARN) | ||||
| # log.setLevel(logging.INFO) | ||||
| log.setLevel(logging.DEBUG) | ||||
| 
 | ||||
| class UNetWrapper(nn.Module): | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels']) | ||||
|         self.unet = UNet(**kwargs) | ||||
|         self.final = nn.Sigmoid() | ||||
| 
 | ||||
|         self._init_weights() | ||||
| 
 | ||||
|     def _init_weights(self): | ||||
|         init_set = { | ||||
|             nn.Conv2d, | ||||
|             nn.Conv3d, | ||||
|             nn.ConvTranspose2d, | ||||
|             nn.ConvTranspose3d, | ||||
|             nn.Linear, | ||||
|         } | ||||
|         for m in self.modules(): | ||||
|             if type(m) in init_set: | ||||
|                 nn.init.kaiming_normal_( | ||||
|                     m.weight.data, mode='fan_out', nonlinearity='relu', a=0 | ||||
|                 ) | ||||
|                 if m.bias is not None: | ||||
|                     fan_in, fan_out = \ | ||||
|                         nn.init._calculate_fan_in_and_fan_out(m.weight.data) | ||||
|                     bound = 1 / math.sqrt(fan_out) | ||||
|                     nn.init.normal_(m.bias, -bound, bound) | ||||
| 
 | ||||
|         # nn.init.constant_(self.unet.last.bias, -4) | ||||
|         # nn.init.constant_(self.unet.last.bias, 4) | ||||
| 
 | ||||
| 
 | ||||
|     def forward(self, input_batch): | ||||
|         bn_output = self.input_batchnorm(input_batch) | ||||
|         un_output = self.unet(bn_output) | ||||
|         fn_output = self.final(un_output) | ||||
|         return fn_output | ||||
| 
 | ||||
| class SegmentationAugmentation(nn.Module): | ||||
|     def __init__( | ||||
|             self, flip=None, offset=None, scale=None, rotate=None, noise=None | ||||
|     ): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.flip = flip | ||||
|         self.offset = offset | ||||
|         self.scale = scale | ||||
|         self.rotate = rotate | ||||
|         self.noise = noise | ||||
| 
 | ||||
|     def forward(self, input_g, label_g): | ||||
|         transform_t = self._build2dTransformMatrix() | ||||
|         transform_t = transform_t.expand(input_g.shape[0], -1, -1) | ||||
|         transform_t = transform_t.to(input_g.device, torch.float32) | ||||
|         affine_t = F.affine_grid(transform_t[:,:2], | ||||
|                 input_g.size(), align_corners=False) | ||||
| 
 | ||||
|         augmented_input_g = F.grid_sample(input_g, | ||||
|                 affine_t, padding_mode='border', | ||||
|                 align_corners=False) | ||||
|         augmented_label_g = F.grid_sample(label_g.to(torch.float32), | ||||
|                 affine_t, padding_mode='border', | ||||
|                 align_corners=False) | ||||
| 
 | ||||
|         if self.noise: | ||||
|             noise_t = torch.randn_like(augmented_input_g) | ||||
|             noise_t *= self.noise | ||||
| 
 | ||||
|             augmented_input_g += noise_t | ||||
| 
 | ||||
|         return augmented_input_g, augmented_label_g > 0.5 | ||||
| 
 | ||||
|     def _build2dTransformMatrix(self): | ||||
|         transform_t = torch.eye(3) | ||||
| 
 | ||||
|         for i in range(2): | ||||
|             if self.flip: | ||||
|                 if random.random() > 0.5: | ||||
|                     transform_t[i,i] *= -1 | ||||
| 
 | ||||
|             if self.offset: | ||||
|                 offset_float = self.offset | ||||
|                 random_float = (random.random() * 2 - 1) | ||||
|                 transform_t[2,i] = offset_float * random_float | ||||
| 
 | ||||
|             if self.scale: | ||||
|                 scale_float = self.scale | ||||
|                 random_float = (random.random() * 2 - 1) | ||||
|                 transform_t[i,i] *= 1.0 + scale_float * random_float | ||||
| 
 | ||||
|         if self.rotate: | ||||
|             angle_rad = random.random() * math.pi * 2 | ||||
|             s = math.sin(angle_rad) | ||||
|             c = math.cos(angle_rad) | ||||
| 
 | ||||
|             rotation_t = torch.tensor([ | ||||
|                 [c, -s, 0], | ||||
|                 [s, c, 0], | ||||
|                 [0, 0, 1]]) | ||||
| 
 | ||||
|             transform_t @= rotation_t | ||||
| 
 | ||||
|         return transform_t | ||||
| 
 | ||||
| 
 | ||||
| # MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask') | ||||
| # | ||||
| # class SegmentationMask(nn.Module): | ||||
| #     def __init__(self): | ||||
| #         super().__init__() | ||||
| # | ||||
| #         self.conv_list = nn.ModuleList([ | ||||
| #             self._make_circle_conv(radius) for radius in range(1, 8) | ||||
| #         ]) | ||||
| # | ||||
| #     def _make_circle_conv(self, radius): | ||||
| #         diameter = 1 + radius * 2 | ||||
| # | ||||
| #         a = torch.linspace(-1, 1, steps=diameter)**2 | ||||
| #         b = (a[None] + a[:, None])**0.5 | ||||
| # | ||||
| #         circle_weights = (b <= 1.0).to(torch.float32) | ||||
| # | ||||
| #         conv = nn.Conv2d(1, 1, kernel_size=diameter, padding=radius, bias=False) | ||||
| #         conv.weight.data.fill_(1) | ||||
| #         conv.weight.data *= circle_weights / circle_weights.sum() | ||||
| # | ||||
| #         return conv | ||||
| # | ||||
| # | ||||
| #     def erode(self, input_mask, radius, threshold=1): | ||||
| #         conv = self.conv_list[radius - 1] | ||||
| #         input_float = input_mask.to(torch.float32) | ||||
| #         result = conv(input_float) | ||||
| # | ||||
| #         # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()]) | ||||
| #         # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()]) | ||||
| # | ||||
| #         return result >= threshold | ||||
| # | ||||
| #     def deposit(self, input_mask, radius, threshold=0): | ||||
| #         conv = self.conv_list[radius - 1] | ||||
| #         input_float = input_mask.to(torch.float32) | ||||
| #         result = conv(input_float) | ||||
| # | ||||
| #         # log.debug(['deposit in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()]) | ||||
| #         # log.debug(['deposit out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()]) | ||||
| # | ||||
| #         return result > threshold | ||||
| # | ||||
| #     def fill_cavity(self, input_mask): | ||||
| #         cumsum = input_mask.cumsum(-1) | ||||
| #         filled_mask = (cumsum > 0) | ||||
| #         filled_mask &= (cumsum < cumsum[..., -1:]) | ||||
| #         cumsum = input_mask.cumsum(-2) | ||||
| #         filled_mask &= (cumsum > 0) | ||||
| #         filled_mask &= (cumsum < cumsum[..., -1:, :]) | ||||
| # | ||||
| #         return filled_mask | ||||
| # | ||||
| # | ||||
| #     def forward(self, input_g, raw_pos_g): | ||||
| #         gcc_g = input_g + 1 | ||||
| # | ||||
| #         with torch.no_grad(): | ||||
| #             # log.info(['gcc_g', gcc_g.min(), gcc_g.mean(), gcc_g.max()]) | ||||
| # | ||||
| #             raw_dense_mask = gcc_g > 0.7 | ||||
| #             dense_mask = self.deposit(raw_dense_mask, 2) | ||||
| #             dense_mask = self.erode(dense_mask, 6) | ||||
| #             dense_mask = self.deposit(dense_mask, 4) | ||||
| # | ||||
| #             body_mask = self.fill_cavity(dense_mask) | ||||
| #             air_mask = self.deposit(body_mask & ~dense_mask, 5) | ||||
| #             air_mask = self.erode(air_mask, 6) | ||||
| # | ||||
| #             lung_mask = self.deposit(air_mask, 5) | ||||
| # | ||||
| #             raw_candidate_mask = gcc_g > 0.4 | ||||
| #             raw_candidate_mask &= air_mask | ||||
| #             candidate_mask = self.erode(raw_candidate_mask, 1) | ||||
| #             candidate_mask = self.deposit(candidate_mask, 1) | ||||
| # | ||||
| #             pos_mask = self.deposit((raw_pos_g > 0.5) & lung_mask, 2) | ||||
| # | ||||
| #             neg_mask = self.deposit(candidate_mask, 1) | ||||
| #             neg_mask &= ~pos_mask | ||||
| #             neg_mask &= lung_mask | ||||
| # | ||||
| #             # label_g = (neg_mask | pos_mask).to(torch.float32) | ||||
| #             label_g = (pos_mask).to(torch.float32) | ||||
| #             neg_g = neg_mask.to(torch.float32) | ||||
| #             pos_g = pos_mask.to(torch.float32) | ||||
| # | ||||
| #         mask_dict = { | ||||
| #             'raw_dense_mask': raw_dense_mask, | ||||
| #             'dense_mask': dense_mask, | ||||
| #             'body_mask': body_mask, | ||||
| #             'air_mask': air_mask, | ||||
| #             'raw_candidate_mask': raw_candidate_mask, | ||||
| #             'candidate_mask': candidate_mask, | ||||
| #             'lung_mask': lung_mask, | ||||
| #             'neg_mask': neg_mask, | ||||
| #             'pos_mask': pos_mask, | ||||
| #         } | ||||
| # | ||||
| #         return label_g, neg_g, pos_g, lung_mask, mask_dict | ||||
| @ -0,0 +1,69 @@ | ||||
| import timing | ||||
| import argparse | ||||
| import sys | ||||
| 
 | ||||
| import numpy as np | ||||
| 
 | ||||
| import torch.nn as nn | ||||
| from torch.autograd import Variable | ||||
| from torch.optim import SGD | ||||
| from torch.utils.data import DataLoader | ||||
| 
 | ||||
| from util.util import enumerateWithEstimate | ||||
| from .dsets import PrepcacheLunaDataset, getCtSampleSize | ||||
| from util.logconf import logging | ||||
| # from .model import LunaModel | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| # log.setLevel(logging.WARN) | ||||
| log.setLevel(logging.INFO) | ||||
| # log.setLevel(logging.DEBUG) | ||||
| 
 | ||||
| 
 | ||||
| class LunaPrepCacheApp: | ||||
|     @classmethod | ||||
|     def __init__(self, sys_argv=None): | ||||
|         if sys_argv is None: | ||||
|             sys_argv = sys.argv[1:] | ||||
| 
 | ||||
|         parser = argparse.ArgumentParser() | ||||
|         parser.add_argument('--batch-size', | ||||
|             help='Batch size to use for training', | ||||
|             default=1024, | ||||
|             type=int, | ||||
|         ) | ||||
|         parser.add_argument('--num-workers', | ||||
|             help='Number of worker processes for background data loading', | ||||
|             default=8, | ||||
|             type=int, | ||||
|         ) | ||||
|         # parser.add_argument('--scaled', | ||||
|         #     help="Scale the CT chunks to square voxels.", | ||||
|         #     default=False, | ||||
|         #     action='store_true', | ||||
|         # ) | ||||
| 
 | ||||
|         self.cli_args = parser.parse_args(sys_argv) | ||||
| 
 | ||||
|     def main(self): | ||||
|         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) | ||||
| 
 | ||||
|         self.prep_dl = DataLoader( | ||||
|             PrepcacheLunaDataset( | ||||
|                 # sortby_str='series_uid', | ||||
|             ), | ||||
|             batch_size=self.cli_args.batch_size, | ||||
|             num_workers=self.cli_args.num_workers, | ||||
|         ) | ||||
| 
 | ||||
|         batch_iter = enumerateWithEstimate( | ||||
|             self.prep_dl, | ||||
|             "Stuffing cache", | ||||
|             start_ndx=self.prep_dl.num_workers, | ||||
|         ) | ||||
|         for batch_ndx, batch_tup in batch_iter: | ||||
|             pass | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     LunaPrepCacheApp().main() | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user