最近工作涉及到修改分布式训练代码,以前半懂非懂,这次改的时候漏了一些细节,带来不必要的麻烦,索性花点时间搞明白。
Pytorch 分布式训练主要有两种方式:
torch.nn.DataParallel ==> 简称 DP
torch.nn.parallel.DistributedDataParallel ==> 简称DDP
其中 DP 只用于单机多卡,DDP 可以用于单机多卡也可用于多机多卡,后者现在也是Pytorch训练的主流用法,DP写法比较简单,但即使在单机多卡情况下也比 DDP 慢。
可参考:https://pytorch.org/docs/stable/nn.html#dataparallel-layers-multi-gpu-distributed 。
本文主要介绍DP和DDP的使用方式。
DP
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import torch import torch.nn as nn # 构造模型 net = model(imput_size, output_size) # 模型放在GPU上 net = net.cuda() net=nn.DataParallel(net) # 数据放在GPU上 inputs, labels = inputs.cuda(), labels.cuda() result = net(inputs) # 其他和正常模型训练无差别 |
关于Dataparallel, 摘取主要源码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | class DataParallel(Module): def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() # 如果没有GPU可用,直接返回 if not torch.cuda.is_available(): self.module = module self.device_ids = [] return # 如果有GPU,但没有指定的话,device_ids为所有可用GPU if device_ids is None: device_ids = list(range(torch.cuda.device_count())) # 默认输出在0号卡上 if output_device is None: output_device = device_ids[0] |
总结
如果不设定好要使用的device_ids的话, 程序会自动找到这个机器上面可以用的所有的显卡用于训练。
如果想要限制使用的显卡数,怎么办呢?
在代码最前面使用:
1 2 3 4 5 6 7 8 9 10 11 | os.environ['CUDA_VISIBLE_DEVICES'] == '0,5' # 限制代码能看到的GPU个数,这里表示指定只使用实际的0号和5号GPU # 注意:这里的赋值必须是字符串,list会报错 # 这时候device_count = 2 device_ids = range(torch.cuda.device_count()) # device_ids = [0,1] 这里的0就是上述指定的'0'号卡,1对应'5'号卡。 net = nn.DataParallel(net,device_ids) # !!!模型和数据都由主gpu(0号卡)分发。 |
值得注意的是,在使用
例如上面我们设定的是
也就是说程序所使用的显卡编号实际上是经过了一次映射之后才会映射到真正的显卡编号上面的, 例如这里的程序看到的1对应实际的5。
但是Dataparallel会带来显存的使用不平衡,具体分析见参考链接[2],而且碰到大的任务,时间和能力上都很受限。
DDP
为了弥补Dataparallel的不足,有torch.nn.parallel.DistributedDataParallel,这也是现在Pytorch分布式训练主推的。
DDP支持单机多卡和多机多卡,每张卡都有一个进程,这就涉及到进程通信,多进程通信初始化,是DDP使用最复杂的地方。
具体看下:
1 | torch.distributed.init_process_group( ) |
详见:https://pytorch.org/docs/stable/distributed.html
常用参数:
-
backend: 后端, 实际上是多个机器之间交换数据的协议,官方和很多用户都强烈推荐’nccl’作为backend。
-
init_method: 机器之间交换数据需要指定一个主节点, 这个参数用来指定主节点的。
-
world_size: 参与job的进程数, 实际就是GPU的个数;
-
rank: 进程组中每个进程的唯一标识符。比如一个节点8张卡,world_size为8,每张卡的rank是对应的0-7的连续整数。
-
顺便解释下local_rank: 假设有两个节点/机器,每个节点有8张卡,总共16张卡,对应16个进程。global rank是指0-15,对于节点1,local_rank为0-7,对于节点2,local_rank也是0-7。
初始化init_method的方法有两种, 一种是使用TCP进行初始化, 另外一种是使用共享文件系统进行初始化。
Pytorch作者推荐了这种初始化方式,来源见水印和参考链接,
我们平常在集群上操作,可以通过os.environ获取每个进程的节点ip信息,全局rank以及local rank。
关于获取节点信息的详细代码:
1 2 3 4 5 6 7 | import os os.environ['SLURM_NTASKS'] # 可用作world size os.environ['SLURM_NODEID'] # node id os.environ['SLURM_PROCID'] # 可用作全局rank os.environ['SLURM_LOCALID'] # local_rank os.environ['SLURM_NODELIST'] # 从中取得一个ip作为通讯ip |
单机多卡,主机只有一个相对没有那么复杂,按照官网推荐的设置就好。
因此,torch中DDP的使用如下方式:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | import os import re import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # 1. 获取环境信息 rank = int(os.environ['SLURM_PROCID']) world_size = int(os.environ['SLURM_NTASKS']) local_rank = int(os.environ['SLURM_LOCALID']) node_list = str(os.environ['SLURM_NODELIST']) # 对ip进行操作 node_parts = re.findall('[0-9]+', node_list) host_ip = '{}.{}.{}.{}'.format(node_parts[1], node_parts[2], node_parts[3], node_parts[4]) # 注意端口一定要没有被使用 port = "23456" # 使用TCP初始化方法 init_method = 'tcp://{}:{}'.format(host_ip, port) # 多进程初始化,初始化通信环境 dist.init_process_group("nccl", init_method=init_method, world_size=world_size, rank=rank) # 指定每个节点上的device torch.cuda.set_device(local_rank) model = model.cuda() # 当前模型所在local_rank model = DDP(model, device_ids=[local_rank]) # 指定当前卡上的GPU号 input = input.cuda() output = model(input) # 此后训练流程与普通模型无异 |
最近官方表述中加了一个store参数,更新了下使用方法,大差不差。
具体参考:https://pytorch.org/docs/stable/distributed.html
使用TCP进行初始化,需要读取ip,我们在集群上通过os.environ可以很方便完成初始化。
我平常提交任务的slurm指令这样写:
1 2 3 4 | # 单机多卡 # 8个任务对应8个进程,每个节点上跑8个任务 srun -n8 --gres=gpu:8 --ntasks-per-node=8 python train.py |
1 2 3 4 5 | # 多机多卡 # 16个任务对应16个进程,每个节点最多跑8个任务/进程,每张卡占满8个GPU # 因此这里是申请了16/8=2个节点,即在两个机器上跑。 srun -n16 --gres=gpu:8 --ntasks-per-node=8 python train.py |
参考:
[1]https://blog.csdn.net/weixin_40087578/article/details/87186613
[2]https://zhuanlan.zhihu.com/p/86441879
[3]https://zhuanlan.zhihu.com/p/68717029