DataEngine

1. DataEngine 简介

DataEngine 是 LLaMA-Factory v1 数据处理的核心类,继承自 PyTorch 的 Dataset,负责各种插件的接入,其他功能(如数据格式转换、数据加载等)均通过插件的形式实现并接入 DataEngine

DataEngine接受一个唯一入参:DataArguments 实例,所有的元数据集信息均通过该参数配置传入。

2. DataEngine 与 DataArguments 接口定义

@dataclass
class DataArguments:
    """ `DataEngine`初始化入参

    args:
        dataset (str): 数据集路径,远程数据集 repo id / dataset_info.yaml 路径,或本地数据集路径/dataset_info.yaml路径
        cutoff_len (int): 数据集截止长度,即数据集最大样本采样数量
    """
    ...


class DataEngine(Dataset):
    """数据引擎(DataEngine)

    `DataEngine` 负责数据集的加载与统一管理,支持:
        - 从本地路径或 Hugging Face Hub 加载数据
        - 通过插件机制加载自定义数据
        - 构建统一的数据索引
        - 支持流式(streaming)与非流式数据访问

    attr:
        args (DataArguments): 数据参数配置
        datasets (dict[str, HFDataset]): 数据集名称到数据对象的映射
        dataset_infos (dict[str, DatasetInfo]): 数据集名称到元信息的映射
        data_index (list[tuple[str, int]]): 数据索引列表,每项为 (dataset_name, sample_index)
        streaming (bool): 是否为流式数据集
    """

    def __init__(self, data_args: DataArguments) -> None:
        """初始化 `DataEngine`

        初始化时自动执行以下步骤:
            1. 调用 `get_dataset_info`, 从 `data_args` 读取并解析数据集元信息
            2. 调用 `load_dataset`,根据配置加载数据集
            3. 调用 `build_data_index`,构建统一的索引列表

        args:
            data_args (DataArguments): 数据参数配置对象
        """
        ...

    def get_dataset_info(self) -> None:
        """从配置文件或远程仓库加载数据集元信息

        根据 `self.args.dataset` 确定数据源,数据源支持如下选项:
            - 本地 YAML 配置文件路径
            - Hugging Face Hub 上的 YAML 配置文件路径
            - 本地数据集文件路径
            - Hugging Face Hub 数据集 repo id

        """
        ...

    def load_dataset(self) -> None:
        """根据数据集元信息加载所有数据集

        每个数据集条目可以包含以下字段:
            - `hf_hub_url`: 使用 `datasets.load_dataset` 加载
            - 本地数据文件:通过 `DataLoaderPlugin` 插件加载
            - `streaming`: 是否启用流式模式

        更新:
            self.datasets (dict): 数据集名称到已加载数据对象的映射
            self.streaming (bool): 如果任一数据集为流式模式,则设置为 True
        """
        ...

    def build_data_index(self) -> None:
        """构建统一的数据索引

        为所有数据集创建全局索引列表 `(dataset_name, sample_index)`

        当启用流式模式时,生成固定长度(例如 1000)的占位索引;
        否则,为每条样本建立索引。

        插件 `DataIndexPlugin` 可根据数据集大小或权重调整索引分布
        """
        ...

    def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
        """将原始样本转换为统一格式

        根据 `dataset_info` 中的 `converter` 字段,调用对应的转换插件,
        将原始样本标准化为统一的数据结构。

        args:
            raw_sample (dict[str, Any]): 原始数据样本
            dataset_name (str): 样本所属的数据集名称

        return:
            Sample: 转换后的标准化格式样本
        """
        ...

    def __len__(self) -> int:
        """返回数据集的总样本数

        return:
            int: 数据集长度
                如果为流式数据集,返回 `-1`
        """
        ...

    def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
        """根据索引或选择器获取样本

        args:
            index (Union[int, Any]): 数据索引,int 或 list[int]

        return:
            Union[Sample, list[Sample]]: 单个样本或样本列表
        """
        ...

    def __iter__(self) -> Iterable:
        """返回数据集迭代器

        用于非流式数据集的顺序或随机访问
        流式模式下需要实现异步加载逻辑

        return:
            Iterable: 数据集迭代器。
        """
        ...

    async def __aiter__(self) -> AsyncIterable:
        """返回异步数据集迭代器

        用于流式数据集或异步数据加载场景
        允许在异步环境中以流的方式读取样本

        return:
            AsyncIterable: 异步迭代器,按顺序产出样本
        """
        ...


