colocate的作用是使多个Worker共享相同的资源池。当然,目前verl中所有模型的Worker都共享相同的资源池:global_pool
。这篇博客主要通过例子和源代码理解verl中colocate的实现,需要一些前置知识。建议先阅读
【AI Infra】【RLHF框架】一、VeRL中基于Ray的执行流程源码解析
一、一个例子
这里简单修改了verl
的单元测试作为示例,先直观感受下colocate的作用。
import rayfrom verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_clsfrom verl import DataProto@ray.remote
class Actor(Worker):def __init__(self) -> None:super().__init__()@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)def add(self, data: DataProto):data.batch['a'] = data.batch['a'].to("cuda")data.batch['a'] += self.rankreturn data@ray.remote
class Critic(Worker):def __init__(self, config) -> None:super().__init__()self.config = config@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)def sub(self, data: DataProto):data.batch['a'] = data.batch['a'].to("cuda")data.batch['a'] -= self.config['b']return datadef test_colocated_workers():ray.init()import torch# 构建一个DataProto,其中属性a是维度为10的零向量。data = DataProto.from_dict({'a': torch.zeros(10)})print(data.batch["a"])# 利用RayClassWithInitArgs将自定义的worker和参数封装起来actor_cls = RayClassWithInitArgs(cls=Actor)critic_cls = RayClassWithInitArgs(cls=Critic, config={'b': 10})# 定义资源池,仅包含一个2GPU的节点resource_pool = RayResourcePool(process_on_nodes=[2])# 利用create_colocated_worker_cls将自定义的两个worker绑定到WorkerDict上cls_dict = {'actor': actor_cls, 'critic': critic_cls}ray_cls_with_init = create_colocated_worker_cls(cls_dict)# 启动WorkerDictwg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)# 分别获取actor和critic的workergroupspawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())colocated_actor_wg = spawn_wg['actor']colocated_critic_wg = spawn_wg['critic']# actor执行add、critic执行subactor_output = colocated_actor_wg.add(data)critic_output = colocated_critic_wg.sub(data)# actor_output.batch["a"]==[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]# critic_output.batch["a"]==[-10, -10, -10, -10, -10, -10, -10, -10, -10, -10]print(actor_output.batch["a"])print(critic_output.batch["a"])ray.shutdown()if __name__ == '__main__':test_colocated_workers()
1. Actor和Critic的解释
Actor
和Critic
的定义比较简单,在add
和sub
方法上使用了装饰器@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
。这个装饰器的作用就是数据并行,即将数据chunk后交由定义worker的两个实例分别计算后,再将结果合并。下面是Actor
执行过程的示意图,Crtic
类似:
2. Actor和Critic绑定至WorkerDict
通过打印一些元信息,可知ray_cls_with_init
是RayClassWithInitArgs
,其中持有cls是一个定义为WorkerDict
的类,其基类同样是Worker
。
cls_dict = {'actor': actor_cls, 'critic': critic_cls}
ray_cls_with_init = create_colocated_worker_cls(cls_dict)
print(type(ray_cls_with_init)) # RayClassWithInitArgs
print(ray_cls_with_init.cls.__ray_actor_class__) # WorkerDict
print(ray_cls_with_init.cls.__ray_actor_class__.__base__) # Worker
print(ray_cls_with_init.cls.actor_add)
print(ray_cls_with_init.cls.critic_sub)
3. 启动WorkerDict并执行操作
# 启动WorkerDict
wg_dict = RayWorkerGroup(resource_pool=resource_pool,ray_cls_with_init=ray_cls_with_init)
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
# 获得actor和critic的RayWorkerGroup
colocated_actor_wg = spawn_wg['actor']
colocated_critic_wg = spawn_wg['critic']
# 执行
actor_output = colocated_actor_wg.add(data)
critic_output = colocated_critic_wg.sub(data)
正如上一篇文章介绍,通过RayWorkerGroup
启动WorkerDict
。spawn
方法会返回一个RayWorkerGroup
的字典,在这个例子中spawn_wg
的值为:
{'actor': <verl.single_controller.ray.base.RayWorkerGroup object at 0x7f2efc719790>, 'critic': <verl.single_controller.ray.base.RayWorkerGroup object at 0x7f2f0e125dd0>}
获得actor
和critic
的RayWorkerGroup
后直接执行操作即可。
二、create_colocated_worker_cls
源码解析
先了看一下稍微简化了一些的源码:
# 原始代码位于verl/single_controller/ray/base.py
def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]):# {"actor": Actor, "critic": Critic}cls_dict = {}# {'actor': {'args': (...), 'kwargs': {}}, 'critic': {'args': (...), 'kwargs': {...}}}init_args_dict = {}# worker_cls是指Actor和Critic,其实就是Workerworker_cls = Nonefor key, cls in class_dict.items():worker_cls = cls.cls.__ray_actor_class__.__base__cls_dict[key] = cls.clsinit_args_dict[key] = {'args': cls.args, 'kwargs': cls.kwargs}class WorkerDict(worker_cls):def __init__(self):super().__init__()self.worker_dict = {}for key, user_defined_cls in cls_dict.items():# 去除掉ray.remote的包装,这里user_defined_cls就是Actor和Critc两个类user_defined_cls = _unwrap_ray_remote(user_defined_cls)with patch.dict(os.environ, {'DISABLE_WORKER_INIT': '1'}):self.worker_dict[key] = user_defined_cls(*init_args_dict[key].get('args', ()), **init_args_dict[key].get('kwargs', {}))for key, user_defined_cls in cls_dict.items():user_defined_cls = _unwrap_ray_remote(user_defined_cls)_bind_workers_method_to_parent(WorkerDict, key, user_defined_cls)remote_cls = ray.remote(WorkerDict)remote_cls = RayClassWithInitArgs(cls=remote_cls)return remote_cls
1. 类WorkerDict
可以看到WorkerDict
是定义在create_colocated_worker_cls
内部的类,其初始化方法__init__
中核心就是构建self.worker_dict
。在本文例子中,key就是actor
和critic
,对应的value就是Actor
和Critic
的实例。
2. _bind_workers_method_to_parent
一句话来说明这个函数的功能:将user_defined_cls
中使用装饰器register
的方法绑定到WorkerDict
,key
是方法绑定至WorkerDict
的前缀。下面是源码:
# 原始代码位于verl/single_controller/ray/base.py
def _bind_workers_method_to_parent(cls, key, user_defined_cls):for method_name in dir(user_defined_cls):if hasattr(method, MAGIC_ATTR):# 遍历user_defined_cls的所有方法,找到使用装饰器`register`装饰的方法def generate_function(name):def func(self, *args, **kwargs):# dispatch to the actual workerreturn getattr(self.worker_dict[key], name)(*args, **kwargs)return funcfunc = generate_function(method_name)# 将原始函数`add`和`sub`的MAGIC_ATTR绑定到func上setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))try:method_name_with_prefix = key + '_' + method_namesetattr(cls, method_name_with_prefix, func)except Exception as e:raise ValueError(f'Fail to set method_name {method_name}')
_bind_workers_method_to_parent
实现可能比较难理解一些,先跳过generate_function
。
setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR))
这行代码主要是将add
或者sub
的装饰器提供的信息复制到func上。
method_name_with_prefix = key + '_' + method_name
setattr(cls, method_name_with_prefix, func)
这段代码就是将Actor
的方法add
或者Critic
的方法sub
绑定至WorkerDict
上。
想理解generate_function
的功能,最好直接看通过最终组装出的WorkerDict
类。这里展示add
绑定到WorkerDict
的等价代码:
class WorkerDict(worker_cls):def __init__(self):...def actor_add(self, *args, **kwargs):return getattr(self.worker_dict["actor"], "add")(*args, **kwargs)
三、spawn
的作用
回顾一下开始例子中的这段代码:
wg_dict = RayWorkerGroup(resource_pool=resource_pool,ray_cls_with_init=ray_cls_with_init)spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
print(spawn_wg["actor"].workers)
print(spawn_wg["critic"].workers)
wg_dict = RayWorkerGroup(resource_pool=resource_pool,ray_cls_with_init=ray_cls_with_init)
这行代码会启动两个WorkerDict
的两个远程实例(这里不使用Ray中actor的称呼是因为和示例中的actor会混淆)。在调用spawn
后返回了字典spawn_wg
,打印spawn_wg["actor"].workers
和spawn_wg["critic"].workers
可以发现这两个workergroup中持有的workers都是相同的。那么,spawn
中作用是什么?来看下源码:
def spawn(self, prefix_set):def _rebind_actor_methods(worker_group, actor_name):prefix: str = actor_name + '_'for method_name in dir(worker_group):if method_name.startswith(prefix):# only valid when Python >= 3.9original_method_name = method_name.removeprefix(prefix)method = getattr(worker_group, method_name)setattr(worker_group, original_method_name, method)new_worker_group_dict = {}for prefix in prefix_set:# 从现有的workers中填写出名字为self._worker_names的worker并构成新的RayWorkerGroupnew_worker_group = self.from_detached(worker_names=self._worker_names,ray_cls_with_init=self.ray_cls_with_init)# 将带有前缀的方法名,移除前缀。例如:`actor_add`->`add`_rebind_actor_methods(new_worker_group, prefix)new_worker_group_dict[prefix] = new_worker_groupreturn new_worker_group_dict
self.from_detached
作用是利用现有的worker构造一个新的RayWorkerGroup
,其并不会启动新的Worker。_rebind_actor_methods
则是将actor_add
这种带前缀的方法名改为add
,然后将add
绑定到新的RayWorkerGroup
上。
所以,spawn
的作用就是保证可以像非colocate那么的方式来执行具体的功能。