Nahrát soubory do „pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model“

This commit is contained in:
Lukáš Pokrývka 2020-11-12 15:21:33 +00:00
parent 35f3889338
commit 526fe93ff9
5 changed files with 770 additions and 0 deletions

View File

@ -0,0 +1,76 @@
import argparse
import datetime
import os
import socket
import sys
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.optim
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from util.util import enumerateWithEstimate
from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
from util.logconf import logging
from util.util import xyz2irc
from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation
from p2ch13.train_seg import LunaTrainingApp
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset):
def __len__(self):
# return 500
return 5000
return 1000
class LunaBenchmarkApp(LunaTrainingApp):
def initTrainDl(self):
train_ds = BenchmarkLuna2dSegmentationDataset(
val_stride=10,
isValSet_bool=False,
contextSlices_count=3,
# augmentation_dict=self.augmentation_dict,
)
batch_size = self.cli_args.batch_size
if self.use_cuda:
batch_size *= torch.cuda.device_count()
train_dl = DataLoader(
train_ds,
batch_size=batch_size,
num_workers=self.cli_args.num_workers,
pin_memory=self.use_cuda,
)
return train_dl
def main(self):
log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
train_dl = self.initTrainDl()
for epoch_ndx in range(1, 2):
log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
epoch_ndx,
self.cli_args.epochs,
len(train_dl),
len([]),
self.cli_args.batch_size,
(torch.cuda.device_count() if self.use_cuda else 1),
))
self.doTraining(epoch_ndx, train_dl)
if __name__ == '__main__':
LunaBenchmarkApp().main()

View File

