"""
DMLC submission script, SLURM version
"""
# pylint: disable=invalid-name
from __future__ import absolute_import

import subprocess, logging
from threading import Thread
from . import tracker

def get_mpi_env(envs):
    """get the slurm command for setting the environment
    """
    cmd = ''
    for k, v in envs.items():
        cmd += '%s=%s ' % (k, str(v))
    return cmd


def submit(args):
    """Submission script with SLURM."""
    def mpi_submit(nworker, nserver, pass_envs):
        """Internal closure for job submission."""
        def run(prog):
            """run the program"""
            subprocess.check_call(prog, shell=True)

        cmd = ' '.join(args.command)

        pass_envs['DMLC_JOB_CLUSTER'] = 'slurm'

        if args.slurm_worker_nodes is None:
          nworker_nodes = nworker
        else:
          nworker_nodes=args.slurm_worker_nodes


        # start workers
        if nworker > 0:
          logging.info('Start %d workers by srun' % nworker)
          pass_envs['DMLC_ROLE'] = 'worker'
          prog = '%s srun --share --exclusive=user -N %d -n %d %s' % (get_mpi_env(pass_envs), nworker_nodes, nworker, cmd)
          thread = Thread(target=run, args=(prog,))
          thread.setDaemon(True)
          thread.start()


        if args.slurm_server_nodes is None:
          nserver_nodes = nserver
        else:
          nserver_nodes=args.slurm_server_nodes

        # start servers
        if nserver > 0:
          logging.info('Start %d servers by srun' % nserver)
          pass_envs['DMLC_ROLE'] = 'server'
          prog = '%s srun --share --exclusive=user -N %d -n %d %s' % (get_mpi_env(pass_envs), nserver_nodes, nserver, cmd)
          thread = Thread(target=run, args=(prog,))
          thread.setDaemon(True)
          thread.start()


    tracker.submit(args.num_workers, args.num_servers,
                   fun_submit=mpi_submit,
                   pscmd=(' '.join(args.command)))