How this project works?
train.py 실행
main 일 경우, args에 arguments 넣어줌
https://engineer-mole.tistory.com/213 ArgumentParser가 뭔데? 실행할 때 뭔가 args로 넣어줄 것들에 대해서 default 및 설명 등등을 넣어둔 것이다
아래는 args를 출력해본 결과다. 딱히 중요한 건 없다. 그저 ArgumentParser라는 것 뿐 (참고: https://engineer-mole.tistory.com/213)
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
그리고 config를
config = ConfigParser.from_args(args, options)
를 통해 받는다
from_args는 다음과 같이 생겼다.
@classmethod
def from_args(cls, args, options=''):
"""
Initialize this class from some cli arguments. Used in train, test.
"""
for opt in options:
args.add_argument(*opt.flags, default=None, type=opt.type)
if not isinstance(args, tuple):
args = args.parse_args()
if args.device is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
if args.resume is not None:
resume = Path(args.resume)
cfg_fname = resume.parent / 'config.json'
else:
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
assert args.config is not None, msg_no_cfg
resume = None
cfg_fname = Path(args.config)
config = read_json(cfg_fname)
if args.config and resume:
# update new config for fine-tuning
config.update(read_json(args.config))
# parse custom cli options into dictionary
modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
return cls(config, resume, modification)
cls를 리턴하며, “클래스”를 생성해서 리턴한다..! 우선 from_args는 class method기 때문에, ConfigParser 를 만들지 않고(인스턴스화 하지 않고) 해당 함수를 부를 수 있다! (참고: https://builtin.com/software-engineering-perspectives/python-cls https://velog.io/@rlath/cls-vs-self) ex)
# 클래스 함수가 아니라면..
config_parser = ConfigParser(~)
config_parser.from_args(~)
하지만 from_args는 class method기 때문에, 만들지 않고도 부를 수 있다!! 그리고 from_args의 return으로 해당 class의 인스턴스를 만들어서 return해주고 있다!!
결국 config는 ConfigParser의 인스턴스이며, config 정보를 모두 들고있다!!(from_args에서 불러왔기 때문이다)
이제 main(config)를 볼 차례다.
def main(config):
logger = config.get_logger('train')
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)
# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)
# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)
trainer.train()
우선 다음과 같고, config.init_obj를 확인해 보자.
def init_obj(self, name, module, *args, **kwargs):
"""
Finds a function handle with the name given as 'type' in config, and returns the
instance initialized with corresponding arguments given.
`object = config.init_obj('name', module, a, b=1)`
is equivalent to
`object = module.name(a, b=1)`
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return getattr(module, module_name)(*args, **module_args)
module_name과 module_args를 from_args()를 통해 얻어낸 config에서 가져온다. 여기서는 self[name]으로 바로 가져오지만, 사실 길고 긴 여정이 있다. self[name]은 def __getitem__에 의한것이고, def __getitem__은 다음과 같다.
def __getitem__(self, name):
"""Access items like ordinary dict."""
return self.config[name]
엥? self.config는 선언한 적이 없고, init 에서도 self._config로 config데이터를 가져왔는데?? 라고 한다면 정답이다. 실제로도 __init__에는 self._config = ~ 로 _config만 존재한다. 그렇다면 getitem에서 저 self.config는 어디서 나온놈인가?
@property
def config(self):
return self._config
이것으로 설명할 수 있다. (참고: https://www.daleseo.com/python-property/) 아무튼 결론적으로 self[name]을 통해 config에 누구보다 빠르게 접근할 수 있게 되었다.
마지막 단계가 남았는데, 그렇다면 getattr은 뭐하는 놈일까?
그렇다. 그냥 module.mudule_name 과 진배없다.
~~getattr(module, module_name) == module.module_name~~
그렇다면 뒤에 붙는 *args와 **module_args는 어떻게 설명하지?
여기서 저 “module” 이 뭔지 확인해야 한다.
다시 train.py로 돌아와보면
data_loader = config.init_obj('data_loader', module_data)
로, module_data 라는 놈이 넘어간다. 이 module_data는 누구인가?
뭔가 임포트 한것이구나를 알 수 있다. 그러면 data_loders가 뭔지 확인하자.
data_loaders는 다음과 같다.
from torchvision import datasets, transforms
from base import BaseDataLoader
class MnistDataLoader(BaseDataLoader):
"""
MNIST data loading demo using BaseDataLoader
"""
def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
trsfm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
self.data_dir = data_dir
self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm)
super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
와우. data_loders에 있는 class이름(config의 ‘type’에 정의하겠지?)를 getattr(module, module_name)으로 불러오고, 해당 class를 인스턴스화 하기위해 (*args, **module_args)를 파라미터로 넘겨버리는 것인 것이다!!!!
즉, 긴 여정이었지만 풀어보자면, getattr(module, module_name)(*args, **module_args)는 getattr(data_loder.data_loders, {config에 정의된 모듈 이름})(*args, **module_args) 이고, data_loders에 정의되어 있는 어떤 모듈(클래스)를 인스턴스화 해서 return 해주는 것이다.
그에 따라 저 파라미터들을 넘겨주는데 있어서 args와 module_args로 편하게 넘겨준다고 생각된다.
architecture도 이하 동문이다. model.model에 모델을 정의해 두면 config파일을 수정하는 것 만으로도 간편하게 모델을 교체할 수 있게 되는 것이다!! Wow