Petals

时间:2023-03-21
本文章向大家介绍Petals,主要内容包括负载均衡、Server/、block_selection.py、block_utils.py、memory_cache.py、task_pool.py、backend.py、handler.py、Client、使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

Petals系统阅读报告

负载均衡

Kademila:P2P 网络核心技术:Kademlia 协议
Kademlia协议这篇文章有些混淆了xor和dis的区别.dis是xor之后的lcp.
distance是xor之后的公共前缀长度.distance越大,距离越近.

Server/

Tree:
server/
├── backend.py √
├── block_selection.py √
├── block_utils.py √
├── handler.py
├── init.py √
├── memory_cache.py √
├── reachability.py
├── server.py
├── task_pool.py √
├── task_prioritizer.py
└── throughput.py

block_selection.py

定义了span类和一个接口

class Span: 
    ...

def should_choose_other_blocks(
    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
) -> bool: 
    ...

其中should_choose_other_blocks()返回一个bool值,表示是否需要选择其他block.
算法的主要目的是:加入吞吐量最差的一个部分,以提高系统的性能瓶颈.也包含了rebalancing的部分.

(P5,C2)Formally, servers maximize the total model throughput by choosing the blocks with the worst throughput and eliminating potential bottlenecks.

block_utils.py

定义了get_block_size()->int函数,获取block大小.

memory_cache.py

本部分定义了一个MemoryCache类,定义为:

class MemoryCache:
    def __init__():
        ...

    @contextlib.asynccontextmanager
    async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
        ...

    @staticmethod
    def get_allocation_size(*descriptors: TensorDescriptor) -> int:
        ...
    
    async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
        """
        This method should be called inside asyncio.shield() because:
            - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
        返回值:tuple, 内含此次被分配的tensor对应的handle_counter.
        """
        ...
    
    @contextlib.contextmanager
    def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]:
        ...
    
    

其中的核心函数是 _schedule_alloc().

该类的主要用途:被connection handler和runtime process调用.前者对分配情况做记录,并将分配结果通过pipe发送.后者接收到后,进行真正的分配.自带free功能.

task_pool.py

本部分定义了一个由优先队列维护的任务池.

它聚合来自多个ConnectionHandler实例的请求,然后在运行时中将它们排序以进行处理,返回结果(或异常)到相应的ConnectionHandler.

一个任务池服务于模型的一个特定的层(如layer1.forward, layer2.backward)等.

队列中的任务类型被定义为Task,拥有优先级,提交时间,任务状态,任务参数和uid五个成员变量.

class Task:
    priority: float
    time_submitted: float
    future: MPFuture = field(compare=False)
    args: Sequence[torch.Tensor] = field(compare=False)

    @property
    def uid(self) -> int:
        return self.future._uid

PrioritizedTaskPool类由hivemind.moe.server.task_pool.TaskPool派生而来,本质上是thread.Thread的子类.它维护了一个普通队列submitted_tasks用以记录提交的任务,维护一个优先队列_ordered_tasks决定当前最先被执行的任务.

  • submit_task()方法用以提交一个任务进入submitted_tasks.
  • _prioritize_tasks()方法将普通队列的任务放入优先队列.
  • load_batch_to_runtime()方法将_ordered_tasks中首个任务出列,处理其参数,放入指定device等,并更新任务池状态.
  • send_outputs_from_runtime()方法将任务执行结果写回cpu.
class PrioritizedTaskPool(TaskPoolBase):
    self.submitted_tasks: SimpleQueue
    self._ordered_tasks: PriorityQueue
    ...
    @staticmethod
    def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
        ...
    def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
        ...
    def load_batch_to_runtime(
        self, timeout: Optional[float] = None, device: Optional[torch.device] = None
    ) -> Tuple[Any, List[torch.Tensor]]:
        ...
    def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
        ...
    

backend.py

这部分主要是在hivemind.moe.server.module_backend.ModuleBackend的基础上派生了TransformerBackend类,它包装了一个BLOOM block,并能够处理对它的forward, backward, inference请求.

其中, forwardbackward方法已经在ModuleBackend类中实现.它的主要新增功能就是inference_step()方法:

@torch.inference_mode()
    def inference_step(
        self,
        hidden_states: torch.Tensor,
        hypo_ids: torch.LongTensor,
        inference_info: InferenceMetadata,
    ) -> Tuple[torch.Tensor, ...]:
        assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
        with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
            self._reorder_cache_inplace(cache_tensors, hypo_ids)
            layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
            hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
            self._update_cache_inplace(cache_tensors, new_kvs, inference_info.prefix_length)
            return (hidden_states,)

它接受隐藏层以及一些其他与推理相关的信息作为输入,输出计算后的数据.

在这里,self.memory_cache.use_cache()方法返回一个生成器,内容是在MemoryCache中所有已经被分配内存的张量.将其顺序进行整理后,根据inference.info中要求的prefix_length获得过去的layer数据,调用forward()函数计算.这里self.module的类型是BloomAttention.随后根据结果更新cache.

handler.py

Client

remote_model.py

DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
这是示例中演示的模型.它的成员有一部分transformer模型位于swarm上.

class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
    def __init__(self, config: DistributedBloomConfig):
        BloomPreTrainedModel.__init__(self, config)
        self.transformer = DistributedBloomModel(config)
        self.lm_head = LMHead(config, self.transformer.word_embeddings)

        # Initialize weights and apply final processing
        self.post_init()

_LowCPUMemoryMixin是一个wrapper,将使用的transformers.PreTrainedModel中的low_cpu_mem_usage选项启动.

low_cpu_mem_usage algorithm:
This is an experimental function that loads the model using ~1x model size CPU memory
Here is how it works:
1. save which state_dict keys we have
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
3. after the model has been instantiated switch to the meta device all params/buffers that
are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors

