LLM distributed supervised fine-tuning with JAX — ROCm Blogs (amd.com)
24年1月25日,Douglas Jia 发布在AMD ROCm 博客上的文章。
在这篇文章中,我们回顾了使用 JAX 对基于双向编码器表示(BERT)的大型语言模型(LLM)进行文本分类任务微调的过程。我们探讨了在多个 AMD GPU 上并行化这一微调过程的技术,然后评估模型在测试数据集上的性能。为此,我们使用了一个基于 BERT的 cased transformer 模型和 General Language Understanding Evaluation(GLUE)基准数据集在多个 AMD GPU 上进行实验。
我们重点关注 JAX 中两个单程序多数据(SPMD)并行化方法。这两个方法是:
- 使用 pmap
函数在单个领先轴上进行简单的数据分发。
- 使用 jit
、`Mesh` 和 mesh_utils
函数在设备之间分片数据,提供更大的并行化控制。
我们主要强调第一个方法,并在文章的最后部分提供了第二个方法的详细说明。
在撰写本文时,我们参考了这个教程,我们强烈推荐阅读。
什么是监督微调?
在人工智能(AI)时代,基于Transformer架构的模型(如 BERT、GPT-3 及其后续版本)为实现各种自然语言处理(NLP)任务(如文本分类、文本生成和情感分析)的尖端性能提供了坚实的基础。然而,当这些大型预训练模型单独应用于这些特定任务时,常常表现出一定的局限性。监督微调(SFT)为解决这些局限性提供了方案。
与在大规模、多样化数据集上进行广泛无监督训练的预训练模型不同,SFT采用了一种专注且资源高效的方法。通常,这需要一个相对紧凑、高质量的数据集,该数据集精确地针对特定任务量身定制。SFT可以在不需要长时间训练的情况下,将模型性能提升到最先进的水平,因为它能够利用预训练模型所获得的广泛知识。
SFT过程包括微调模型的现有权重或添加额外参数,以确保与指定任务的复杂性保持一致。通常,这种适应会结合任务特定的层,例如为分类添加一个 softmax 层,从而增强模型解决监督任务的能力。
什么是 JAX?
JAX 是一个高性能的 Python 数值计算库。与传统的机器学习框架(如 TensorFlow 和 PyTorch)相比,JAX 的速度和效率都非常出色。JAX 利用即时编译(JIT),无缝的自动微分,以及高效向量化和并行化代码的能力,使其能简单地适配 AI 加速器(如 GPU 和 TPU)。
为什么使用 AMD GPU?
AMD GPU 因其强大的开源支持而脱颖而出,工具如 ROCm 和 HIP 使其易于适配 AI 工作流程。AMD 具有竞争力的性价比,非常适合寻求成本效益的 AI 和深度学习任务解决方案的用户。随着 AMD 在市场上的影响力不断增长,越来越多的机器学习库和框架正在添加对 AMD GPU 的支持。
硬件要求和运行环境
为了利用完成此任务所需的计算能力,我们使用AMD加速器云平台 (AAC)。AAC 是一个按需提供云计算资源和API的付费平台。具体来说,我们使用一个JAX Docker容器,其在AAC上拥有8个GPU,以充分利用先进的GPU并行计算能力。
本文是硬件无关的,这意味着要成功运行提供的代码示例,不需要访问AAC。只要您有加速器设备(如GPU或TPU),您应该能够以最小的代码修改来运行这些代码示例。如果您使用的是AMD GPU,请确保正确安装了ROCm及其兼容版本的JAX和Jaxlib。参考以下教程进行安装:
-
JAX and Jaxlib 安装: 您也可以直接通过链接拉取一个JAX Docker镜像。
代码示例:对Transformer模型进行SFT
为了演示,我们使用一个通用语言理解评估(GLUE)基准数据集Quora Question Pairs(QQP)微调一个基于transformer的LLM(如:bert-base-cased)。该数据集包含超过40万对问题,每对问题都有一个二进制注释,指示这两个问题是否是相互的复述。输入变量是两个问题的句子,而输出变量是一个二进制指标,表示这两个问题是否具有相同的含义。
安装
首先,安装所需的软件包 (%%capture
是一个 _cell magic_,它将抑制单元格的输出)。
%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git
!pip install evaluate
!pip install ipywidgets
!pip install black isort # 单元格中的格式化器;可选项
导入剩余的软件包和功能。
import os
from itertools import chain
from typing import Callable
import evaluate
import flax
import jax
import jax.numpy as jnp
import optax
import pandas as pd
from datasets import load_dataset
from flax import traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ipywidgets import IntProgress as IProgress
from tqdm.notebook import tqdm
from transformers import (
AutoConfig,
AutoTokenizer,
FlaxAutoModelForSequenceClassification,
)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
JAX 预先分配75%的GPU内存以减少首次运行JAX操作时的开销和碎片,但可能会触发内存不足(OOM)错误。为了避免OOM问题,可通过将 XLA_PYTHON_CLIENT_PREALLOCATE
标志设置为 false 来抑制默认行为。
检查是否可以通过JAX检测到GPU设备。如果不能,可能需要重新安装ROCm、JAX和Jaxlib。如果JAX安装正确,你可以看到所有请求的GPU设备,在我们的例子中是8个GPU。
jax.local_devices()
[gpu(id=0),
gpu(id=1),
gpu(id=2),
gpu(id=3),
gpu(id=4),
gpu(id=5),
gpu(id=6),
gpu(id=7)]
获取微调数据集和预训练模型检查点
指定你的微调过程的设置:数据集、预训练模型以及每个设备每批次要处理的数据量。
task = "qqp"
model_checkpoint = "bert-base-cased"
per_device_batch_size = 64
加载数据集和评估指标模块。
raw_dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)
接下来的几段代码展示了如何使用模型特定的分词器对文本数据进行分词,并加载分词后的训练和验证数据。使用与预训练模型相同的分词器确保在微调过程中相同的词会被转换为相同的嵌入向量。
重要的是,我们在原始训练数据中对训练和评估数据集进行了10%的抽样。尽管如此,QQP数据集仍然提供了足够的数据来实现令人满意的性能,并且可以在每个epoch后观察到指标的改进。这种抽样方法还加快了我们的训练过程,便于说明。
使用数据预处理函数和map包装器的批处理和并行处理功能处理训练和评估数据集。你可以在以下输出中查看分词后的数据集。
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
def preprocess_function(examples):
texts = (examples["question1"], examples["question2"])
processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
processed["labels"] = examples["label"]
return processed
# 关于如何处理和操作 huggingface 数据集的详细信息:
# https://huggingface.co/docs/datasets/process
data = raw_dataset["train"].shuffle(seed=0)
train_data = data.select(list(range(int(data.shape[0] * 0.1))))
eval_data = data.select(list(range(int(data.shape[0] * 0.1), int(data.shape[0] * 0.2))))
print(f"原始训练数据集的形状为: {data.shape}")
print(f"当前训练数据集的形状为: {train_data.shape}")
print(f"当前验证数据集的形状为: {eval_data.shape}")
原始训练数据集的形状为: (363846, 4)
当前训练数据集的形状为: (36384, 4)
当前验证数据集的形状为: (36385, 4)
train_dataset = train_data.map(
preprocess_function, batched=True, remove_columns=train_data.column_names
)
eval_dataset = eval_data.map(
preprocess_function, batched=True, remove_columns=eval_data.column_names
)
# 你可以在以下单元格的输出中查看已分词的数据集
pd.DataFrame(train_dataset[:3])
从Hugging Face下载预训练模型配置和检查点。注意,你会看到一个警告信息,指出某些模型权重未使用。这是预期的,因为BERT模型检查点是一个PreTraining模型类,而你正在初始化一个
SequenceClassification模型。警告信息指出:你可能需要在下游任务上训练该模型,以便能够将其用于预测和推理。 这就是我们接下来要关注的内容。
num_labels = 2
seed = 0
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
model_checkpoint, config=config, seed=seed
)
某些在bert-base-cased模型检查点中的权重在初始化FlaxBertForSequenceClassification时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 如果您正在从另一个任务或架构的模型检查点初始化FlaxBertForSequenceClassification(例如,从BertForPreTraining模型初始化BertForSequenceClassification模型),这是预期的。
- 如果您正在从您期望完全相同的模型检查点初始化FlaxBertForSequenceClassification(从BertForSequenceClassification模型初始化BertForSequenceClassification模型),这不是预期的。
某些在bert-base-cased模型检查点中的权重未被初始化到FlaxBertForSequenceClassification并被重新初始化: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
您可能需要在下游任务中训练此模型,以便能够使用它进行预测和推理。
定义微调模型的状态
以下代码块展示了如何设置训练参数,比如训练周期数和初始学习率。学习率调度是为了使学习率在训练过程中线性衰减,以确保学习的效率和稳定性。
num_train_epochs = 6
learning_rate = 2e-5
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)
The overall batch size (both for training and eval) is 512
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
learning_rate_function = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps
)
接下来,需要建立训练状态,包括优化器和损失函数的职责,并监督模型参数在训练过程中的更新。
使用状态对象,初始化和更新模型。当调用模型时,将状态作为输入,模型会返回通过新数据批次更新后的状态,同时保留模型实例。
Flax 提供了一个用户友好的类(`flax.training.train_state.TrainState`),它将模型参数、损失函数和优化器封装在一起。当提供数据时,它可以使用 apply_gradients
函数更新模型参数。
下面的代码块展示了如何定义和建立训练状态、优化器和损失函数。
class TrainState(train_state.TrainState):
logits_function: Callable = flax.struct.field(pytree_node=False)
loss_function: Callable = flax.struct.field(pytree_node=False)
# 创建一个 decay_mask_fn 函数,以确保对任何偏置项或 LayerNorm 权重不应用权重衰减,因为这可能不会提高模型性能甚至会有害。
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
flat_mask = {
path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))
for path in flat_params
}
return traverse_util.unflatten_dict(flat_mask)
# 标准的带权重衰减的 Adam 优化器
def adamw(weight_decay):
return optax.adamw(
learning_rate=learning_rate_function,
b1=0.9,
b2=0.999,
eps=1e-6,
weight_decay=weight_decay,
mask=decay_mask_fn,
)
def loss_function(logits, labels):
xentropy = optax.softmax_cross_entropy(
logits, onehot(labels, num_classes=num_labels)
)
return jnp.mean(xentropy)
def eval_function(logits):
return logits.argmax(-1)
# 实例化 TrainState
state = TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=adamw(weight_decay=0.01),
logits_function=eval_function,
loss_function=loss_function,
)
定义如何训练、评估模型并启用并行化
train_step
和 eval_step
参数定义了如何训练和评估模型。训练步骤遵循标准的训练过程:
-
使用当前的权重计算损失。
-
计算损失函数相对于权重的梯度。
-
使用梯度和学习率更新权重。
-
使用梯度和学习率更新权重。
需要强调的是,`lax.pmean` 函数计算跨所有 8 个 GPU 设备的数据批次梯度的均值。这个关键步骤保证了所有 GPU 设备上的模型参数同步。
def train_step(state, batch, dropout_rng):
targets = batch.pop("labels")
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_function(params):
logits = state.apply_fn(
**batch, params=params, dropout_rng=dropout_rng, train=True
)[0]
loss = state.loss_function(logits, targets)
return loss
grad_function = jax.value_and_grad(loss_function)
loss, grad = grad_function(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean(
{"loss": loss, "learning_rate": learning_rate_function(state.step)},
axis_name="batch",
)
return new_state, metrics, new_dropout_rng
def eval_step(state, batch):
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
return state.logits_function(logits)
接下来,应用 jax.pmap
函数到定义的 train_step
和 eval_step
函数。将 pmap()
应用于函数时,该函数会使用 XLA 编译(类似于 jit()
),然后在 XLA 设备上并行运行,例如多 GPU 设备或多 TPU 核。简单来说,这一步将训练和评估函数发送到所有 GPU 设备。你还需要通过 flax.jax_utils.replicate
将训练状态发送到所有 GPU 设备,这些步骤确保你通过分布式训练在所有 GPU 设备上更新模型状态。
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)
定义数据加载函数,这些函数返回数据批次生成器。在最终的训练和评估循环中,每一步都会输入一个新的数据批次。
def glue_train_data_loader(rng, dataset, batch_size):
steps_per_epoch = len(dataset) // batch_size
perms = jax.random.permutation(rng, len(dataset))
perms = perms[: steps_per_epoch * batch_size] # 跳过不完整的批次。
perms = perms.reshape((steps_per_epoch, batch_size))
for perm in perms:
batch = dataset[perm]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
def glue_eval_data_loader(dataset, batch_size):
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
基于整数种子生成伪随机数生成器(PRNG)密钥,然后将其拆分为 8 个新的密钥,以确保每个 GPU 设备都得到不同的密钥。然后运行训练步骤,以根据预定义的训练参数(如训练轮次和总批次大小)更新 state
。在每个轮次结束时,运行评估步骤,以查看评估数据集上的准确率和 F1 指标。由于使用的训练数据集比基准中的原始训练数据集要小,可以看到在前几轮训练中,评估指标(训练损失和评估准确率)稳定提升。
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
for i, epoch in enumerate(
tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):
rng, input_rng = jax.random.split(rng)
# train
with tqdm(
total=len(train_dataset) // total_batch_size, desc="Training...", leave=True
) as progress_bar_train:
for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
state, train_metrics, dropout_rngs = parallel_train_step(
state, batch, dropout_rngs
)
progress_bar_train.update(1)
# 评估
with tqdm(
total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False
) as progress_bar_eval:
for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
labels = batch.pop("labels")
predictions = parallel_eval_step(state, batch)
metric.add_batch(
predictions=list(chain(*predictions)), references=list(chain(*labels))
)
progress_bar_eval.update(1)
eval_metric = metric.compute()
loss = round(flax.jax_utils.unreplicate(train_metrics)["loss"].item(), 3)
eval_score1 = round(list(eval_metric.values())[0], 3)
metric_name1 = list(eval_metric.keys())[0]
eval_score2 = round(list(eval_metric.values())[1], 3)
metric_name2 = list(eval_metric.keys())[1]
print(
f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}"
)
Epoch ...: 0%| | 0/6 [00:00<?, ?it/s]
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.475 | Eval accuracy: 0.799, f1: 0.762
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.369 | Eval accuracy: 0.834, f1: 0.789
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.299 | Eval accuracy: 0.846, f1: 0.797
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.239 | Eval accuracy: 0.846, f1: 0.806
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.252 | Eval accuracy: 0.849, f1: 0.802
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.212 | Eval accuracy: 0.849, f1: 0.805
使用JAX设备网格来实现并行化
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
model_checkpoint, config=config, seed=seed
)
state = TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=adamw(weight_decay=0.01),
logits_function=eval_function,
loss_function=loss_function,
)
一些来自 bert-base-cased 模型检查点的权重在初始化 FlaxBertForSequenceClassification 时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 当你用模型训练其他任务或用另一种架构初始化 FlaxBertForSequenceClassification 时,这是预期中的情况(例如从 BertForPreTraining 模型初始化 BertForSequenceClassification 模型)。
- 当你期望从与 FlaxBertForSequenceClassification 模型完全相同的检查点初始化时(从 BertForSequenceClassification 模型初始化 BertForSequenceClassification 模型),这不是预期情况。
FlaxBertForSequenceClassification 中一些权重没有从 bert-base-cased 模型检查点初始化,是新初始化的: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
应该将这个模型训练到下游任务上以便用于预测和推断。
@jax.jit
def train_step(state, batch, dropout_rng):
targets = batch.pop("labels")
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_function(params):
logits = state.apply_fn(
**batch, params=params, dropout_rng=dropout_rng, train=True
)[0]
loss = state.loss_function(logits, targets)
return loss
grad_function = jax.value_and_grad(loss_function)
loss, grad = grad_function(state.params)
new_state = state.apply_gradients(grads=grad)
metrics = {"loss": loss, "learning_rate": learning_rate_function(state.step)}
return new_state, metrics, new_dropout_rng
@jax.jit
def eval_step(state, batch):
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
return state.logits_function(logits)
num_devices = len(jax.local_devices())
devices = mesh_utils.create_device_mesh((num_devices,))
# 数据将沿批处理轴进行分割
data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh
data_sharding = NamedSharding(
data_mesh,
P(
"batch",
),
) # 命名网格的轴
def glue_train_data_loader(rng, dataset, batch_size):
steps_per_epoch = len(dataset) // batch_size
perms = jax.random.permutation(rng, len(dataset))
perms = perms[: steps_per_epoch * batch_size] # 略过不完整的批处理。
perms = perms.reshape((steps_per_epoch, batch_size))
for perm in perms:
batch = dataset[perm]
batch = {
k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()
}
yield batch
def glue_eval_data_loader(dataset, batch_size):
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
batch = {
k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()
}
yield batch
# 在所有设备上复制模型和优化器变量
def get_replicated_train_state(devices, state):
# 所有变量将在所有设备上复制
var_mesh = Mesh(devices, axis_names=("_"))
# 在 NamedSharding 中,未提到的轴将被复制(此处为所有轴)
var_replication = NamedSharding(var_mesh, P())
# 应用分布设置到模型变量
state = jax.device_put(state, var_replication)
return state
state = get_replicated_train_state(devices, state)
rng = jax.random.PRNGKey(seed)
dropout_rng = jax.random.PRNGKey(seed)
for i, epoch in enumerate(
tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):
rng, input_rng = jax.random.split(rng)
# 训练
with tqdm(
total=len(train_dataset) // total_batch_size, desc="Training...", leave=True
) as progress_bar_train:
for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
state, train_metrics, dropout_rng = train_step(state, batch, dropout_rng)
progress_bar_train.update(1)
# 评估
with tqdm(
total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False
) as progress_bar_eval:
for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
labels = batch.pop("labels")
predictions = eval_step(state, batch)
metric.add_batch(predictions=list(predictions), references=list(labels))
progress_bar_eval.update(1)
eval_metric = metric.compute()
loss = round(train_metrics["loss"].item(), 3)
eval_score1 = round(list(eval_metric.values())[0], 3)
metric_name1 = list(eval_metric.keys())[0]
eval_score2 = round(list(eval_metric.values())[1], 3)
metric_name2 = list(eval_metric.keys())[1]
print(
f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}"
)
Epoch ...: 0%| | 0/6 [00:00<?, ?it/s]
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.469 | Eval accuracy: 0.796, f1: 0.759
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.376 | Eval accuracy: 0.833, f1: 0.788
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.296 | Eval accuracy: 0.844, f1: 0.795
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.267 | Eval accuracy: 0.846, f1: 0.805
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.263 | Eval accuracy: 0.848, f1: 0.804
Training...: 0%| | 0/71 [00:00<?, ?it/s]
Evaluating...: 0%| | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.222 | Eval accuracy: 0.849, f1: 0.805
本站资源均来自互联网,仅供研究学习,禁止违法使用和商用,产生法律纠纷本站概不负责!如果侵犯了您的权益请与我们联系!
转载请注明出处: 免费源码网-免费的源码资源网站 » 使用 JAX 进行 LLM 分布式监督微调
发表评论 取消回复