从零开始的 Ray 框架之旅(三)——分布式训练

本章,我们将通过一个完整的分布式训练例子,具体阐释 Ray 在分布式神经网络模型训练场景中的应用。这个例子将帮助你理解,Ray 是如何成为 OpenAI 在训练 ChatGPT 时钦定的分布式计算框架的。

1. 几个常见分布式计算框架的简要对比

老牌的Spark框架虽然在大数据处理领域非常成熟,但其核心设计更偏向于批处理和 SQL 分析,对于需要高度灵活性、动态调度和复杂并行模型的大规模 AI 训练,其显得非常笨重和受限。
Dask 在并行化现有 Python 代码和科学计算方面非常出色,虽然其比Spark灵活,然而其缺乏原神的 Actor 模型,使得其在需要维护和交互大量独立状态单元的复杂系统(如RLHF等)时非常受限,且对于需要高度定制化、动态变化的分布式计算流程,尤其是AI训练中的复杂并行策略和自适应行为,Dask显得就不那么灵活了。
Spark的 RDD/DataFrame 这种高级抽象相比,Ray 的 Task 和 Actor 提供了更底层的控制,允许开发者更精细地管理资源和执行流程,这对于优化尖端模型的性能至关重要。 OpenAI 的需求不仅仅是数据处理,还包括模型训练、强化学习、模拟、模型部署等。Ray 的通用性使其能够支撑整个机器学习生命周期,从数据预处理到模型服务,这使得 OpenAI 可以用它来构建各种复杂的后台系统和研究平台。

接下来,我们以鸢尾花数据集训练一个多分类模型为例子,来看看 Ray 是怎么完成机器学习的整个生命周期的(包括数据预处理、模型训练、超参数调优和模型服务部署)。

2. 数据预处理

现有的大模型预训练数据集的量级都 TB 级别的,单纯靠单机处理已经不现实了,必须依赖分布式水平扩展才能完成大规模数据预处理工作。而数据预处理对大语言模型的预训练至关重要,高效和高质量的数据预处理是构建高智商大语言模型必不可少的一个环节。

下面,我们完成一个简单的数据清洗流程,包括异常值处理,缺失值处理和特征合成三个流程。

from typing import Dict
import numpy as np
import ray
import os

# ——————————————————————— Ray 初始化 ———————————————————————  #

# 这里如果是多机,可以使用环境变量缺省addr参数,或者是使用ray.init(address="auto")
ray.init(address="local")

# ——————————————————————— 数据读取 ———————————————————————  #
# 这里如果是多机,读取目录一定要是所有节点都能访问的目录,比如hdfs,nfs,s3等
ds = ray.data.read_csv("./iris.csv")

# ——————————————————————— 全局统计量计算 ———————————————————————  #

# 计算全局统计量
numeric_cols = ["sepal.length", "sepal.width", "petal.length", "petal.width"]

# 计算全局的统计量
means = {c: ds.mean(c) for c in numeric_cols}
stds = {c: ds.std(c) for c in numeric_cols}

# 统计量现在只在Driver进程中,需要通过put存储到对象存储中,远端执行match_baches的worker进程才能拿到
mean_id = ray.put(means)
stds_id = ray.put(stds)

# 这里的id是对象存储的id,是一个ObjectRef,需要通过get获取
print(ray.get(mean_id))
print(ray.get(stds_id))

# ——————————————————————— 算子批处理 ———————————————————————  #
import pandas as pd
import numpy as np

# 1.缺失值处理算子
def fillna_batch(df: pd.DataFrame, means):
    return df.fillna(means)

ds_filled = ds.map_batches(
    fillna_batch,
    fn_kwargs={"means": ray.get(mean_id)},
    batch_format="pandas",
    num_cpus=0.5,
)

# 2. z-score 过滤算子
def filter_outliers_batch(df: pd.DataFrame, means, stds):
    mask = np.ones(len(df), dtype=bool)
    for col in numeric_cols:
        mask &= (df[col] >= means[col] - 3 * stds[col]) & (df[col] <= means[col] + 3 * stds[col]) # 3sigma原则
    return df[mask]

ds_filtered = ds_filled.map_batches(
    filter_outliers_batch,
    fn_kwargs={"means": ray.get(mean_id), "stds": ray.get(stds_id)},
    batch_format="pandas",
    num_cpus=0.5,
)

# 3. 新特征合成算子
def add_area_batch(df: pd.DataFrame):
    # Ray2.10起copy-on-write优化,这里不需要深拷贝
    df["petal.area"] = df["petal.length"] * df["petal.width"]
    return df

ds_final = ds_filtered.map_batches(
    add_area_batch,
    batch_format="pandas",
)

# ——————————————————————— 数据保存 ———————————————————————  #
# 4. 数据保存
ds_final.write_json("iris_cleaned/", num_rows_per_file=200)

首先,我们对 Ray 进行一个初始化,如果你使用多机处理,可以直接省略 init 方法中的参数,然后,我们读取数据集,注意:这里数据集的目录一定要是共享存储,否则其他节点无法访问这个数据源会报错。
然后,我们使用全局统计的方法统计数据源各个 Feature 的均值和标准差,用于后续算子来处理缺失值和过滤异常数据,注意:这里我们需要把计算的指标存储对象存储中,不然它只会存在于 Driver 进程的内存中,其他 worker 进程是无法访问的。
接下来,我们使用自定义的算子配合 Ray Dataset的 map_batchs 方法进行处理,我们需要制定 batch_format 来告诉 Ray 我们需要传入这个格式的数据,以及算子需要的额外数据我们使用 fn_kwargs 来传入,注意:这里本来可以直接传入 ObjectRef,但 Ray 存在 bug,且在最新的 Ray 版本中,fn_kwag无法自动解引用,这个bug依然没有解决。

