zpwiki/pages/students/2016/lukas_pokryvka/dp2021/lungCancer/model/benchmark_seg.py

77 lines
2.0 KiB
Python

import argparse
import datetime
import os
import socket
import sys
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
import torch.optim
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from util.util import enumerateWithEstimate
from p2ch13.dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
from util.logconf import logging
from util.util import xyz2irc
from p2ch13.model_seg import UNetWrapper, SegmentationAugmentation
from p2ch13.train_seg import LunaTrainingApp
log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)
class BenchmarkLuna2dSegmentationDataset(TrainingLuna2dSegmentationDataset):
def __len__(self):
# return 500
return 5000
return 1000
class LunaBenchmarkApp(LunaTrainingApp):
def initTrainDl(self):
train_ds = BenchmarkLuna2dSegmentationDataset(
val_stride=10,
isValSet_bool=False,
contextSlices_count=3,
# augmentation_dict=self.augmentation_dict,
)
batch_size = self.cli_args.batch_size
if self.use_cuda:
batch_size *= torch.cuda.device_count()
train_dl = DataLoader(
train_ds,
batch_size=batch_size,
num_workers=self.cli_args.num_workers,
pin_memory=self.use_cuda,
)
return train_dl
def main(self):
log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
train_dl = self.initTrainDl()
for epoch_ndx in range(1, 2):
log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
epoch_ndx,
self.cli_args.epochs,
len(train_dl),
len([]),
self.cli_args.batch_size,
(torch.cuda.device_count() if self.use_cuda else 1),
))
self.doTraining(epoch_ndx, train_dl)
if __name__ == '__main__':
LunaBenchmarkApp().main()