Python 2.7 脚本中的多个 SSH 连接 - 多处理与线程

Multiple SSH Connections in a Python 2.7 script- Multiprocessing Vs Threading

我有一个脚本获取节点列表作为参数(可以是 10 个甚至 50 个),并通过 SSH 连接到每个节点以 运行 服务重启命令。 目前,我正在使用 multiprocessing 来并行化脚本(也将批处理大小作为参数),但是我听说线程模块可以帮助我以更快、更容易管理的方式执行任务(我将 try..except KeyboardInterrupt 与 sys.exit()pool.terminate() 一起使用,但它不会停止整个脚本,因为它是一个不同的过程)。 因为我知道多线程更轻量级并且更容易管理我的案例,所以我试图将我的脚本转换为使用线程而不是多处理,但它不能正常工作。

多处理中的当前代码(有效):

def restart_service(node, initd_tup):
    """
    Get a node name as an argument, connect to it via SSH and run the service restart command..
    """
    command = 'service {0} restart'.format(initd_tup[node])
    logger.info('[{0}] Connecting to {0} in order to restart {1} service...'.format(node, initd_tup[node]))
    try:
        ssh.connect(node)
        stdin, stdout, stderr = ssh.exec_command(command)
        result = stdout.read()
        if not result:
            result_err = stderr.read()
            print '{0}{1}[{2}] ERROR: {3}{4}'.format(Color.BOLD, Color.RED, node, result_err, Color.END)
            logger.error('[{0}]  Result of command {1} output: {2}'.format(node, command, result_err))
        else:
            print '{0}{1}{2}[{3}]{4}\n{5}'.format(Color.BOLD, Color.UNDERLINE, Color.GREEN, node, Color.END, result)
            logger.info('[{0}]  Result of command {1} output: {2}'.format(node, command, result.replace("\n", "... ")))
        ssh.close()
    except paramiko.AuthenticationException:
        print "{0}{1}ERROR! SSH failed with Authentication Error. Make sure you run the script as root and try again..{2}".format(Color.BOLD, Color.RED, Color.END)
        logger.error('SSH Authentication failed, thrown error message to the user to make sure script is run with root permissions')
        pool.terminate()
    except socket.error as error:
        print("[{0}]{1}{2} ERROR! SSH failed with error: {3}{4}\n".format(node, Color.RED, Color.BOLD, error, Color.END))
        logger.error("[{0}] SSH failed with error: {1}".format(node, error))
    except KeyboardInterrupt:
        pool.terminate()
        general_utils.terminate(logger)


def convert_to_tuple(a_b):
    """Convert 'f([1,2])' to 'f(1,2)' call."""
    return restart_service(*a_b)


def iterate_nodes_and_call_exec_func(nodes_list):
    """
    Iterate over the list of nodes to process,
    create a list of nodes that shouldn't exceed the batch size provided (or 1 if not provided).
    Then using the multiprocessing module, call the restart_service func on x nodes in parallel (where x is the batch size).
    If batch_sleep arg was provided, call the sleep func and provide the batch_sleep argument between each batch.
    """
    global pool
    general_utils.banner('Initiating service restart')
    pool = multiprocessing.Pool(10)
    manager = multiprocessing.Manager()
    work = manager.dict()
    for line in nodes_list:
        work[line] = general_utils.get_initd(logger, args, line)
        if len(work) >= int(args.batch):
            pool.map(convert_to_tuple, itertools.izip(work.keys(), itertools.repeat(work)))
            work = {}
            if int(args.batch_sleep) > 0:
                logger.info('*** Sleeping for %d seconds before moving on to next batch ***', int(args.batch_sleep))
                general_utils.sleep_func(int(args.batch_sleep))
    if len(work) > 0:
        try:
            pool.map(convert_to_tuple, itertools.izip(work.keys(), itertools.repeat(work)))
        except KeyboardInterrupt:
            pool.terminate()
            general_utils.terminate(logger)

