forked from KEMT/zpwiki
		
	Nahrát soubory do „pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model“
This commit is contained in:
		
							parent
							
								
									35f3889338
								
							
						
					
					
						commit
						526fe93ff9
					
				
										
											Binary file not shown.
										
									
								
							@ -0,0 +1,76 @@
 | 
				
			|||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import datetime
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import socket
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from torch.utils.tensorboard import SummaryWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					import torch.optim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from torch.optim import SGD, Adam
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from util.util import enumerateWithEstimate
 | 
				
			||||||
 | 
					from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
 | 
				
			||||||
 | 
					from util.logconf import logging
 | 
				
			||||||
 | 
					from util.util import xyz2irc
 | 
				
			||||||
 | 
					from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation
 | 
				
			||||||
 | 
					from p2ch13.train_seg import LunaTrainingApp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					# log.setLevel(logging.WARN)
 | 
				
			||||||
 | 
					# log.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					log.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset):
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        # return 500
 | 
				
			||||||
 | 
					        return 5000
 | 
				
			||||||
 | 
					        return 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LunaBenchmarkApp(LunaTrainingApp):
 | 
				
			||||||
 | 
					    def initTrainDl(self):
 | 
				
			||||||
 | 
					        train_ds = BenchmarkLuna2dSegmentationDataset(
 | 
				
			||||||
 | 
					            val_stride=10,
 | 
				
			||||||
 | 
					            isValSet_bool=False,
 | 
				
			||||||
 | 
					            contextSlices_count=3,
 | 
				
			||||||
 | 
					            # augmentation_dict=self.augmentation_dict,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size = self.cli_args.batch_size
 | 
				
			||||||
 | 
					        if self.use_cuda:
 | 
				
			||||||
 | 
					            batch_size *= torch.cuda.device_count()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        train_dl = DataLoader(
 | 
				
			||||||
 | 
					            train_ds,
 | 
				
			||||||
 | 
					            batch_size=batch_size,
 | 
				
			||||||
 | 
					            num_workers=self.cli_args.num_workers,
 | 
				
			||||||
 | 
					            pin_memory=self.use_cuda,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return train_dl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def main(self):
 | 
				
			||||||
 | 
					        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        train_dl = self.initTrainDl()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for epoch_ndx in range(1, 2):
 | 
				
			||||||
 | 
					            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
 | 
				
			||||||
 | 
					                epoch_ndx,
 | 
				
			||||||
 | 
					                self.cli_args.epochs,
 | 
				
			||||||
 | 
					                len(train_dl),
 | 
				
			||||||
 | 
					                len([]),
 | 
				
			||||||
 | 
					                self.cli_args.batch_size,
 | 
				
			||||||
 | 
					                (torch.cuda.device_count() if self.use_cuda else 1),
 | 
				
			||||||
 | 
					            ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.doTraining(epoch_ndx, train_dl)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    LunaBenchmarkApp().main()
 | 
				
			||||||
@ -0,0 +1,401 @@
 | 
				
			|||||||
 | 
					import copy
 | 
				
			||||||
 | 
					import csv
 | 
				
			||||||
 | 
					import functools
 | 
				
			||||||
 | 
					import glob
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from collections import namedtuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import SimpleITK as sitk
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import scipy.ndimage.morphology as morph
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.cuda
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					from torch.utils.data import Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from util.disk import getCache
 | 
				
			||||||
 | 
					from util.util import XyzTuple, xyz2irc
 | 
				
			||||||
 | 
					from util.logconf import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					# log.setLevel(logging.WARN)
 | 
				
			||||||
 | 
					# log.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					log.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					raw_cache = getCache('part2ch13_raw')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@functools.lru_cache(1)
 | 
				
			||||||
 | 
					def getCandidateInfoList(requireOnDisk_bool=True):
 | 
				
			||||||
 | 
					    # We construct a set with all series_uids that are present on disk.
 | 
				
			||||||
 | 
					    # This will let us use the data, even if we haven't downloaded all of
 | 
				
			||||||
 | 
					    # the subsets yet.
 | 
				
			||||||
 | 
					    mhd_list = glob.glob('data-unversioned/subset*/*.mhd')
 | 
				
			||||||
 | 
					    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    candidateInfo_list = []
 | 
				
			||||||
 | 
					    with open('data/annotations_with_malignancy.csv', "r") as f:
 | 
				
			||||||
 | 
					        for row in list(csv.reader(f))[1:]:
 | 
				
			||||||
 | 
					            series_uid = row[0]
 | 
				
			||||||
 | 
					            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
 | 
				
			||||||
 | 
					            annotationDiameter_mm = float(row[4])
 | 
				
			||||||
 | 
					            isMal_bool = {'False': False, 'True': True}[row[5]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            candidateInfo_list.append(
 | 
				
			||||||
 | 
					                CandidateInfoTuple(
 | 
				
			||||||
 | 
					                    True,
 | 
				
			||||||
 | 
					                    True,
 | 
				
			||||||
 | 
					                    isMal_bool,
 | 
				
			||||||
 | 
					                    annotationDiameter_mm,
 | 
				
			||||||
 | 
					                    series_uid,
 | 
				
			||||||
 | 
					                    annotationCenter_xyz,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open('data/candidates.csv', "r") as f:
 | 
				
			||||||
 | 
					        for row in list(csv.reader(f))[1:]:
 | 
				
			||||||
 | 
					            series_uid = row[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            isNodule_bool = bool(int(row[4]))
 | 
				
			||||||
 | 
					            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not isNodule_bool:
 | 
				
			||||||
 | 
					                candidateInfo_list.append(
 | 
				
			||||||
 | 
					                    CandidateInfoTuple(
 | 
				
			||||||
 | 
					                        False,
 | 
				
			||||||
 | 
					                        False,
 | 
				
			||||||
 | 
					                        False,
 | 
				
			||||||
 | 
					                        0.0,
 | 
				
			||||||
 | 
					                        series_uid,
 | 
				
			||||||
 | 
					                        candidateCenter_xyz,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    candidateInfo_list.sort(reverse=True)
 | 
				
			||||||
 | 
					    return candidateInfo_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@functools.lru_cache(1)
 | 
				
			||||||
 | 
					def getCandidateInfoDict(requireOnDisk_bool=True):
 | 
				
			||||||
 | 
					    candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
 | 
				
			||||||
 | 
					    candidateInfo_dict = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for candidateInfo_tup in candidateInfo_list:
 | 
				
			||||||
 | 
					        candidateInfo_dict.setdefault(candidateInfo_tup.series_uid,
 | 
				
			||||||
 | 
					                                      []).append(candidateInfo_tup)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return candidateInfo_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Ct:
 | 
				
			||||||
 | 
					    def __init__(self, series_uid):
 | 
				
			||||||
 | 
					        mhd_path = glob.glob(
 | 
				
			||||||
 | 
					            'data-unversioned/subset*/{}.mhd'.format(series_uid)
 | 
				
			||||||
 | 
					        )[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ct_mhd = sitk.ReadImage(mhd_path)
 | 
				
			||||||
 | 
					        self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
 | 
				
			||||||
 | 
					        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.series_uid = series_uid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
 | 
				
			||||||
 | 
					        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
 | 
				
			||||||
 | 
					        self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        candidateInfo_list = getCandidateInfoDict()[self.series_uid]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.positiveInfo_list = [
 | 
				
			||||||
 | 
					            candidate_tup
 | 
				
			||||||
 | 
					            for candidate_tup in candidateInfo_list
 | 
				
			||||||
 | 
					            if candidate_tup.isNodule_bool
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
 | 
				
			||||||
 | 
					        self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
 | 
				
			||||||
 | 
					                                 .nonzero()[0].tolist())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
 | 
				
			||||||
 | 
					        boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for candidateInfo_tup in positiveInfo_list:
 | 
				
			||||||
 | 
					            center_irc = xyz2irc(
 | 
				
			||||||
 | 
					                candidateInfo_tup.center_xyz,
 | 
				
			||||||
 | 
					                self.origin_xyz,
 | 
				
			||||||
 | 
					                self.vxSize_xyz,
 | 
				
			||||||
 | 
					                self.direction_a,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            ci = int(center_irc.index)
 | 
				
			||||||
 | 
					            cr = int(center_irc.row)
 | 
				
			||||||
 | 
					            cc = int(center_irc.col)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            index_radius = 2
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
 | 
				
			||||||
 | 
					                        self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
 | 
				
			||||||
 | 
					                    index_radius += 1
 | 
				
			||||||
 | 
					            except IndexError:
 | 
				
			||||||
 | 
					                index_radius -= 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            row_radius = 2
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
 | 
				
			||||||
 | 
					                        self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
 | 
				
			||||||
 | 
					                    row_radius += 1
 | 
				
			||||||
 | 
					            except IndexError:
 | 
				
			||||||
 | 
					                row_radius -= 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            col_radius = 2
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
 | 
				
			||||||
 | 
					                        self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
 | 
				
			||||||
 | 
					                    col_radius += 1
 | 
				
			||||||
 | 
					            except IndexError:
 | 
				
			||||||
 | 
					                col_radius -= 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
 | 
				
			||||||
 | 
					            # assert row_radius > 0
 | 
				
			||||||
 | 
					            # assert col_radius > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            boundingBox_a[
 | 
				
			||||||
 | 
					                 ci - index_radius: ci + index_radius + 1,
 | 
				
			||||||
 | 
					                 cr - row_radius: cr + row_radius + 1,
 | 
				
			||||||
 | 
					                 cc - col_radius: cc + col_radius + 1] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        mask_a = boundingBox_a & (self.hu_a > threshold_hu)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return mask_a
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def getRawCandidate(self, center_xyz, width_irc):
 | 
				
			||||||
 | 
					        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
 | 
				
			||||||
 | 
					                             self.direction_a)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        slice_list = []
 | 
				
			||||||
 | 
					        for axis, center_val in enumerate(center_irc):
 | 
				
			||||||
 | 
					            start_ndx = int(round(center_val - width_irc[axis]/2))
 | 
				
			||||||
 | 
					            end_ndx = int(start_ndx + width_irc[axis])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if start_ndx < 0:
 | 
				
			||||||
 | 
					                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
 | 
				
			||||||
 | 
					                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
 | 
				
			||||||
 | 
					                start_ndx = 0
 | 
				
			||||||
 | 
					                end_ndx = int(width_irc[axis])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if end_ndx > self.hu_a.shape[axis]:
 | 
				
			||||||
 | 
					                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
 | 
				
			||||||
 | 
					                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
 | 
				
			||||||
 | 
					                end_ndx = self.hu_a.shape[axis]
 | 
				
			||||||
 | 
					                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            slice_list.append(slice(start_ndx, end_ndx))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ct_chunk = self.hu_a[tuple(slice_list)]
 | 
				
			||||||
 | 
					        pos_chunk = self.positive_mask[tuple(slice_list)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ct_chunk, pos_chunk, center_irc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@functools.lru_cache(1, typed=True)
 | 
				
			||||||
 | 
					def getCt(series_uid):
 | 
				
			||||||
 | 
					    return Ct(series_uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@raw_cache.memoize(typed=True)
 | 
				
			||||||
 | 
					def getCtRawCandidate(series_uid, center_xyz, width_irc):
 | 
				
			||||||
 | 
					    ct = getCt(series_uid)
 | 
				
			||||||
 | 
					    ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz,
 | 
				
			||||||
 | 
					                                                         width_irc)
 | 
				
			||||||
 | 
					    ct_chunk.clip(-1000, 1000, ct_chunk)
 | 
				
			||||||
 | 
					    return ct_chunk, pos_chunk, center_irc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@raw_cache.memoize(typed=True)
 | 
				
			||||||
 | 
					def getCtSampleSize(series_uid):
 | 
				
			||||||
 | 
					    ct = Ct(series_uid)
 | 
				
			||||||
 | 
					    return int(ct.hu_a.shape[0]), ct.positive_indexes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Luna2dSegmentationDataset(Dataset):
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 val_stride=0,
 | 
				
			||||||
 | 
					                 isValSet_bool=None,
 | 
				
			||||||
 | 
					                 series_uid=None,
 | 
				
			||||||
 | 
					                 contextSlices_count=3,
 | 
				
			||||||
 | 
					                 fullCt_bool=False,
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					        self.contextSlices_count = contextSlices_count
 | 
				
			||||||
 | 
					        self.fullCt_bool = fullCt_bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if series_uid:
 | 
				
			||||||
 | 
					            self.series_list = [series_uid]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.series_list = sorted(getCandidateInfoDict().keys())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isValSet_bool:
 | 
				
			||||||
 | 
					            assert val_stride > 0, val_stride
 | 
				
			||||||
 | 
					            self.series_list = self.series_list[::val_stride]
 | 
				
			||||||
 | 
					            assert self.series_list
 | 
				
			||||||
 | 
					        elif val_stride > 0:
 | 
				
			||||||
 | 
					            del self.series_list[::val_stride]
 | 
				
			||||||
 | 
					            assert self.series_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.sample_list = []
 | 
				
			||||||
 | 
					        for series_uid in self.series_list:
 | 
				
			||||||
 | 
					            index_count, positive_indexes = getCtSampleSize(series_uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.fullCt_bool:
 | 
				
			||||||
 | 
					                self.sample_list += [(series_uid, slice_ndx)
 | 
				
			||||||
 | 
					                                     for slice_ndx in range(index_count)]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.sample_list += [(series_uid, slice_ndx)
 | 
				
			||||||
 | 
					                                     for slice_ndx in positive_indexes]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.candidateInfo_list = getCandidateInfoList()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        series_set = set(self.series_list)
 | 
				
			||||||
 | 
					        self.candidateInfo_list = [cit for cit in self.candidateInfo_list
 | 
				
			||||||
 | 
					                                   if cit.series_uid in series_set]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.pos_list = [nt for nt in self.candidateInfo_list
 | 
				
			||||||
 | 
					                            if nt.isNodule_bool]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        log.info("{!r}: {} {} series, {} slices, {} nodules".format(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            len(self.series_list),
 | 
				
			||||||
 | 
					            {None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
 | 
				
			||||||
 | 
					            len(self.sample_list),
 | 
				
			||||||
 | 
					            len(self.pos_list),
 | 
				
			||||||
 | 
					        ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self.sample_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, ndx):
 | 
				
			||||||
 | 
					        series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]
 | 
				
			||||||
 | 
					        return self.getitem_fullSlice(series_uid, slice_ndx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def getitem_fullSlice(self, series_uid, slice_ndx):
 | 
				
			||||||
 | 
					        ct = getCt(series_uid)
 | 
				
			||||||
 | 
					        ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        start_ndx = slice_ndx - self.contextSlices_count
 | 
				
			||||||
 | 
					        end_ndx = slice_ndx + self.contextSlices_count + 1
 | 
				
			||||||
 | 
					        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
 | 
				
			||||||
 | 
					            context_ndx = max(context_ndx, 0)
 | 
				
			||||||
 | 
					            context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
 | 
				
			||||||
 | 
					            ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
 | 
				
			||||||
 | 
					        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
 | 
				
			||||||
 | 
					        # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
 | 
				
			||||||
 | 
					        # The upper bound nukes any weird hotspots and clamps bone down
 | 
				
			||||||
 | 
					        ct_t.clamp_(-1000, 1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ct_t, pos_t, ct.series_uid, slice_ndx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.ratio_int = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return 300000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def shuffleSamples(self):
 | 
				
			||||||
 | 
					        random.shuffle(self.candidateInfo_list)
 | 
				
			||||||
 | 
					        random.shuffle(self.pos_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, ndx):
 | 
				
			||||||
 | 
					        candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
 | 
				
			||||||
 | 
					        return self.getitem_trainingCrop(candidateInfo_tup)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def getitem_trainingCrop(self, candidateInfo_tup):
 | 
				
			||||||
 | 
					        ct_a, pos_a, center_irc = getCtRawCandidate(
 | 
				
			||||||
 | 
					            candidateInfo_tup.series_uid,
 | 
				
			||||||
 | 
					            candidateInfo_tup.center_xyz,
 | 
				
			||||||
 | 
					            (7, 96, 96),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        pos_a = pos_a[3:4]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        row_offset = random.randrange(0,32)
 | 
				
			||||||
 | 
					        col_offset = random.randrange(0,32)
 | 
				
			||||||
 | 
					        ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64,
 | 
				
			||||||
 | 
					                                     col_offset:col_offset+64]).to(torch.float32)
 | 
				
			||||||
 | 
					        pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64,
 | 
				
			||||||
 | 
					                                       col_offset:col_offset+64]).to(torch.long)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        slice_ndx = center_irc.index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PrepcacheLunaDataset(Dataset):
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.candidateInfo_list = getCandidateInfoList()
 | 
				
			||||||
 | 
					        self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.seen_set = set()
 | 
				
			||||||
 | 
					        self.candidateInfo_list.sort(key=lambda x: x.series_uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self.candidateInfo_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, ndx):
 | 
				
			||||||
 | 
					        # candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        candidateInfo_tup = self.candidateInfo_list[ndx]
 | 
				
			||||||
 | 
					        getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        series_uid = candidateInfo_tup.series_uid
 | 
				
			||||||
 | 
					        if series_uid not in self.seen_set:
 | 
				
			||||||
 | 
					            self.seen_set.add(series_uid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            getCtSampleSize(series_uid)
 | 
				
			||||||
 | 
					            # ct = getCt(series_uid)
 | 
				
			||||||
 | 
					            # for mask_ndx in ct.positive_indexes:
 | 
				
			||||||
 | 
					            #     build2dLungMask(series_uid, mask_ndx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return 0, 1 #candidate_t, pos_t, series_uid, center_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset):
 | 
				
			||||||
 | 
					    def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3):
 | 
				
			||||||
 | 
					        assert contextSlices_count == 3
 | 
				
			||||||
 | 
					        data = torch.load('./imgs_and_masks.pt')
 | 
				
			||||||
 | 
					        suids = list(set(data['suids']))
 | 
				
			||||||
 | 
					        trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1)
 | 
				
			||||||
 | 
					        trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i}
 | 
				
			||||||
 | 
					        trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]])
 | 
				
			||||||
 | 
					        if not isValSet_bool:
 | 
				
			||||||
 | 
					            self.imgs = data["imgs"][trn_mask]
 | 
				
			||||||
 | 
					            self.masks = data["masks"][trn_mask]
 | 
				
			||||||
 | 
					            self.suids = [s for s, i in zip(data["suids"], trn_mask) if i]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.imgs = data["imgs"][~trn_mask]
 | 
				
			||||||
 | 
					            self.masks = data["masks"][~trn_mask]
 | 
				
			||||||
 | 
					            self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i]
 | 
				
			||||||
 | 
					        # discard spurious hotspots and clamp bone
 | 
				
			||||||
 | 
					        self.imgs.clamp_(-1000, 1000)
 | 
				
			||||||
 | 
					        self.imgs /= 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self):
 | 
				
			||||||
 | 
					        return len(self.imgs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, i):
 | 
				
			||||||
 | 
					        oh, ow = torch.randint(0, 32, (2,))
 | 
				
			||||||
 | 
					        sl = self.masks.size(1)//2
 | 
				
			||||||
 | 
					        return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999
 | 
				
			||||||
@ -0,0 +1,224 @@
 | 
				
			|||||||
 | 
					import math
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					from collections import namedtuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn as nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from util.logconf import logging
 | 
				
			||||||
 | 
					from util.unet import UNet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					# log.setLevel(logging.WARN)
 | 
				
			||||||
 | 
					# log.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					log.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UNetWrapper(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
 | 
				
			||||||
 | 
					        self.unet = UNet(**kwargs)
 | 
				
			||||||
 | 
					        self.final = nn.Sigmoid()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._init_weights()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _init_weights(self):
 | 
				
			||||||
 | 
					        init_set = {
 | 
				
			||||||
 | 
					            nn.Conv2d,
 | 
				
			||||||
 | 
					            nn.Conv3d,
 | 
				
			||||||
 | 
					            nn.ConvTranspose2d,
 | 
				
			||||||
 | 
					            nn.ConvTranspose3d,
 | 
				
			||||||
 | 
					            nn.Linear,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        for m in self.modules():
 | 
				
			||||||
 | 
					            if type(m) in init_set:
 | 
				
			||||||
 | 
					                nn.init.kaiming_normal_(
 | 
				
			||||||
 | 
					                    m.weight.data, mode='fan_out', nonlinearity='relu', a=0
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                if m.bias is not None:
 | 
				
			||||||
 | 
					                    fan_in, fan_out = \
 | 
				
			||||||
 | 
					                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
 | 
				
			||||||
 | 
					                    bound = 1 / math.sqrt(fan_out)
 | 
				
			||||||
 | 
					                    nn.init.normal_(m.bias, -bound, bound)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # nn.init.constant_(self.unet.last.bias, -4)
 | 
				
			||||||
 | 
					        # nn.init.constant_(self.unet.last.bias, 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, input_batch):
 | 
				
			||||||
 | 
					        bn_output = self.input_batchnorm(input_batch)
 | 
				
			||||||
 | 
					        un_output = self.unet(bn_output)
 | 
				
			||||||
 | 
					        fn_output = self.final(un_output)
 | 
				
			||||||
 | 
					        return fn_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SegmentationAugmentation(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self, flip=None, offset=None, scale=None, rotate=None, noise=None
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.flip = flip
 | 
				
			||||||
 | 
					        self.offset = offset
 | 
				
			||||||
 | 
					        self.scale = scale
 | 
				
			||||||
 | 
					        self.rotate = rotate
 | 
				
			||||||
 | 
					        self.noise = noise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, input_g, label_g):
 | 
				
			||||||
 | 
					        transform_t = self._build2dTransformMatrix()
 | 
				
			||||||
 | 
					        transform_t = transform_t.expand(input_g.shape[0], -1, -1)
 | 
				
			||||||
 | 
					        transform_t = transform_t.to(input_g.device, torch.float32)
 | 
				
			||||||
 | 
					        affine_t = F.affine_grid(transform_t[:,:2],
 | 
				
			||||||
 | 
					                input_g.size(), align_corners=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        augmented_input_g = F.grid_sample(input_g,
 | 
				
			||||||
 | 
					                affine_t, padding_mode='border',
 | 
				
			||||||
 | 
					                align_corners=False)
 | 
				
			||||||
 | 
					        augmented_label_g = F.grid_sample(label_g.to(torch.float32),
 | 
				
			||||||
 | 
					                affine_t, padding_mode='border',
 | 
				
			||||||
 | 
					                align_corners=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.noise:
 | 
				
			||||||
 | 
					            noise_t = torch.randn_like(augmented_input_g)
 | 
				
			||||||
 | 
					            noise_t *= self.noise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            augmented_input_g += noise_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return augmented_input_g, augmented_label_g > 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _build2dTransformMatrix(self):
 | 
				
			||||||
 | 
					        transform_t = torch.eye(3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(2):
 | 
				
			||||||
 | 
					            if self.flip:
 | 
				
			||||||
 | 
					                if random.random() > 0.5:
 | 
				
			||||||
 | 
					                    transform_t[i,i] *= -1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.offset:
 | 
				
			||||||
 | 
					                offset_float = self.offset
 | 
				
			||||||
 | 
					                random_float = (random.random() * 2 - 1)
 | 
				
			||||||
 | 
					                transform_t[2,i] = offset_float * random_float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.scale:
 | 
				
			||||||
 | 
					                scale_float = self.scale
 | 
				
			||||||
 | 
					                random_float = (random.random() * 2 - 1)
 | 
				
			||||||
 | 
					                transform_t[i,i] *= 1.0 + scale_float * random_float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.rotate:
 | 
				
			||||||
 | 
					            angle_rad = random.random() * math.pi * 2
 | 
				
			||||||
 | 
					            s = math.sin(angle_rad)
 | 
				
			||||||
 | 
					            c = math.cos(angle_rad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rotation_t = torch.tensor([
 | 
				
			||||||
 | 
					                [c, -s, 0],
 | 
				
			||||||
 | 
					                [s, c, 0],
 | 
				
			||||||
 | 
					                [0, 0, 1]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            transform_t @= rotation_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return transform_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# class SegmentationMask(nn.Module):
 | 
				
			||||||
 | 
					#     def __init__(self):
 | 
				
			||||||
 | 
					#         super().__init__()
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         self.conv_list = nn.ModuleList([
 | 
				
			||||||
 | 
					#             self._make_circle_conv(radius) for radius in range(1, 8)
 | 
				
			||||||
 | 
					#         ])
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     def _make_circle_conv(self, radius):
 | 
				
			||||||
 | 
					#         diameter = 1 + radius * 2
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         a = torch.linspace(-1, 1, steps=diameter)**2
 | 
				
			||||||
 | 
					#         b = (a[None] + a[:, None])**0.5
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         circle_weights = (b <= 1.0).to(torch.float32)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         conv = nn.Conv2d(1, 1, kernel_size=diameter, padding=radius, bias=False)
 | 
				
			||||||
 | 
					#         conv.weight.data.fill_(1)
 | 
				
			||||||
 | 
					#         conv.weight.data *= circle_weights / circle_weights.sum()
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         return conv
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     def erode(self, input_mask, radius, threshold=1):
 | 
				
			||||||
 | 
					#         conv = self.conv_list[radius - 1]
 | 
				
			||||||
 | 
					#         input_float = input_mask.to(torch.float32)
 | 
				
			||||||
 | 
					#         result = conv(input_float)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
 | 
				
			||||||
 | 
					#         # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         return result >= threshold
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     def deposit(self, input_mask, radius, threshold=0):
 | 
				
			||||||
 | 
					#         conv = self.conv_list[radius - 1]
 | 
				
			||||||
 | 
					#         input_float = input_mask.to(torch.float32)
 | 
				
			||||||
 | 
					#         result = conv(input_float)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         # log.debug(['deposit in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
 | 
				
			||||||
 | 
					#         # log.debug(['deposit out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         return result > threshold
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     def fill_cavity(self, input_mask):
 | 
				
			||||||
 | 
					#         cumsum = input_mask.cumsum(-1)
 | 
				
			||||||
 | 
					#         filled_mask = (cumsum > 0)
 | 
				
			||||||
 | 
					#         filled_mask &= (cumsum < cumsum[..., -1:])
 | 
				
			||||||
 | 
					#         cumsum = input_mask.cumsum(-2)
 | 
				
			||||||
 | 
					#         filled_mask &= (cumsum > 0)
 | 
				
			||||||
 | 
					#         filled_mask &= (cumsum < cumsum[..., -1:, :])
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         return filled_mask
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     def forward(self, input_g, raw_pos_g):
 | 
				
			||||||
 | 
					#         gcc_g = input_g + 1
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         with torch.no_grad():
 | 
				
			||||||
 | 
					#             # log.info(['gcc_g', gcc_g.min(), gcc_g.mean(), gcc_g.max()])
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             raw_dense_mask = gcc_g > 0.7
 | 
				
			||||||
 | 
					#             dense_mask = self.deposit(raw_dense_mask, 2)
 | 
				
			||||||
 | 
					#             dense_mask = self.erode(dense_mask, 6)
 | 
				
			||||||
 | 
					#             dense_mask = self.deposit(dense_mask, 4)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             body_mask = self.fill_cavity(dense_mask)
 | 
				
			||||||
 | 
					#             air_mask = self.deposit(body_mask & ~dense_mask, 5)
 | 
				
			||||||
 | 
					#             air_mask = self.erode(air_mask, 6)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             lung_mask = self.deposit(air_mask, 5)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             raw_candidate_mask = gcc_g > 0.4
 | 
				
			||||||
 | 
					#             raw_candidate_mask &= air_mask
 | 
				
			||||||
 | 
					#             candidate_mask = self.erode(raw_candidate_mask, 1)
 | 
				
			||||||
 | 
					#             candidate_mask = self.deposit(candidate_mask, 1)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             pos_mask = self.deposit((raw_pos_g > 0.5) & lung_mask, 2)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             neg_mask = self.deposit(candidate_mask, 1)
 | 
				
			||||||
 | 
					#             neg_mask &= ~pos_mask
 | 
				
			||||||
 | 
					#             neg_mask &= lung_mask
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#             # label_g = (neg_mask | pos_mask).to(torch.float32)
 | 
				
			||||||
 | 
					#             label_g = (pos_mask).to(torch.float32)
 | 
				
			||||||
 | 
					#             neg_g = neg_mask.to(torch.float32)
 | 
				
			||||||
 | 
					#             pos_g = pos_mask.to(torch.float32)
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         mask_dict = {
 | 
				
			||||||
 | 
					#             'raw_dense_mask': raw_dense_mask,
 | 
				
			||||||
 | 
					#             'dense_mask': dense_mask,
 | 
				
			||||||
 | 
					#             'body_mask': body_mask,
 | 
				
			||||||
 | 
					#             'air_mask': air_mask,
 | 
				
			||||||
 | 
					#             'raw_candidate_mask': raw_candidate_mask,
 | 
				
			||||||
 | 
					#             'candidate_mask': candidate_mask,
 | 
				
			||||||
 | 
					#             'lung_mask': lung_mask,
 | 
				
			||||||
 | 
					#             'neg_mask': neg_mask,
 | 
				
			||||||
 | 
					#             'pos_mask': pos_mask,
 | 
				
			||||||
 | 
					#         }
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#         return label_g, neg_g, pos_g, lung_mask, mask_dict
 | 
				
			||||||
@ -0,0 +1,69 @@
 | 
				
			|||||||
 | 
					import timing
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from torch.autograd import Variable
 | 
				
			||||||
 | 
					from torch.optim import SGD
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from util.util import enumerateWithEstimate
 | 
				
			||||||
 | 
					from .dsets import PrepcacheLunaDataset, getCtSampleSize
 | 
				
			||||||
 | 
					from util.logconf import logging
 | 
				
			||||||
 | 
					# from .model import LunaModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					# log.setLevel(logging.WARN)
 | 
				
			||||||
 | 
					log.setLevel(logging.INFO)
 | 
				
			||||||
 | 
					# log.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LunaPrepCacheApp:
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def __init__(self, sys_argv=None):
 | 
				
			||||||
 | 
					        if sys_argv is None:
 | 
				
			||||||
 | 
					            sys_argv = sys.argv[1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					        parser.add_argument('--batch-size',
 | 
				
			||||||
 | 
					            help='Batch size to use for training',
 | 
				
			||||||
 | 
					            default=1024,
 | 
				
			||||||
 | 
					            type=int,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        parser.add_argument('--num-workers',
 | 
				
			||||||
 | 
					            help='Number of worker processes for background data loading',
 | 
				
			||||||
 | 
					            default=8,
 | 
				
			||||||
 | 
					            type=int,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # parser.add_argument('--scaled',
 | 
				
			||||||
 | 
					        #     help="Scale the CT chunks to square voxels.",
 | 
				
			||||||
 | 
					        #     default=False,
 | 
				
			||||||
 | 
					        #     action='store_true',
 | 
				
			||||||
 | 
					        # )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.cli_args = parser.parse_args(sys_argv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def main(self):
 | 
				
			||||||
 | 
					        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.prep_dl = DataLoader(
 | 
				
			||||||
 | 
					            PrepcacheLunaDataset(
 | 
				
			||||||
 | 
					                # sortby_str='series_uid',
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            batch_size=self.cli_args.batch_size,
 | 
				
			||||||
 | 
					            num_workers=self.cli_args.num_workers,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_iter = enumerateWithEstimate(
 | 
				
			||||||
 | 
					            self.prep_dl,
 | 
				
			||||||
 | 
					            "Stuffing cache",
 | 
				
			||||||
 | 
					            start_ndx=self.prep_dl.num_workers,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        for batch_ndx, batch_tup in batch_iter:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    LunaPrepCacheApp().main()
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user