Source code for easycv.toolkit.quantize.quantize_utils

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from mmcv.parallel import scatter_kwargs
from mmcv.runner import get_dist_info

quantize_config = {
    'device': 'cpu',
    'backend': 'PyTorch',
}


[docs]def calib(model, data_loader): for cur_iter, data in enumerate(data_loader): input_args, kwargs = scatter_kwargs(None, data, [-1]) with torch.no_grad(): kwargs[0]['img'] = kwargs[0]['img'].squeeze(dim=0) model(kwargs[0]['img']) if cur_iter == 2: return