这就是我尝试使用线程处理的方法,但它不起作用(当我分配一个大于 1 的 batch_size 时,脚本会卡住,我必须强行将其终止。

def parse_args():
    """Define the argument parser, and the arguments to accept.."""
    global args, parser
    parser = MyParser(description=__doc__)
    parser.add_argument('-H', '--host', help='List of hosts to process, separated by "," and NO SPACES!')
    parser.add_argument('--batch', help='Do requests in batches', default=1)
    args = parser.parse_args()

    # If no arguments were passed, print the help file and exit with ERROR..
    if len(sys.argv) == 1:
        parser.print_help()
        print '\n\nERROR: No arguments passed!\n'
        sys.exit(3)


def do_work(node):
    logger.info('[{0}]'.format(node))
    try:
        ssh.connect(node)
        stdin, stdout, stderr = ssh.exec_command('hostname ; date')
        print stdout.read()
        ssh.close()
    except:
        print 'ERROR!'
        sys.exit(2)


def worker():
    while True:
        item = q.get()
        do_work(item)
        q.task_done()


def iterate():
    for item in args.host.split(","):
        q.put(item)

    for i in range(int(args.batch)):
        t = Thread(target=worker)
        t.daemon = True
        t.start()

    q.join()


def main():
    parse_args()
    try:
        iterate()

    except KeyboardInterrupt:
        exit(1)

在脚本日志中,我看到 Paramiko 生成的警告如下:

2016-01-04 22:51:37,613 WARNING: Oops, unhandled type 3

我尝试 Google 这个 unhandled type 3 错误,但没有找到与我的问题相关的任何内容,因为它正在谈论 2 因素身份验证或尝试通过密码和 SSH 密钥连接同时,但我只加载主机密钥而不向 SSH 客户端提供任何密码。

对于此事,我将不胜感激..

设法使用 parallel-ssh 模块解决了我的问题。

这是代码,用我想要的操作修复:

def iterate_nodes_and_call_exec_func(nodes):
    """
    Get a dict as an argument, containing linux services (initd) as the keys,
    and a list of nodes on which the linux service needs to be checked/
    Iterate over the list of nodes to process,
    create a list of nodes that shouldn't exceed the batch size provided (or 1 if not provided).
    Then using the parallel-ssh module, call the restart_service func on x nodes in parallel (where x is the batch size)
    and provide the linux service (initd) to process.
    If batch_sleep arg was provided, call the sleep func and provide the batch_sleep argument between each batch.
    """

    for initd in nodes.keys():
        work = dict()
        work[initd] = []
        count = 0
        for node in nodes[initd]:
            count += 1
            work[initd].append(node)
            if len(work[initd]) == args.batch:
                restart_service(work[initd], initd)
                work[initd] = []
                if args.batch_sleep > 0 and count < len(nodes[initd]):
                    logger.info('*** Sleeping for %d seconds before moving on to next batch ***', args.batch_sleep)
                    general_utils.sleep_func(int(args.batch_sleep))
        if len(work[initd]) > 0:
            restart_service(work[initd], initd)


def restart_service(nodes, initd):
    """
    Get a list of nodes and linux service as an argument,
    then connect by Parallel SSH module to the nodes and run the service restart command..
    """
    command = 'service {0} restart'.format(initd)
    logger.info('Connecting to {0} to restart the {1} service...'.format(nodes, initd))
    try:
        client = pssh.ParallelSSHClient(nodes, pool_size=args.batch, timeout=10, num_retries=1)
        output = client.run_command(command, sudo=True)
        for node in output:
            for line in output[node]['stdout']:
                if client.get_exit_code(output[node]) == 0:
                    print '[{0}]{1}{2}  {3}{4}'.format(node, Color.BOLD, Color.GREEN, line, Color.END)
                else:
                    print '[{0}]{1}{2}  ERROR! {3}{4}'.format(node, Color.BOLD, Color.RED, line, Color.END)
                    logger.error('[{0}]  Result of command {1} output: {2}'.format(node, command, line))

    except pssh.AuthenticationException:
        print "{0}{1}ERROR! SSH failed with Authentication Error. Make sure you run the script as root and try again..{2}".format(Color.BOLD, Color.RED, Color.END)
        logger.error('SSH Authentication failed, thrown error message to the user to make sure script is run with root permissions')
        sys.exit(2)

    except pssh.ConnectionErrorException as error:
        print("[{0}]{1}{2} ERROR! SSH failed with error: {3}{4}\n".format(error[1], Color.RED, Color.BOLD, error[3], Color.END))
        logger.error("[{0}] SSH Failed with error: {1}".format(error[1], error[3]))
        restart_service(nodes[nodes.index(error[1])+1:], initd)

    except KeyboardInterrupt:
        general_utils.terminate(logger)


def generate_nodes_by_initd_dict(nodes_list):
    """
    Get a list of nodes as an argument.
    Then by calling the get_initd function for each of the nodes,
    Build a dict based on linux services (initd) as keys and a list of nodes on which the initd
    needs to be processed as values. Then call the iterate_nodes_and_call_exec_func and provide the generated dict
     as its argument.
    """
    nodes = {}
    for node in nodes_list:
        initd = general_utils.get_initd(logger, args, node)
        if initd in nodes.keys():
            nodes[initd].append(node)
        else:
            nodes[initd] = [node, ]

    return iterate_nodes_and_call_exec_func(nodes)


def main():
    parse_args()
    try:
        general_utils.init_script('Service Restart', logger, log)
        log_args(logger, args)
        generate_nodes_by_initd_dict(general_utils.generate_nodes_list(args, logger, ['service', 'datacenter', 'lob']))

    except KeyboardInterrupt:
        general_utils.terminate(logger)

    finally:
        general_utils.wrap_up(logger)


if __name__ == '__main__':
    main()

除了使用 pssh 模块外,经过更彻底的故障排除工作后,我还能够使用本机线程模块解决问题中发布的原始代码,方法是为每个线程,而不是对所有线程使用相同的客户端。 所以基本上(仅更新原始问题中的 do_work 函数),这里是更改:

def do_work(node):
    logger.info('[{0}]'.format(node))
    try:
        ssh = paramiko.SSHClient() 
        ssh.connect(node)
        stdin, stdout, stderr = ssh.exec_command('hostname ; date')
        print stdout.read()
        ssh.close()
    except:
        print 'ERROR!'
        sys.exit(2)

以这种方式完成后,本机线程模块将完美运行!