Source code for easycv.utils.test_util

# Copyright (c) Alibaba, Inc. and its affiliates.
"""Contains functions which are convenient for unit testing."""
import logging
import os
import shutil
import socket
import subprocess
import uuid
from multiprocessing import Process

import numpy as np
import torch

TEST_DIR = '/tmp/ev_pytorch_test'


[docs]def get_tmp_dir(): if os.environ.get('TEST_DIR', '') != '': global TEST_DIR TEST_DIR = os.environ['TEST_DIR'] dir_name = os.path.join(TEST_DIR, uuid.uuid4().hex) if os.path.exists(dir_name): shutil.rmtree(dir_name) os.makedirs(dir_name) return dir_name
[docs]def clear_all_tmp_dirs(): shutil.rmtree(TEST_DIR)
[docs]def replace_data_for_test(cfg): """ replace real data with test data Args: cfg: Config object """ pass
# function dectorator to run function in subprocess # if a function will start a tf session. Because tensorflow # gpu memory will not be cleared until the process exit
[docs]def RunAsSubprocess(f): def wrapped_f(*args, **kw): p = Process(target=f, args=args, kwargs=kw) p.start() p.join(timeout=600) assert p.exitcode == 0, 'subprocess run failed: %s' % f.__name__ return wrapped_f
[docs]def clean_up(test_dir): if test_dir is not None: shutil.rmtree(test_dir)
[docs]def run_in_subprocess(cmd): try: with subprocess.Popen( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as return_info: while True: next_line = return_info.stdout.readline() return_line = next_line.decode('utf-8', 'ignore').strip() if return_line == '' and return_info.poll() != None: break if return_line != '': logging.info(return_line) return_code = return_info.wait() if return_code: raise subprocess.CalledProcessError(return_code, return_info) except Exception as e: logging.info(e)
[docs]def dist_exec_wrapper(cmd, nproc_per_node, node_rank=0, nnodes=1, port='29527', addr='127.0.0.1', python_path=None): """ donot forget init dist in your function or script of cmd ```python from mmcv.runner import init_dist init_dist(launcher='pytorch') ``` """ dist_world_size = nproc_per_node * nnodes cur_env = os.environ.copy() cur_env['WORLD_SIZE'] = str(dist_world_size) cur_env['MASTER_ADDR'] = addr if is_port_used(port): port = str(get_random_port()) logging.warning('Given port is used, change to port %s' % port) cur_env['MASTER_PORT'] = port processes = [] for local_rank in range(0, nproc_per_node): # rank of each process dist_rank = nproc_per_node * node_rank + local_rank cur_env['RANK'] = str(dist_rank) cur_env['LOCAL_RANK'] = str(local_rank) if python_path: cur_env['PYTHONPATH'] = ':'.join( (cur_env['PYTHONPATH'], python_path)) process = subprocess.Popen(cmd, env=cur_env, shell=True) processes.append(process) for process in processes: process.wait() if process.returncode != 0: raise subprocess.CalledProcessError( returncode=process.returncode, cmd=cmd)
[docs]def is_port_used(port, host='127.0.0.1'): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: s.connect((host, int(port))) return True except: return False finally: s.close()
[docs]def get_random_port(): while True: port = np.random.randint(low=5000, high=10000) if is_port_used(port): continue else: break logging.info('Random port: %s' % port) return port
[docs]def pseudo_dist_init(): os.environ['RANK'] = '0' os.environ['WORLD_SIZE'] = '1' os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = str(get_random_port()) torch.cuda.set_device(0) from torch import distributed as dist dist.init_process_group(backend='nccl')