@ -0,0 +1,401 @@
import copy
import csv
import functools
import glob
import math
import os
import random
from collections import namedtuple
import SimpleITK as sitk
import numpy as np
import scipy.ndimage.morphology as morph
import torch
import torch.cuda
import torch.nn.functional as F
from torch.utils.data import Dataset
from util.disk import getCache
from util.util import XyzTuple, xyz2irc
from util.logconf import logging
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
raw_cache = getCache('part2ch13_raw')
MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
@functools.lru_cache(1)
def getCandidateInfoList(requireOnDisk_bool=True):
# We construct a set with all series_uids that are present on disk.
# This will let us use the data, even if we haven't downloaded all of
# the subsets yet.
mhd_list = glob.glob('data-unversioned/subset*/*.mhd')
presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
candidateInfo_list = []
with open('data/annotations_with_malignancy.csv', "r") as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
annotationDiameter_mm = float(row[4])
isMal_bool = {'False': False, 'True': True}[row[5]]
candidateInfo_list.append(
CandidateInfoTuple(
True,
True,
isMal_bool,
annotationDiameter_mm,
series_uid,
annotationCenter_xyz,
)
)
with open('data/candidates.csv', "r") as f:
for row in list(csv.reader(f))[1:]:
series_uid = row[0]
if series_uid not in presentOnDisk_set and requireOnDisk_bool:
continue
isNodule_bool = bool(int(row[4]))
candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
if not isNodule_bool:
candidateInfo_list.append(
CandidateInfoTuple(
False,
False,
False,
0.0,
series_uid,
candidateCenter_xyz,
)
)
candidateInfo_list.sort(reverse=True)
return candidateInfo_list
@functools.lru_cache(1)
def getCandidateInfoDict(requireOnDisk_bool=True):
candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
candidateInfo_dict = {}
for candidateInfo_tup in candidateInfo_list:
candidateInfo_dict.setdefault(candidateInfo_tup.series_uid,
[]).append(candidateInfo_tup)
return candidateInfo_dict
class Ct:
def __init__(self, series_uid):
mhd_path = glob.glob(
'data-unversioned/subset*/{}.mhd'.format(series_uid)
)[0]
ct_mhd = sitk.ReadImage(mhd_path)
self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
# CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
# HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
self.series_uid = series_uid
self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
candidateInfo_list = getCandidateInfoDict()[self.series_uid]
self.positiveInfo_list = [
candidate_tup
for candidate_tup in candidateInfo_list
if candidate_tup.isNodule_bool
]
self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
.nonzero()[0].tolist())
def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
for candidateInfo_tup in positiveInfo_list:
center_irc = xyz2irc(
candidateInfo_tup.center_xyz,
self.origin_xyz,
self.vxSize_xyz,
self.direction_a,
)
ci = int(center_irc.index)
cr = int(center_irc.row)
cc = int(center_irc.col)
index_radius = 2
try:
while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
index_radius += 1
except IndexError:
index_radius -= 1
row_radius = 2
try:
while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
row_radius += 1
except IndexError:
row_radius -= 1
col_radius = 2
try:
while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
col_radius += 1
except IndexError:
col_radius -= 1
# assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
# assert row_radius > 0
# assert col_radius > 0
boundingBox_a[
ci - index_radius: ci + index_radius + 1,
cr - row_radius: cr + row_radius + 1,
cc - col_radius: cc + col_radius + 1] = True
mask_a = boundingBox_a & (self.hu_a > threshold_hu)
return mask_a
def getRawCandidate(self, center_xyz, width_irc):
center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
self.direction_a)
slice_list = []
for axis, center_val in enumerate(center_irc):
start_ndx = int(round(center_val - width_irc[axis]/2))
end_ndx = int(start_ndx + width_irc[axis])
assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
if start_ndx < 0:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
start_ndx = 0
end_ndx = int(width_irc[axis])
if end_ndx > self.hu_a.shape[axis]:
# log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
# self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
end_ndx = self.hu_a.shape[axis]
start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
slice_list.append(slice(start_ndx, end_ndx))
ct_chunk = self.hu_a[tuple(slice_list)]
pos_chunk = self.positive_mask[tuple(slice_list)]
return ct_chunk, pos_chunk, center_irc
@functools.lru_cache(1, typed=True)
def getCt(series_uid):
return Ct(series_uid)
@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
ct = getCt(series_uid)
ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz,
width_irc)
ct_chunk.clip(-1000, 1000, ct_chunk)
return ct_chunk, pos_chunk, center_irc
@raw_cache.memoize(typed=True)
def getCtSampleSize(series_uid):
ct = Ct(series_uid)
return int(ct.hu_a.shape[0]), ct.positive_indexes
class Luna2dSegmentationDataset(Dataset):
def __init__(self,
val_stride=0,
isValSet_bool=None,
series_uid=None,
contextSlices_count=3,
fullCt_bool=False,
):
self.contextSlices_count = contextSlices_count
self.fullCt_bool = fullCt_bool
if series_uid:
self.series_list = [series_uid]
else:
self.series_list = sorted(getCandidateInfoDict().keys())
if isValSet_bool:
assert val_stride > 0, val_stride
self.series_list = self.series_list[::val_stride]
assert self.series_list
elif val_stride > 0:
del self.series_list[::val_stride]
assert self.series_list
self.sample_list = []
for series_uid in self.series_list:
index_count, positive_indexes = getCtSampleSize(series_uid)
if self.fullCt_bool:
self.sample_list += [(series_uid, slice_ndx)
for slice_ndx in range(index_count)]
else:
self.sample_list += [(series_uid, slice_ndx)
for slice_ndx in positive_indexes]
self.candidateInfo_list = getCandidateInfoList()
series_set = set(self.series_list)
self.candidateInfo_list = [cit for cit in self.candidateInfo_list
if cit.series_uid in series_set]
self.pos_list = [nt for nt in self.candidateInfo_list
if nt.isNodule_bool]
log.info("{!r}: {} {} series, {} slices, {} nodules".format(
self,
len(self.series_list),
{None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
len(self.sample_list),
len(self.pos_list),
))
def __len__(self):
return len(self.sample_list)
def __getitem__(self, ndx):
series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]
return self.getitem_fullSlice(series_uid, slice_ndx)
def getitem_fullSlice(self, series_uid, slice_ndx):
ct = getCt(series_uid)
ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512))
start_ndx = slice_ndx - self.contextSlices_count
end_ndx = slice_ndx + self.contextSlices_count + 1
for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
context_ndx = max(context_ndx, 0)
context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
# CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
# HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
# The lower bound gets rid of negative density stuff used to indicate out-of-FOV
# The upper bound nukes any weird hotspots and clamps bone down
ct_t.clamp_(-1000, 1000)
pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)
return ct_t, pos_t, ct.series_uid, slice_ndx
class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ratio_int = 2
def __len__(self):
return 300000
def shuffleSamples(self):
random.shuffle(self.candidateInfo_list)
random.shuffle(self.pos_list)
def __getitem__(self, ndx):
candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
return self.getitem_trainingCrop(candidateInfo_tup)
def getitem_trainingCrop(self, candidateInfo_tup):
ct_a, pos_a, center_irc = getCtRawCandidate(
candidateInfo_tup.series_uid,
candidateInfo_tup.center_xyz,
(7, 96, 96),
)
pos_a = pos_a[3:4]
row_offset = random.randrange(0,32)
col_offset = random.randrange(0,32)
ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64,
col_offset:col_offset+64]).to(torch.float32)
pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64,
col_offset:col_offset+64]).to(torch.long)
slice_ndx = center_irc.index
return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx
class PrepcacheLunaDataset(Dataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.candidateInfo_list = getCandidateInfoList()
self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
self.seen_set = set()
self.candidateInfo_list.sort(key=lambda x: x.series_uid)
def __len__(self):
return len(self.candidateInfo_list)
def __getitem__(self, ndx):
# candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx)
candidateInfo_tup = self.candidateInfo_list[ndx]
getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96))
series_uid = candidateInfo_tup.series_uid
if series_uid not in self.seen_set:
self.seen_set.add(series_uid)
getCtSampleSize(series_uid)
# ct = getCt(series_uid)
# for mask_ndx in ct.positive_indexes:
# build2dLungMask(series_uid, mask_ndx)
return 0, 1 #candidate_t, pos_t, series_uid, center_t
class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset):
def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3):
assert contextSlices_count == 3
data = torch.load('./imgs_and_masks.pt')
suids = list(set(data['suids']))
trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1)
trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i}
trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]])
if not isValSet_bool:
self.imgs = data["imgs"][trn_mask]
self.masks = data["masks"][trn_mask]
self.suids = [s for s, i in zip(data["suids"], trn_mask) if i]
else:
self.imgs = data["imgs"][~trn_mask]
self.masks = data["masks"][~trn_mask]
self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i]
# discard spurious hotspots and clamp bone
self.imgs.clamp_(-1000, 1000)
self.imgs /= 1000
def __len__(self):
return len(self.imgs)
def __getitem__(self, i):
oh, ow = torch.randint(0, 32, (2,))
sl = self.masks.size(1)//2
return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999