以及,在特征合成算子中,其利用了copy-on-write(写时复制)机制,下面会进行补充说明。
最后,我们将数据存储在指定目录(这里如果多机,也需要存储在共享目录中,方便后续处理)。

1.1 Copy On Write 机制

简单来说,CoW 机制就是「先共享,写时再拷贝」,在写操作时,相较于深拷贝(立即复制整个数据结构)或部分浅拷贝(立即复制顶层结构但共享底层数据),CoW 在处理写操作时更为高效。当数据需要修改时,只有被修改的部分才会被复制,在此之前,数据是共享的,这使得 CoW 在内存开销上非常经济。
当 Ray Data 将底层的 Arrow 数据块转换为其算子(例如 map_batches 中的函数)所使用的 pandas DataFrame 时,生成的 DataFrame 中的列通常是零拷贝视图 (zero-copy views)。这些视图直接指向原始的 Arrow 内存区域,意味着:

  • 只读操作:在不修改数据的前提下,无论是多个并行的 map_batches 算子,还是同一算子内处理的多条记录,都可以同时、直接地读取这块共享内存。这完全避免了数据复制的开销,实现了高效的零成本数据访问。
  • 写入操作:当首次尝试修改 DataFrame 中的某一列或特定区域时,pandas 的 copy-on-write 机制便会启动。此时,包含被修改数据的那整个数据块(Block,通常对应一列或一组相同类型的列)会被复制到一块新的内存区域,并标记为可写状态。后续的写入操作将作用于这个新生成的副本。关键在于,DataFrame 中其他未被修改的列(或数据块)仍然共享原有的只读内存。同时,其他算子或后续操作所持有的 DataFrame 若引用的是旧数据,则它们依然指向旧的、未经修改的内存区域。这种方式有效地保证了数据的隔离性,防止了意外的数据污染,并最大限度地减少了不必要的内存拷贝。

实际上,当 map_batches 算子利用 CoW 处理数据并返回一个 DataFrame(此时为视图 + CoW状态)时,Ray 会立即将此 DataFrame 转换回 pyarrow.Table 格式并存入其对象存储。转换规则如下:

  • 旧列:直接被包装成 Arrow 的 ChunkedArray,通过零拷贝方式继续指向原始的 Arrow 内存缓冲区。
  • 新列:由于 CoW 机制,这些列的数据已经存在于新的内存区域(通常是 pandas/NumPy 管理的内存)。这部分数据会被复制(或者在特定条件下进行零拷贝转换)到一个新的 Arrow 内存缓冲区,然后包装成 ChunkedArray

至此,「视图 + CoW」这一语义已经彻底消失,取而代之的是标准的、基于列的 ChunkedArray 结构。换言之,上游算子在 pandas层面获得的「视图 + CoW」状态不会传递给 Ray 的下一个算子;下游算子只会感知到标准的 Arrow 列数据。

1.2 Apache Arrow

在 Apache Arrow 中,每一列都是 ChunckArray

  • 每个 ChunkedArray 由一个或多个 chunk 组成,每个 chunk 是一段物理上连续的内存缓冲区。
  • 当对表进行追加、切片或多批次拼接等操作时,Arrow 通常只操作元数据(如指针列表),使得逻辑上的一列数据可能由多个物理上分散的 chunk 构成。
  • combine_chunks() 函数的作用是:为指定的一列分配一块足够大的新连续内存缓冲区,然后将该列所有 chunk 的数据按顺序复制到这个新缓冲区中,使得此列最终只包含一个 chunk。这个操作有利于后续的顺序扫描、数据压缩或磁盘写入,但其代价是一次该列数据的全量复制。

总结来说:

  • Apache Arrow (combine_chunks)

    • 原来列 a 有 [chunk0chunk1] 两段 buffer。
    • combine_chunks() ➜ 把它们复制到一块连续 buffer,列 a 现在只有 chunk0' 一个指针。
    • 此操作针对每一列独立进行,不涉及将多个不同列的数据混合或合并到单个统一的内存块中
  • pandas + CoW

    • 执行如 df_view = df[['a', 'b']] 这样的操作(视图创建)时,df_view 最初会共享 df 中列 a 和列 b 的底层数据(通过共享指向这些数据块的指针,这些数据块由 pandas 的 BlockManager 管理)。
    • a 和列 b 的数据在物理上仍然是独立的(除非它们原本就在同一个数据块中,例如因数据类型相同而被pandas整合),并不会因为创建视图而被「粘合」成一个更大的数组。
    • 当执行如 df_view.loc[0, 'a'] = 999 这样的写入操作时,pandas 的 CoW 机制会针对列 a 所在的数据块(Block)创建一个副本,用于承载这个新值以及该块中原有的其他数据。

关键差异对比:

  1. 操作粒度

    • Arrow 的 combine_chunks()列内操作,合并一个列内部的多个 chunk。
    • pandas CoW 的复制是当列所在的数据块(Block)发生写入时触发,复制的是整个数据块。
  2. 主要目标

    • Arrow (combine_chunks()) 追求的是物理内存的连续性,这对于提升顺序I/O性能、实现零拷贝传输至关重要。
    • pandas CoW 追求的是延迟复制,以节省内存开销,仅在真正需要写入时才付出复制成本。
  3. 跨列行为:

    • Arrow 的 combine_chunks() 不进行跨列合并。每个列的 chunk 合并是独立进行的。
    • pandas 内部可能会将数据类型相同 (same dtype) 的多个列存储在同一个二维数据块 (Block) 中(此过程称为 consolidation)。然而,这与 CoW 机制本身是正交的。创建视图或进行写入操作本身并不会自动导致不同数据块中的列被合并。CoW 机制作用于这些已存在的块。

    2. 分布式训练

评论区
头像