Source code for easycv.hooks.export_hook

# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from mmcv.runner import Hook
from mmcv.runner.dist_utils import master_only

from easycv.utils.config_tools import validate_export_config
from .registry import HOOKS


[docs]@HOOKS.register_module class ExportHook(Hook): """ export model when training on pai """
[docs] def __init__( self, cfg, ckpt_filename_tmpl='epoch_{}.pth', export_ckpt_filename_tmpl='epoch_{}_export.pt', export_after_each_ckpt=False, ): """ Args: cfg: config dict ckpt_filename_tmpl: checkpoint filename template """ self.cfg = validate_export_config(cfg) self.work_dir = cfg.work_dir self.ckpt_filename_tmpl = ckpt_filename_tmpl self.export_ckpt_filename_tmpl = export_ckpt_filename_tmpl self.export_after_each_ckpt = export_after_each_ckpt or cfg.get( 'export_after_each_ckpt', False)
[docs] def export_model(self, runner, epoch): # epoch = runner.epoch ckpt_fname = self.ckpt_filename_tmpl.format(epoch) export_ckpt_fname = self.export_ckpt_filename_tmpl.format(epoch) local_ckpt = os.path.join(self.work_dir, ckpt_fname) export_local_ckpt = os.path.join(self.work_dir, export_ckpt_fname) if not os.path.exists(local_ckpt): runner.logger.warning(f'{local_ckpt} does not exists, skip export') else: runner.logger.info(f'export {local_ckpt} to {export_local_ckpt}') from easycv.apis.export import export export(self.cfg, local_ckpt, export_local_ckpt)
[docs] @master_only def after_train_iter(self, runner): pass
[docs] @master_only def after_train_epoch(self, runner): # do export after every ckpy is right! should do so! if self.export_after_each_ckpt: self.export_model(runner, runner.epoch) pass
[docs] @master_only def after_run(self, runner): self.export_model(runner, runner.epoch)