其中RemoteGenerationMixin包含了自回归文本生成所需要的算法或函数,如greedy search, beam_search等.

随后拆分一下DistributedBloomModel.

class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
    """BloomModel, but all transformer layers are hosted by the swarm"""
    def __init__(self, config: DistributedBloomConfig):
        ...
        # 获得DHT并以此建立一个Sequential
        dht = config.dht if config.dht is not None else hivemind.DHT(...)
        self.h = RemoteSequential(config, dht, config.dht_prefix, )
        ...
        # 对微调的方法进行设置
        if config.tuning_mode and "ptune" in config.tuning_mode:
            ...
    def get_prompt():
        # 获取ptune过程中需要的prompt
        ...

    def forward(self, ..., inputs_embeds: Optional[torch.Tensor] = None,...):
        ...
        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
        hidden_states = self.h(hidden_states)
        hidden_states = self.ln_f(hidden_states)

        return BaseModelOutputWithPastAndCrossAttentions(   # transformers中的模块
            last_hidden_state=hidden_states,
            ...
        )

init函数中获取config中的DHT(若没有则自己创建一个),并对微调方法进行相应设置.
get_prompt()方法获取ptune过程中需要的prompt.
forward()方法将hidden_states经过transformer计算,最终调用transformers中的模块

下面拆解RemoteSequential

class RemoteSequential(nn.Module):
    """
    A sequence of transformer blocks hosted by the swarm.
    """
    def __init__(
        self,
        config: petals.client.DistributedBloomConfig,
        dht: DHT,
        dht_prefix: Optional[str] = None,
        p2p: Optional[P2P] = None,
        sequence_manager: Optional[RemoteSequenceManager] = None,
        **kwargs,
    ):
        self.config = config
        self.dht = dht
        self.dht_prefix = dht_prefix or config.dht_prefix
        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
        ...
        if sequence_manager is None:
            self.sequence_manager = RemoteSequenceManager(...)
        else:
            self.sequence_manager = sequence_manager

    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
        assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
        assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
        return outputs

下面拆解_RemoteSequentialAutogradFunction

class _RemoteSequentialAutogradFunction(torch.autograd.Function):
    """
    PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
    This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
    """
    @staticmethod
    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
        batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
        input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
        if is_dummy(prompts):
            prompt_batches = [DUMMY] * len(input_batches)
        else:
            prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)

        sequence_manager.rpc_info  # lazy init
        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
        output_batches = [output[0] for output in outputs]
        ...
        return torch.cat(output_batches, dim=0)

gather_forward()调用了run_remote_forward(), 下面看之

async def run_remote_forward(
    uid: ModuleUID,
    stub: StubBase,
    rpc_info: RPCInfo,
    *inputs: torch.Tensor,
    timeout: float,
    metadata: Optional[bytes] = None,
    **kwargs,
) -> Tuple[torch.Tensor, ...]:
    """
    Serializes input tensors and calls "rpc_forward" on a remote server.
    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
    """

    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
    # detach to avoid pickling the computation graph
    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}

    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
    forward_inputs = (inputs, kwargs)

    # Modify forward_schema to support prompts
    args_schema, kwargs_schema = rpc_info["forward_schema"]
    # TODO: rm this assert when support arbitrary number of input tensors
    assert len(args_schema) == 1 and len(inputs) == 2
    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)

    if not nested_compare(forward_inputs, forward_schema_with_prompts):
        raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")

    forward_inputs = nested_flatten(forward_inputs)
    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)

    # Asynchronous serialization
    loop = asyncio.get_running_loop()
    serialized_tensors = await asyncio.gather(
        *(
            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
        )
    )

    # call RPC on remote server
    size = sum(t.element_size() * t.nelement() for t in inputs)
    forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
    # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
    return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])


调用关系

async def sequential_forward(
    inputs: torch.Tensor,
    prompts: torch.Tensor,
    sequence_manager: RemoteSequenceManager,
    start_index: int = 0,
    end_index: Optional[int] = None,
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
    """
    Constructs a routing path from <start_index> to <end_index>.
    Performs chained forward for each subsequence of blocks on the path.
    If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
    """
    ...
    block_idx = start_index
    while block_idx < end_index:
        for ...:
            sequences = deque(...)      # block index的双向队列
            span = sequences.popleft()
            stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
            inputs_and_prompts = [inputs, prompts[span.start : span.end]]   
            span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
            metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
            
            (outputs,) = await run_remote_forward(
                span_uids,
                stub,
                sequence_manager.rpc_info,
                *inputs_and_prompts,
                timeout=sequence_manager.request_timeout,
                metadata=MSGPackSerializer.dumps(metadata),
            )

下面拆解RemoteSequenceManager

文件列表

client/
├── inference_session.py
├── init.py
├── remote_forward_backward.py run_remote_forward(), call _forward_stream()
后者call TransformerConnectionHandler.rpc_forward_stream()[位于/server/handler:line 226], 继续调用
_rpc_forward()[位于同一文件], 继续调用backend.forward_pool.submit_task(),其中backendTransformerBackend类型

├── remote_generation.py RemoteGenerationMixin

├── remote_model.py DistributedBloomConfig, _LowCPUMemoryMixin
DistributedBloomModel,DistributedBloomForCausalLM
DistributedBloomForSequenceClassification

├── remote_sequential.py RemoteSequential, RemoteTransformerBlock
├── routing
│ ├── init.py
│ ├── sequence_info.py
│ ├── sequence_manager.py RemoteSequenceManager
│ └── spending_policy.py
└── sequential_autograd. _RemoteSequentialAutogradFunction, _gather_forward(), sequential_forward()[call run_remote_forward()]

原文地址:https://www.cnblogs.com/linxiaoshu/p/17238542.html