DataArguments 参数说明:

dataset: 数据集路径,支持本地或远程,当传入本地数据集文件路径时,需要满足该数据集为标准格式;否则需要传入 dataset_info.yaml 来配置数据集的 converter 等元信息,以告知 DataEngine 应当如何处理该数据。

cutoff_len: 数据集的截止长度,即该数据集的最大样本数量。


3. DataEngine 核心方法

3.1 get_dataset_info:加载数据元信息

根据 dataset 参数加载数据集配置,获取数据位置、数据格式、插件配置等所有数据元信息,在实例化 DataEngine 时会自动调用此方法。

3.2 加载数据集:load_dataset

遍历所有数据源,根据不同的数据源加载数据,在实例化 DataEngine 时会自动调用此方法。

for key, value in self.dataset_infos.items():
    split = value.get("split", "train")
    streaming = value.get("streaming", False)

    if "hf_hub_url" in value:
        # 从 HF Hub 加载
        dataset = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
    else:
        # 使用 DataLoaderPlugin 加载本地文件
        dataset = DataLoaderPlugin(args=self.args).auto_load_data(value)

    self.datasets[key] = dataset

3.3 build_data_index:构建数据索引

为每个数据集创建索引列表 [(dataset_name, sample_index), ...], DataIndexPlugin插件在此处被调用,可控制各数据集的采样频率、采样方式等,在实例化DataEngine时会自动调用此方法。

for dataset_name, dataset in self.datasets.items():
    # 创建基础索引
    data_index = [(dataset_name, idx) for idx in range(len(dataset))]

    # 根据 size 和 weight 调整索引
    size = self.dataset_infos[dataset_name].get("size")
    weight = self.dataset_infos[dataset_name].get("weight")
    if size or weight:
        data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)

    self.data_index.extend(data_index)

3.4 _convert_data_sample:数据格式标准化

将原始数据转换为标准格式,DataConverterPlugin插件在此处被调用,具体调用的插件由 get_dataset_info 方法获取的 converter 信息指定,若 converter 为空则假定数据集为标准格式,此方法由DataEngine__getitem__ 方法调用。

def _convert_data_sample(self, raw_sample: dict, dataset_name: str) -> Sample:
    converter = self.dataset_infos[dataset_name].get("converter")
    if converter is not None:
        # 使用指定的转换器
        from ..plugins.data_plugins.converter import get_converter
        return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
    else:
        # 已经是标准格式
        return {"_dataset_name": dataset_name, **raw_sample}

4. 初始化

DataEngine 初始化过程只需传入一个构建好的 DataArguments 即可,后续可通过该 DataEngine 访问数据集中的数据。

from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine

# 1. 创建数据参数
data_args = DataArguments(
    dataset="~/data/v1_sft_demo.jsonl",
    cutoff_len=2048
)

# 2. 初始化 Data Engine
data_engine = DataEngine(data_args=data_args)

# 3. 访问数据
sample = data_engine[0]  # 获取第一个样本

5. 数据访问方式

实例化后的DataEngine支持整数索引、列表索引、以及切片等访问方式,其数据读取用法可等价于 Python 列表。

sample = data_engine[0]  # 获取第一个样本

sample = data_engine[0:10]  # 获取前 10 个样本

sample = data_engine[[0, 5, 10]]  # 获取指定索引的样本