View File

@ -0,0 +1,224 @@
import math
import random
from collections import namedtuple
import torch
from torch import nn as nn
import torch.nn.functional as F
from util.logconf import logging
from util.unet import UNet
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
class UNetWrapper(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
self.unet = UNet(**kwargs)
self.final = nn.Sigmoid()
self._init_weights()
def _init_weights(self):
init_set = {
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
}
for m in self.modules():
if type(m) in init_set:
nn.init.kaiming_normal_(
m.weight.data, mode='fan_out', nonlinearity='relu', a=0
)
if m.bias is not None:
fan_in, fan_out = \
nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
# nn.init.constant_(self.unet.last.bias, -4)
# nn.init.constant_(self.unet.last.bias, 4)
def forward(self, input_batch):
bn_output = self.input_batchnorm(input_batch)
un_output = self.unet(bn_output)
fn_output = self.final(un_output)
return fn_output
class SegmentationAugmentation(nn.Module):
def __init__(
self, flip=None, offset=None, scale=None, rotate=None, noise=None
):
super().__init__()
self.flip = flip
self.offset = offset
self.scale = scale
self.rotate = rotate
self.noise = noise
def forward(self, input_g, label_g):
transform_t = self._build2dTransformMatrix()
transform_t = transform_t.expand(input_g.shape[0], -1, -1)
transform_t = transform_t.to(input_g.device, torch.float32)
affine_t = F.affine_grid(transform_t[:,:2],
input_g.size(), align_corners=False)
augmented_input_g = F.grid_sample(input_g,
affine_t, padding_mode='border',
align_corners=False)
augmented_label_g = F.grid_sample(label_g.to(torch.float32),
affine_t, padding_mode='border',
align_corners=False)
if self.noise:
noise_t = torch.randn_like(augmented_input_g)
noise_t *= self.noise
augmented_input_g += noise_t
return augmented_input_g, augmented_label_g > 0.5
def _build2dTransformMatrix(self):
transform_t = torch.eye(3)
for i in range(2):
if self.flip:
if random.random() > 0.5:
transform_t[i,i] *= -1
if self.offset:
offset_float = self.offset
random_float = (random.random() * 2 - 1)
transform_t[2,i] = offset_float * random_float
if self.scale:
scale_float = self.scale
random_float = (random.random() * 2 - 1)
transform_t[i,i] *= 1.0 + scale_float * random_float
if self.rotate:
angle_rad = random.random() * math.pi * 2
s = math.sin(angle_rad)
c = math.cos(angle_rad)
rotation_t = torch.tensor([
[c, -s, 0],
[s, c, 0],
[0, 0, 1]])
transform_t @= rotation_t
return transform_t
# MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
#
# class SegmentationMask(nn.Module):
# def __init__(self):
# super().__init__()
#
# self.conv_list = nn.ModuleList([
# self._make_circle_conv(radius) for radius in range(1, 8)
# ])
#
# def _make_circle_conv(self, radius):
# diameter = 1 + radius * 2
#
# a = torch.linspace(-1, 1, steps=diameter)**2
# b = (a[None] + a[:, None])**0.5
#
# circle_weights = (b <= 1.0).to(torch.float32)
#
# conv = nn.Conv2d(1, 1, kernel_size=diameter, padding=radius, bias=False)
# conv.weight.data.fill_(1)
# conv.weight.data *= circle_weights / circle_weights.sum()
#
# return conv
#
#
# def erode(self, input_mask, radius, threshold=1):
# conv = self.conv_list[radius - 1]
# input_float = input_mask.to(torch.float32)
# result = conv(input_float)
#
# # log.debug(['erode in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
# # log.debug(['erode out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
#
# return result >= threshold
#
# def deposit(self, input_mask, radius, threshold=0):
# conv = self.conv_list[radius - 1]
# input_float = input_mask.to(torch.float32)
# result = conv(input_float)
#
# # log.debug(['deposit in ', radius, threshold, input_float.min().item(), input_float.mean().item(), input_float.max().item()])
# # log.debug(['deposit out', radius, threshold, result.min().item(), result.mean().item(), result.max().item()])
#
# return result > threshold
#
# def fill_cavity(self, input_mask):
# cumsum = input_mask.cumsum(-1)
# filled_mask = (cumsum > 0)
# filled_mask &= (cumsum < cumsum[..., -1:])
# cumsum = input_mask.cumsum(-2)
# filled_mask &= (cumsum > 0)
# filled_mask &= (cumsum < cumsum[..., -1:, :])
#
# return filled_mask
#
#
# def forward(self, input_g, raw_pos_g):
# gcc_g = input_g + 1
#
# with torch.no_grad():
# # log.info(['gcc_g', gcc_g.min(), gcc_g.mean(), gcc_g.max()])
#
# raw_dense_mask = gcc_g > 0.7
# dense_mask = self.deposit(raw_dense_mask, 2)
# dense_mask = self.erode(dense_mask, 6)
# dense_mask = self.deposit(dense_mask, 4)
#
# body_mask = self.fill_cavity(dense_mask)
# air_mask = self.deposit(body_mask & ~dense_mask, 5)
# air_mask = self.erode(air_mask, 6)
#
# lung_mask = self.deposit(air_mask, 5)
#
# raw_candidate_mask = gcc_g > 0.4
# raw_candidate_mask &= air_mask
# candidate_mask = self.erode(raw_candidate_mask, 1)
# candidate_mask = self.deposit(candidate_mask, 1)
#
# pos_mask = self.deposit((raw_pos_g > 0.5) & lung_mask, 2)
#
# neg_mask = self.deposit(candidate_mask, 1)
# neg_mask &= ~pos_mask
# neg_mask &= lung_mask
#
# # label_g = (neg_mask | pos_mask).to(torch.float32)
# label_g = (pos_mask).to(torch.float32)
# neg_g = neg_mask.to(torch.float32)
# pos_g = pos_mask.to(torch.float32)
#
# mask_dict = {
# 'raw_dense_mask': raw_dense_mask,
# 'dense_mask': dense_mask,
# 'body_mask': body_mask,
# 'air_mask': air_mask,
# 'raw_candidate_mask': raw_candidate_mask,
# 'candidate_mask': candidate_mask,
# 'lung_mask': lung_mask,
# 'neg_mask': neg_mask,
# 'pos_mask': pos_mask,
# }
#
# return label_g, neg_g, pos_g, lung_mask, mask_dict

View File

@ -0,0 +1,69 @@
import timing
import argparse
import sys
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import SGD
from torch.utils.data import DataLoader
from util.util import enumerateWithEstimate
from .dsets import PrepcacheLunaDataset, getCtSampleSize
from util.logconf import logging
# from .model import LunaModel
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
# log.setLevel(logging.DEBUG)
class LunaPrepCacheApp:
@classmethod
def __init__(self, sys_argv=None):
if sys_argv is None:
sys_argv = sys.argv[1:]
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size',
help='Batch size to use for training',
default=1024,
type=int,
)
parser.add_argument('--num-workers',
help='Number of worker processes for background data loading',
default=8,
type=int,
)
# parser.add_argument('--scaled',
# help="Scale the CT chunks to square voxels.",
# default=False,
# action='store_true',
# )
self.cli_args = parser.parse_args(sys_argv)
def main(self):
log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
self.prep_dl = DataLoader(
PrepcacheLunaDataset(
# sortby_str='series_uid',
),
batch_size=self.cli_args.batch_size,
num_workers=self.cli_args.num_workers,
)
batch_iter = enumerateWithEstimate(
self.prep_dl,
"Stuffing cache",
start_ndx=self.prep_dl.num_workers,
)
for batch_ndx, batch_tup in batch_iter:
pass
if __name__ == '__main__':
LunaPrepCacheApp().main()