diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/__init__.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/__init__.py new file mode 100644 index 00000000..06d74050 Binary files /dev/null and b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/__init__.py differ diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/benchmark_seg.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/benchmark_seg.py new file mode 100644 index 00000000..05384204 --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/benchmark_seg.py @@ -0,0 +1,76 @@ +import argparse +import datetime +import os +import socket +import sys + +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +import torch +import torch.nn as nn +import torch.optim + +from torch.optim import SGD, Adam +from torch.utils.data import DataLoader + +from util.util import enumerateWithEstimate +from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt +from util.logconf import logging +from util.util import xyz2irc +from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation +from p2ch13.train_seg import LunaTrainingApp + +log = logging.getLogger(__name__) +# log.setLevel(logging.WARN) +# log.setLevel(logging.INFO) +log.setLevel(logging.DEBUG) + +class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset): + def __len__(self): + # return 500 + return 5000 + return 1000 + +class LunaBenchmarkApp(LunaTrainingApp): + def initTrainDl(self): + train_ds = BenchmarkLuna2dSegmentationDataset( + val_stride=10, + isValSet_bool=False, + contextSlices_count=3, + # augmentation_dict=self.augmentation_dict, + ) + + batch_size = self.cli_args.batch_size + if self.use_cuda: + batch_size *= torch.cuda.device_count() + + train_dl = DataLoader( + train_ds, + batch_size=batch_size, + num_workers=self.cli_args.num_workers, + pin_memory=self.use_cuda, + ) + + return train_dl + + def main(self): + log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) + + train_dl = self.initTrainDl() + + for epoch_ndx in range(1, 2): + log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format( + epoch_ndx, + self.cli_args.epochs, + len(train_dl), + len([]), + self.cli_args.batch_size, + (torch.cuda.device_count() if self.use_cuda else 1), + )) + + self.doTraining(epoch_ndx, train_dl) + + +if __name__ == '__main__': + LunaBenchmarkApp().main() diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/dsets.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/dsets.py new file mode 100644 index 00000000..f16b1386 --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/dsets.py @@ -0,0 +1,401 @@ +import copy +import csv +import functools +import glob +import math +import os +import random + +from collections import namedtuple + +import SimpleITK as sitk +import numpy as np +import scipy.ndimage.morphology as morph + +import torch +import torch.cuda +import torch.nn.functional as F +from torch.utils.data import Dataset + +from util.disk import getCache +from util.util import XyzTuple, xyz2irc +from util.logconf import logging + +log = logging.getLogger(__name__) +# log.setLevel(logging.WARN) +# log.setLevel(logging.INFO) +log.setLevel(logging.DEBUG) + +raw_cache = getCache('part2ch13_raw') + +MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask') + +CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz') + +@functools.lru_cache(1) +def getCandidateInfoList(requireOnDisk_bool=True): + # We construct a set with all series_uids that are present on disk. + # This will let us use the data, even if we haven't downloaded all of + # the subsets yet. + mhd_list = glob.glob('data-unversioned/subset*/*.mhd') + presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list} + + candidateInfo_list = [] + with open('data/annotations_with_malignancy.csv', "r") as f: + for row in list(csv.reader(f))[1:]: + series_uid = row[0] + annotationCenter_xyz = tuple([float(x) for x in row[1:4]]) + annotationDiameter_mm = float(row[4]) + isMal_bool = {'False': False, 'True': True}[row[5]] + + candidateInfo_list.append( + CandidateInfoTuple( + True, + True, + isMal_bool, + annotationDiameter_mm, + series_uid, + annotationCenter_xyz, + ) + ) + + with open('data/candidates.csv', "r") as f: + for row in list(csv.reader(f))[1:]: + series_uid = row[0] + + if series_uid not in presentOnDisk_set and requireOnDisk_bool: + continue + + isNodule_bool = bool(int(row[4])) + candidateCenter_xyz = tuple([float(x) for x in row[1:4]]) + + if not isNodule_bool: + candidateInfo_list.append( + CandidateInfoTuple( + False, + False, + False, + 0.0, + series_uid, + candidateCenter_xyz, + ) + ) + + candidateInfo_list.sort(reverse=True) + return candidateInfo_list + +@functools.lru_cache(1) +def getCandidateInfoDict(requireOnDisk_bool=True): + candidateInfo_list = getCandidateInfoList(requireOnDisk_bool) + candidateInfo_dict = {} + + for candidateInfo_tup in candidateInfo_list: + candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, + []).append(candidateInfo_tup) + + return candidateInfo_dict + +class Ct: + def __init__(self, series_uid): + mhd_path = glob.glob( + 'data-unversioned/subset*/{}.mhd'.format(series_uid) + )[0] + + ct_mhd = sitk.ReadImage(mhd_path) + self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32) + + # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale + # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. + + self.series_uid = series_uid + + self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin()) + self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing()) + self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3) + + candidateInfo_list = getCandidateInfoDict()[self.series_uid] + + self.positiveInfo_list = [ + candidate_tup + for candidate_tup in candidateInfo_list + if candidate_tup.isNodule_bool + ] + self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list) + self.positive_indexes = (self.positive_mask.sum(axis=(1,2)) + .nonzero()[0].tolist()) + + def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700): + boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool) + + for candidateInfo_tup in positiveInfo_list: + center_irc = xyz2irc( + candidateInfo_tup.center_xyz, + self.origin_xyz, + self.vxSize_xyz, + self.direction_a, + ) + ci = int(center_irc.index) + cr = int(center_irc.row) + cc = int(center_irc.col) + + index_radius = 2 + try: + while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \ + self.hu_a[ci - index_radius, cr, cc] > threshold_hu: + index_radius += 1 + except IndexError: + index_radius -= 1 + + row_radius = 2 + try: + while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \ + self.hu_a[ci, cr - row_radius, cc] > threshold_hu: + row_radius += 1 + except IndexError: + row_radius -= 1 + + col_radius = 2 + try: + while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \ + self.hu_a[ci, cr, cc - col_radius] > threshold_hu: + col_radius += 1 + except IndexError: + col_radius -= 1 + + # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]]) + # assert row_radius > 0 + # assert col_radius > 0 + + boundingBox_a[ + ci - index_radius: ci + index_radius + 1, + cr - row_radius: cr + row_radius + 1, + cc - col_radius: cc + col_radius + 1] = True + + mask_a = boundingBox_a & (self.hu_a > threshold_hu) + + return mask_a + + def getRawCandidate(self, center_xyz, width_irc): + center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, + self.direction_a) + + slice_list = [] + for axis, center_val in enumerate(center_irc): + start_ndx = int(round(center_val - width_irc[axis]/2)) + end_ndx = int(start_ndx + width_irc[axis]) + + assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis]) + + if start_ndx < 0: + # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( + # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) + start_ndx = 0 + end_ndx = int(width_irc[axis]) + + if end_ndx > self.hu_a.shape[axis]: + # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format( + # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc)) + end_ndx = self.hu_a.shape[axis] + start_ndx = int(self.hu_a.shape[axis] - width_irc[axis]) + + slice_list.append(slice(start_ndx, end_ndx)) + + ct_chunk = self.hu_a[tuple(slice_list)] + pos_chunk = self.positive_mask[tuple(slice_list)] + + return ct_chunk, pos_chunk, center_irc + +@functools.lru_cache(1, typed=True) +def getCt(series_uid): + return Ct(series_uid) + +@raw_cache.memoize(typed=True) +def getCtRawCandidate(series_uid, center_xyz, width_irc): + ct = getCt(series_uid) + ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz, + width_irc) + ct_chunk.clip(-1000, 1000, ct_chunk) + return ct_chunk, pos_chunk, center_irc + +@raw_cache.memoize(typed=True) +def getCtSampleSize(series_uid): + ct = Ct(series_uid) + return int(ct.hu_a.shape[0]), ct.positive_indexes + + +class Luna2dSegmentationDataset(Dataset): + def __init__(self, + val_stride=0, + isValSet_bool=None, + series_uid=None, + contextSlices_count=3, + fullCt_bool=False, + ): + self.contextSlices_count = contextSlices_count + self.fullCt_bool = fullCt_bool + + if series_uid: + self.series_list = [series_uid] + else: + self.series_list = sorted(getCandidateInfoDict().keys()) + + if isValSet_bool: + assert val_stride > 0, val_stride + self.series_list = self.series_list[::val_stride] + assert self.series_list + elif val_stride > 0: + del self.series_list[::val_stride] + assert self.series_list + + self.sample_list = [] + for series_uid in self.series_list: + index_count, positive_indexes = getCtSampleSize(series_uid) + + if self.fullCt_bool: + self.sample_list += [(series_uid, slice_ndx) + for slice_ndx in range(index_count)] + else: + self.sample_list += [(series_uid, slice_ndx) + for slice_ndx in positive_indexes] + + self.candidateInfo_list = getCandidateInfoList() + + series_set = set(self.series_list) + self.candidateInfo_list = [cit for cit in self.candidateInfo_list + if cit.series_uid in series_set] + + self.pos_list = [nt for nt in self.candidateInfo_list + if nt.isNodule_bool] + + log.info("{!r}: {} {} series, {} slices, {} nodules".format( + self, + len(self.series_list), + {None: 'general', True: 'validation', False: 'training'}[isValSet_bool], + len(self.sample_list), + len(self.pos_list), + )) + + def __len__(self): + return len(self.sample_list) + + def __getitem__(self, ndx): + series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)] + return self.getitem_fullSlice(series_uid, slice_ndx) + + def getitem_fullSlice(self, series_uid, slice_ndx): + ct = getCt(series_uid) + ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512)) + + start_ndx = slice_ndx - self.contextSlices_count + end_ndx = slice_ndx + self.contextSlices_count + 1 + for i, context_ndx in enumerate(range(start_ndx, end_ndx)): + context_ndx = max(context_ndx, 0) + context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1) + ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32)) + + # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale + # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0. + # The lower bound gets rid of negative density stuff used to indicate out-of-FOV + # The upper bound nukes any weird hotspots and clamps bone down + ct_t.clamp_(-1000, 1000) + + pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0) + + return ct_t, pos_t, ct.series_uid, slice_ndx + + +class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.ratio_int = 2 + + def __len__(self): + return 300000 + + def shuffleSamples(self): + random.shuffle(self.candidateInfo_list) + random.shuffle(self.pos_list) + + def __getitem__(self, ndx): + candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)] + return self.getitem_trainingCrop(candidateInfo_tup) + + def getitem_trainingCrop(self, candidateInfo_tup): + ct_a, pos_a, center_irc = getCtRawCandidate( + candidateInfo_tup.series_uid, + candidateInfo_tup.center_xyz, + (7, 96, 96), + ) + pos_a = pos_a[3:4] + + row_offset = random.randrange(0,32) + col_offset = random.randrange(0,32) + ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64, + col_offset:col_offset+64]).to(torch.float32) + pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64, + col_offset:col_offset+64]).to(torch.long) + + slice_ndx = center_irc.index + + return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx + +class PrepcacheLunaDataset(Dataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.candidateInfo_list = getCandidateInfoList() + self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool] + + self.seen_set = set() + self.candidateInfo_list.sort(key=lambda x: x.series_uid) + + def __len__(self): + return len(self.candidateInfo_list) + + def __getitem__(self, ndx): + # candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx) + + candidateInfo_tup = self.candidateInfo_list[ndx] + getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96)) + + series_uid = candidateInfo_tup.series_uid + if series_uid not in self.seen_set: + self.seen_set.add(series_uid) + + getCtSampleSize(series_uid) + # ct = getCt(series_uid) + # for mask_ndx in ct.positive_indexes: + # build2dLungMask(series_uid, mask_ndx) + + return 0, 1 #candidate_t, pos_t, series_uid, center_t + + +class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset): + def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3): + assert contextSlices_count == 3 + data = torch.load('./imgs_and_masks.pt') + suids = list(set(data['suids'])) + trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1) + trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i} + trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]]) + if not isValSet_bool: + self.imgs = data["imgs"][trn_mask] + self.masks = data["masks"][trn_mask] + self.suids = [s for s, i in zip(data["suids"], trn_mask) if i] + else: + self.imgs = data["imgs"][~trn_mask] + self.masks = data["masks"][~trn_mask] + self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i] + # discard spurious hotspots and clamp bone + self.imgs.clamp_(-1000, 1000) + self.imgs /= 1000 + + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, i): + oh, ow = torch.randint(0, 32, (2,)) + sl = self.masks.size(1)//2 + return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999 diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/model.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/model.py new file mode 100644 index 00000000..20cecbb9 --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/model.py @@ -0,0 +1,224 @@ +import math +import random +from collections import namedtuple + +import torch +from torch import nn as nn +import torch.nn.functional as F + +from util.logconf import logging +from util.unet import UNet + +log = logging.getLogger(__name__) +# log.setLevel(logging.WARN) +# log.setLevel(logging.INFO) +log.setLevel(logging.DEBUG) + +class UNetWrapper(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels']) + self.unet = UNet(**kwargs) + self.final = nn.Sigmoid() + + self._init_weights() + + def _init_weights(self): + init_set = { + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + } + for m in self.modules(): + if type(m) in init_set: + nn.init.kaiming_normal_( + m.weight.data, mode='fan_out', nonlinearity='relu', a=0 + ) + if m.bias is not None: + fan_in, fan_out = \ + nn.init._calculate_fan_in_and_fan_out(m.weight.data) + bound = 1 / math.sqrt(fan_out) + nn.init.normal_(m.bias, -bound, bound) + + # nn.init.constant_(self.unet.last.bias, -4) + # nn.init.constant_(self.unet.last.bias, 4) + + + def forward(self, input_batch): + bn_output = self.input_batchnorm(input_batch) + un_output = self.unet(bn_output) + fn_output = self.final(un_output) + return fn_output + +class SegmentationAugmentation(nn.Module): + def __init__( + self, flip=None, offset=None, scale=None, rotate=None, noise=None + ): + super().__init__() + + self.flip = flip + self.offset = offset + self.scale = scale + self.rotate = rotate + self.noise = noise + + def forward(self, input_g, label_g): + transform_t = self._build2dTransformMatrix() + transform_t = transform_t.expand(input_g.shape[0], -1, -1) + transform_t = transform_t.to(input_g.device, torch.float32) + affine_t = F.affine_grid(transform_t[:,:2], + input_g.size(), align_corners=False) + + augmented_input_g = F.grid_sample(input_g, + affine_t, padding_mode='border', + align_corners=False) + augmented_label_g = F.grid_sample(label_g.to(torch.float32), + affine_t, padding_mode='border', + align_corners=False) + + if self.noise: + noise_t = torch.randn_like(augmented_input_g) + noise_t *= self.noise + + augmented_input_g += noise_t + + return augmented_input_g, augmented_label_g > 0.5 + + def _build2dTransformMatrix(self): + transform_t = torch.eye(3) + + for i in range(2): + if self.flip: + if random.random() > 0.5: + transform_t[i,i] *= -1 + + if self.offset: + offset_float = self.offset + random_float = (random.random() * 2 - 1) + transform_t[2,i] = offset_float * random_float + + if self.scale: + scale_float = self.scale + random_float = (random.random() * 2 - 1) + transform_t[i,i] *= 1.0 + scale_float * random_float + + if self.rotate: + angle_rad = random.random() * math.pi * 2 + s = math.sin(angle_rad) + c = math.cos(angle_rad) + + rotation_t = torch.tensor([ + [c, -s, 0], + [s, c, 0], + [0, 0, 1]]) + + transform_t @= rotation_t + + return transform_t + + +# MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask') +# +# class SegmentationMask(nn.Module): +# def __init__(self): +# super().__init__() +# +# self.conv_list = nn.ModuleList([ +# self._make_circle_conv(radius) for radius in range(1, 8) +# ]) +# +# def _make_circle_conv(self, radius): +# diameter = 1 + radius * 2 +# +# a = torch.linspace(-1, 1, steps=diameter)**2 +# b = (a[None] + a[:, None])**0.5 +# +# circle_weights = (b <= 1.0).to(torch.float32) +# +# conv = nn.Conv2d(1, 1, kernel_size=diameter, padding=radius, bias=False) +# conv.weight.data.fill_(1) +# conv.weight.data *= circle_weights / circle_weights.sum() +# +# return conv +# +# +# def erode(self, input_mask, radius, threshold=1): +# conv = self.conv_list[radius - 1] +# input_float = input_mask.to(torch.float32) +# result = conv(input_float) +# +# # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()]) +# # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()]) +# +# return result >= threshold +# +# def deposit(self, input_mask, radius, threshold=0): +# conv = self.conv_list[radius - 1] +# input_float = input_mask.to(torch.float32) +# result = conv(input_float) +# +# # log.debug(['deposit in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()]) +# # log.debug(['deposit out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()]) +# +# return result > threshold +# +# def fill_cavity(self, input_mask): +# cumsum = input_mask.cumsum(-1) +# filled_mask = (cumsum > 0) +# filled_mask &= (cumsum < cumsum[..., -1:]) +# cumsum = input_mask.cumsum(-2) +# filled_mask &= (cumsum > 0) +# filled_mask &= (cumsum < cumsum[..., -1:, :]) +# +# return filled_mask +# +# +# def forward(self, input_g, raw_pos_g): +# gcc_g = input_g + 1 +# +# with torch.no_grad(): +# # log.info(['gcc_g', gcc_g.min(), gcc_g.mean(), gcc_g.max()]) +# +# raw_dense_mask = gcc_g > 0.7 +# dense_mask = self.deposit(raw_dense_mask, 2) +# dense_mask = self.erode(dense_mask, 6) +# dense_mask = self.deposit(dense_mask, 4) +# +# body_mask = self.fill_cavity(dense_mask) +# air_mask = self.deposit(body_mask & ~dense_mask, 5) +# air_mask = self.erode(air_mask, 6) +# +# lung_mask = self.deposit(air_mask, 5) +# +# raw_candidate_mask = gcc_g > 0.4 +# raw_candidate_mask &= air_mask +# candidate_mask = self.erode(raw_candidate_mask, 1) +# candidate_mask = self.deposit(candidate_mask, 1) +# +# pos_mask = self.deposit((raw_pos_g > 0.5) & lung_mask, 2) +# +# neg_mask = self.deposit(candidate_mask, 1) +# neg_mask &= ~pos_mask +# neg_mask &= lung_mask +# +# # label_g = (neg_mask | pos_mask).to(torch.float32) +# label_g = (pos_mask).to(torch.float32) +# neg_g = neg_mask.to(torch.float32) +# pos_g = pos_mask.to(torch.float32) +# +# mask_dict = { +# 'raw_dense_mask': raw_dense_mask, +# 'dense_mask': dense_mask, +# 'body_mask': body_mask, +# 'air_mask': air_mask, +# 'raw_candidate_mask': raw_candidate_mask, +# 'candidate_mask': candidate_mask, +# 'lung_mask': lung_mask, +# 'neg_mask': neg_mask, +# 'pos_mask': pos_mask, +# } +# +# return label_g, neg_g, pos_g, lung_mask, mask_dict diff --git a/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/prepcache.py b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/prepcache.py new file mode 100644 index 00000000..9e867cde --- /dev/null +++ b/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/prepcache.py @@ -0,0 +1,69 @@ +import timing +import argparse +import sys + +import numpy as np + +import torch.nn as nn +from torch.autograd import Variable +from torch.optim import SGD +from torch.utils.data import DataLoader + +from util.util import enumerateWithEstimate +from .dsets import PrepcacheLunaDataset, getCtSampleSize +from util.logconf import logging +# from .model import LunaModel + +log = logging.getLogger(__name__) +# log.setLevel(logging.WARN) +log.setLevel(logging.INFO) +# log.setLevel(logging.DEBUG) + + +class LunaPrepCacheApp: + @classmethod + def __init__(self, sys_argv=None): + if sys_argv is None: + sys_argv = sys.argv[1:] + + parser = argparse.ArgumentParser() + parser.add_argument('--batch-size', + help='Batch size to use for training', + default=1024, + type=int, + ) + parser.add_argument('--num-workers', + help='Number of worker processes for background data loading', + default=8, + type=int, + ) + # parser.add_argument('--scaled', + # help="Scale the CT chunks to square voxels.", + # default=False, + # action='store_true', + # ) + + self.cli_args = parser.parse_args(sys_argv) + + def main(self): + log.info("Starting {}, {}".format(type(self).__name__, self.cli_args)) + + self.prep_dl = DataLoader( + PrepcacheLunaDataset( + # sortby_str='series_uid', + ), + batch_size=self.cli_args.batch_size, + num_workers=self.cli_args.num_workers, + ) + + batch_iter = enumerateWithEstimate( + self.prep_dl, + "Stuffing cache", + start_ndx=self.prep_dl.num_workers, + ) + for batch_ndx, batch_tup in batch_iter: + pass + + +if __name__ == '__main__': + LunaPrepCacheApp().main()