zpwiki/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/prepcache.py

70 lines
1.8 KiB
Python
Raw Permalink Normal View History

import timing
import argparse
import sys
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import SGD
from torch.utils.data import DataLoader
from util.util import enumerateWithEstimate
from .dsets import PrepcacheLunaDataset, getCtSampleSize
from util.logconf import logging
# from .model import LunaModel
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
# log.setLevel(logging.DEBUG)
class LunaPrepCacheApp:
@classmethod
def __init__(self, sys_argv=None):
if sys_argv is None:
sys_argv = sys.argv[1:]
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size',
help='Batch size to use for training',
default=1024,
type=int,
)
parser.add_argument('--num-workers',
help='Number of worker processes for background data loading',
default=8,
type=int,
)
# parser.add_argument('--scaled',
# help="Scale the CT chunks to square voxels.",
# default=False,
# action='store_true',
# )
self.cli_args = parser.parse_args(sys_argv)
def main(self):
log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
self.prep_dl = DataLoader(
PrepcacheLunaDataset(
# sortby_str='series_uid',
),
batch_size=self.cli_args.batch_size,
num_workers=self.cli_args.num_workers,
)
batch_iter = enumerateWithEstimate(
self.prep_dl,
"Stuffing cache",
start_ndx=self.prep_dl.num_workers,
)
for batch_ndx, batch_tup in batch_iter:
pass
if __name__ == '__main__':
LunaPrepCacheApp().main()