单卡4090通过GRPO训练QWen2.5基础模型复现Deepseek-R1关键思路

技术

本文在4090卡上复现如下 blog 提到的训练过程,这个过程体现了Deepseek-R1的关键RL思路:

环境搭建

参见上上篇文章:单卡 RTX 4090 用 unsloth 和医学数据微调 DeepSeek-R1-Distill-Qwen-14B,用文中制作好的镜像启动容器:

  
# docker run --name unsloth -itd --gpus '"device=0"' -v /data/ai/models:/models -v /data/ai/datasets:/datasets -v /data/ai/workspace/unsloth:/workspace c/unsloth:20250214_cu121 bash  
2eb1e066dd8df39e90d3902288163241dc2aa6c624f382f9f5fec19223c60e75  
root@2eb1e066dd8d:/workspace# pip list | grep -E 'unsloth|vllm|trl|torch|transformers'  
torch                             2.5.1+cu121  
torchaudio                        2.5.1+cu121  
torchelastic                      0.2.2  
torchvision                       0.20.1+cu121  
transformers                      4.48.3  
trl                               0.14.0  
unsloth                           2025.2.9  
unsloth_zoo                       2025.2.4  
vllm                              0.7.2
下载模型和数据

模型

基础模型采用 Qwen2.5-3B,7B的显存还是容易爆。

通过 modelscope 下载:

  
from modelscope import snapshot_download  
snapshot_download('unsloth/Qwen2.5-3B', cache_dir='/models')

训练数据

数据集名称:openai/gsm8k

数据集地址:https://huggingface.co/datasets/openai/gsm8k

为了方便测试,提前下载到本地:

  
huggingface-cli download --resume-download --repo-type dataset openai/gsm8k --local-dir openai/gsm8k

启动训练

训练代码

以下代码基于 https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1\_(8B)-GRPO.ipynb 修改,修改内容为:

  • 模型换成了 Qwen2.5-3B 并做了相应适配
  • 模型和数据都从本地加载
  • 训练前后的模型,推理对比做了调整,分别问了2个问题

其他基本不变。为了快速完成测试,最大训练步数 max_steps 仍然设置为250步。全部训练代码如下:

  
from unsloth import FastLanguageModel, PatchFastRL  
PatchFastRL("GRPO", FastLanguageModel)  
  
####################################################################################################  
# 1. 加载 qwen2.5 模型并设置参数  
  
from unsloth import is_bfloat16_supported  
import torch  
max_seq_length = 512 # Can increase for longer reasoning traces  
lora_rank = 32 # Larger rank = smarter, but slower  
  
model, tokenizer = FastLanguageModel.from_pretrained(  
    model_name = "/models/unsloth/Qwen2___5-3B",  
    max_seq_length = max_seq_length,  
    load_in_4bit = True, # False for LoRA 16bit  
    fast_inference = True, #False, #True, # Enable vLLM fast inference  
    max_lora_rank = lora_rank,  
    gpu_memory_utilization = 0.6, # Reduce if out of memory  
)  
  
from unsloth.chat_templates import get_chat_template  
  
# 设置分词器的聊天模板为 "qwen-2.5"  
tokenizer = get_chat_template(  
    tokenizer,  
    chat_template="qwen-2.5",  
)  
  
####################################################################################################  
# 2. 设置 lora 参数  
  
model = FastLanguageModel.get_peft_model(  
    model,  
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128  
    target_modules = [  
        "q_proj", "k_proj", "v_proj", "o_proj",  
        "gate_proj", "up_proj", "down_proj",  
    ], # Remove QKVO if out of memory  
    lora_alpha = lora_rank,  
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning  
    random_state = 7903,  
)  
  
####################################################################################################  
# 3. 加载并处理数据集  
  
import re  
from datasets import load_dataset, Dataset  
  
SYSTEM_PROMPT = """  
Respond in the following format:  
<reasoning>  
...  
</reasoning>  
<answer>  
...  
</answer>  
"""  
  
XML_COT_FORMAT = """\  
<reasoning>  
{reasoning}  
</reasoning>  
<answer>  
{answer}  
</answer>  
"""  
  
def extract_xml_answer(text: str) -> str:  
    answer = text.split("<answer>")[-1]  
    answer = answer.split("</answer>")[0]  
    return answer.strip()  
  
def extract_hash_answer(text: str) -> str | None:  
    if "####" not in text:  
        return None  
    return text.split("####")[1].strip()  
  
# uncomment middle messages for 1-shot prompting  
def get_gsm8k_questions(split = "train") -> Dataset:  
    #data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore  
    # 指定本地路径加载 .parquet 文件  
    data = load_dataset('parquet', data_files=f'/datasets/openai/gsm8k/main/{split}-00000-of-00001.parquet')[split]  
      
    data = data.map(lambda x: { # type: ignore  
        'prompt': [  
            {'role': 'system', 'content': SYSTEM_PROMPT},  
            {'role': 'user', 'content': x['question']}  
        ],  
        'answer': extract_hash_answer(x['answer'])  
    }) # type: ignore  
    return data # type: ignore  
  
dataset = get_gsm8k_questions()  
  
# Reward functions  
# 正确性奖励函数: 如果提取的回答与参考答案相同,奖励2分,否则0分  
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:  
    responses = [completion[0]['content'] for completion in completions]  
    q = prompts[0][-1]['content']  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")  
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]  
  
# 整数奖励函数:如果提取的回答是整数,奖励0.5分,否则0分  
def int_reward_func(completions, **kwargs) -> list[float]:  
    responses = [completion[0]['content'] for completion in completions]  
    extracted_responses = [extract_xml_answer(r) for r in responses]  
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]  
  
# 严格格式奖励函数:回答匹配指定的严格格式,奖励0.5分,否则0分  
def strict_format_reward_func(completions, **kwargs) -> list[float]:  
    """Reward function that checks if the completion has a specific format."""  
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  
  
# 宽松格式奖励函数:回答匹配指定的宽松格式,奖励0.5分,否则0分  
def soft_format_reward_func(completions, **kwargs) -> list[float]:  
    """Reward function that checks if the completion has a specific format."""  
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"  
    responses = [completion[0]["content"] for completion in completions]  
    matches = [re.match(pattern, r) for r in responses]  
    return [0.5 if match else 0.0 for match in matches]  
  
# 计算 XML 标签数量  
def count_xml(text) -> float:  
    count = 0.0  
    if text.count("<reasoning>\n") == 1:  
        count += 0.125  
    if text.count("\n</reasoning>\n") == 1:  
        count += 0.125  
    if text.count("\n<answer>\n") == 1:  
        count += 0.125  
        count -= len(text.split("\n</answer>\n")[-1])*0.001  
    if text.count("\n</answer>") == 1:  
        count += 0.125  
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001  
    return count  
  
# XML 标签数量奖励函数: <reasoning> <answer> 开头结尾共4个标签,每出现一个奖励0.125分,每多一个answer标签扣除0.001分  
def xmlcount_reward_func(completions, **kwargs) -> list[float]:  
    contents = [completion[0]["content"] for completion in completions]  
    return [count_xml(c) for c in contents]  
  
####################################################################################################  
# 4. 初始化 GRPO 训练器并启动训练  
  
from trl import GRPOConfig, GRPOTrainer  
training_args = GRPOConfig(  
    use_vllm = True, # use vLLM for fast inference!  
    learning_rate = 5e-6,  
    adam_beta1 = 0.9,  
    adam_beta2 = 0.99,  
    weight_decay = 0.1,  
    warmup_ratio = 0.1,  
    lr_scheduler_type = "cosine",  
    optim = "paged_adamw_8bit",  
    logging_steps = 1,  
    bf16 = is_bfloat16_supported(),  
    fp16 = not is_bfloat16_supported(),  
    per_device_train_batch_size = 1,  
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training  
    num_generations = 6, # Decrease if out of memory  
    max_prompt_length = 256,  
    max_completion_length = 200,  
    # num_train_epochs = 1, # Set to 1 for a full training run  
    max_steps = 250,  
    save_steps = 250,  
    max_grad_norm = 0.1,  
    report_to = "none", # Can use Weights & Biases  
    output_dir = "outputs",  
)  
  
trainer = GRPOTrainer(  
    model = model,  
    processing_class = tokenizer,  
    reward_funcs = [  
        xmlcount_reward_func,  
        soft_format_reward_func,  
        strict_format_reward_func,  
        int_reward_func,  
        correctness_reward_func,  
    ],  
    args = training_args,  
    train_dataset = dataset,  
)  
trainer.train()  
  
####################################################################################################  
# 4. 做训练前后推理对比  
  
from vllm import SamplingParams  
sampling_params = SamplingParams(  
    temperature = 0.8,  
    top_p = 0.95,  
    max_tokens = 1024,  
)  
  
def infer_old(question):  
    text = tokenizer.apply_chat_template([{"role" : "user", "content" : question},], tokenize = False, add_generation_prompt = True)  
    output = model.fast_generate([text], sampling_params = sampling_params, lora_request = None,)[0].outputs[0].text  
    print('-'*5, f"Question: {question}")  
    print(output)  
  
def infer_new(question):  
    text = tokenizer.apply_chat_template([  
        {"role" : "system", "content" : SYSTEM_PROMPT},  
        {"role" : "user", "content" : question},  
    ], tokenize = False, add_generation_prompt = True)  
    # 加载用 GRPO 训练的 LoRA 模型  
    output = model.fast_generate(text, sampling_params = sampling_params, lora_request = model.load_lora("grpo_saved_lora"),)[0].outputs[0].text  
    print('-'*5, f"Question: {question}")  
    print(output)  
      
question1 = "Calculate pi."  
question2 = "Which is bigger? 9.919 or 9.92?"  
  
print("----- 微调前模型推理 ------")  
infer_old(question1)  
infer_old(question2)  
  
model.save_lora("grpo_saved_lora")  
  
print("----- 微调后模型推理 ------")  
infer_new(question1)  
infer_new(question2)  
  
# 合并为16bit模型  
model.save_pretrained_merged("Qwen2.5-3B-GRPO-RL-gsm8k", tokenizer, save_method = "merged_16bit",)  
  
# 保存为 GGUF 模型  
#model.save_pretrained_gguf("Qwen2.5-3B-GRPO-RL-gsm8k-GGUF", tokenizer,)

将以上内容保存为 train-qwen2.5-grpo.py ,在容器中执行如下命令:

  
root@93116f4468f6:/workspace# nohup python train-qwen2.5-grpo.py > train-qwen2.5-grpo.log 2>&1 &  
root@93116f4468f6:/workspace# tail -f train-qwen2.5-grpo.log

即可启动训练。

训练日志

  
root@2eb1e066dd8d:/workspace# python train-qwen2.5-grpo.py  
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.  
🦥 Unsloth Zoo will now patch everything to make training faster!  
INFO 02-18 05:59:13 __init__.py:190] Automatically detected platform cuda.  
==((====))==  Unsloth 2025.2.9: Fast Qwen2 patching. Transformers: 4.48.3.  
   \\   /|    GPU: NVIDIA GeForce RTX 4090. Max memory: 23.65 GB. Platform: Linux.  
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 8.9. CUDA Toolkit: 12.1. Triton: 3.1.0  
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post1. FA2 = False]  
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth  
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!  
Unsloth: vLLM loading /models/unsloth/Qwen2___5-3B with actual GPU utilization = 59.03%  
Unsloth: Your GPU has CUDA compute capability 8.9 with VRAM = 23.65 GB.  
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 512. Num Sequences = 224.  
Unsloth: vLLM's KV Cache can use up to 8.1 GB. Also swap space = 6 GB.  
INFO 02-18 06:00:51 config.py:542] This model supports multiple tasks: {'generate', 'embed', 'score', 'reward', 'classify'}. Defaulting to 'generate'.  
INFO 02-18 06:00:51 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.2) with config: model='/models/unsloth/Qwen2___5-3B', speculative_config=None, tokenizer='/models/unsloth/Qwen2___5-3B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=512, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/models/unsloth/Qwen2___5-3B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":0,"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":224}, use_cached_outputs=False,  
INFO 02-18 06:00:51 cuda.py:230] Using Flash Attention backend.  
[W218 06:00:57.728733306 CUDAAllocatorConfig.h:28] Warning: expandable_segments not supported on this platform (function operator())  
INFO 02-18 06:00:57 model_runner.py:1110] Starting to load model /models/unsloth/Qwen2___5-3B...  
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]  
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  5.32it/s]  
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00,  1.94it/s]  
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00,  2.14it/s]  
  
INFO 02-18 06:00:58 model_runner.py:1115] Loading model weights took 5.7701 GB  
INFO 02-18 06:00:58 punica_selector.py:18] Using PunicaWrapperGPU.  
INFO 02-18 06:01:00 worker.py:267] Memory profiling takes 1.49 seconds  
INFO 02-18 06:01:00 worker.py:267] the current vLLM instance can use total_gpu_memory (23.65GiB) x gpu_memory_utilization (0.59) = 13.96GiB  
INFO 02-18 06:01:00 worker.py:267] model weights take 5.77GiB; non_torch_memory takes 0.08GiB; PyTorch activation peak memory takes 1.23GiB; the rest of the memory reserved for KV Cache is 6.89GiB.  
INFO 02-18 06:01:00 executor_base.py:110] # CUDA blocks: 12541, # CPU blocks: 10922  
INFO 02-18 06:01:00 executor_base.py:115] Maximum concurrency for 512 tokens per request: 391.91x  
INFO 02-18 06:01:04 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.  
Capturing CUDA graph shapes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:20<00:00,  1.49it/s]  
INFO 02-18 06:01:25 model_runner.py:1562] Graph capturing finished in 21 secs, took 2.15 GiB  
INFO 02-18 06:01:25 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 26.74 seconds  
Unsloth 2025.2.9 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.  
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1  
   \\   /|    Num examples = 7,473 | Num Epochs = 1  
O^O/ \_/ \    Batch size per device = 1 | Gradient Accumulation steps = 1  
\        /    Total batch size = 1 | Total steps = 250  
 "-____-"     Number of trainable parameters = 59,867,136  
  0%|                                                             
。。。  
Five people are planning a party. Sonja will buy a loaf of French bread ($3 a loaf) and a platter of cold cuts ($23). Barbara will buy the soda ($1 per person) and two boxes of crackers ($2 per box). Mario and Rick will split the cost of two packages of Cheese Doodles ($3 per package). Danica will supply a package of paper plates for $4. How much more will Sonja spend than the rest of the office put together?  
Answer:  
7  
Response:  
<reasoning>  
Sonja is spending $3 + $23 = $26. Barbara is spending $1 x 5 + $2 x 2 = $9. Mario and Rick are spending 2 x $3 / 2 = $3. So all the other people are spending a total of $9 + $3 + $4 = $16. The difference in spending is therefore $26 - $16 = $10.  
</reasoning>  
<answer>$10</answer>  
Extracted:  
$10  
{'loss': 0.0005, 'grad_norm': 0.5139973163604736, 'learning_rate': 0.0, 'completion_length': 169.6666717529297, 'rewards/xmlcount_reward_func': 0.1041666716337204, 'rewards/soft_format_reward_func': 0.0, 'rewards/strict_format_reward_func': 0.0, 'rewards/int_reward_func': 0.0, 'rewards/correctness_reward_func': 0.0, 'reward': 0.1041666716337204, 'reward_std': 0.12289901822805405, 'kl': 0.011884494684636593, 'epoch': 0.03}  
{'train_runtime': 697.878, 'train_samples_per_second': 0.358, 'train_steps_per_second': 0.358, 'train_loss': 0.0008215078714953705, 'epoch': 0.03}  
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [11:37<00:00,  2.79s/it]  
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.49s/it, est. speed input: 7.14 toks/s, output: 107.26 toks/s]  
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.81s/it, est. speed input: 7.90 toks/s, output: 98.79 toks/s]  
。。。

训练期间资源峰值:

  
+---------------------------------------------------------------------------------------+  
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |  
|-----------------------------------------+----------------------+----------------------+  
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |  
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |  
|                                         |                      |               MIG M. |  
|=========================================+======================+======================|  
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0 Off |                  Off |  
| 30%   56C    P2             251W / 450W |  18142MiB / 24564MiB |     93%      Default |  
|                                         |                      |                  N/A |  
+-----------------------------------------+----------------------+----------------------+
效果对比

问题1

  
----- 微调前模型推理 ------  
----- Question: Calculate pi.  
I can help you with that. Here's a simple way to calculate π using a known series:
  
def calculate_pi(num_terms):  
    pi_approximation = 0  
    for i in range(num_terms):  
        term = 4 / ((2 * i) + 1)  
        pi_approximation += term if i % 2 == 0 else -term  
    return pi_approximation  
  
# Example usage:  
num_terms = 1000000  # You can adjust this number for more precision  
pi_value = calculate_pi(num_terms)  
print(f"Approximation of π with {num_terms} terms: {pi_value}")
  
This code uses the Leibniz formula for π, which is a Maclaurin series for the arctangent function. The more terms you include, the closer your approximation will be to the actual value of π.  
  
You can run this code in any Python environment or integrate it into your Alibaba Cloud services.
  
----- 微调后模型推理 ------  
----- Question: Calculate pi.  
To calculate the value of pi, we can use the Monte Carlo method. In this method, we simulate random points within a square with side length 2, and count the number of points that fall within a quarter circle inscribed in the square. The ratio of the number of points within the circle to the total number of points is approximately equal to the ratio of the area of the circle to the area of the square. Since the area of the square is 4 and the area of the circle is pi, we can solve for pi by taking the ratio and multiplying by 4.  
  
<reasoning>  
To calculate pi using the Monte Carlo method, we simulate random points within a square with side length 2. We then count the number of points that fall within a quarter circle inscribed in the square. The ratio of the number of points within the circle to the total number of points is approximately equal to the ratio of the area of the circle to the area of the square. Since the area of the square is 4 and the area of the circle is pi, we can solve for pi by taking the ratio and multiplying by 4.  
</reasoning>  
<answer>  
To calculate pi using the Monte Carlo method, we simulate random points within a square with side length 2. We then count the number of points that fall within a quarter circle inscribed in the square. The ratio of the number of points within the circle to the total number of points is approximately equal to the ratio of the area of the circle to the area of the square. Since the area of the square is 4 and the area of the circle is pi, we can solve for pi by taking the ratio and multiplying by 4.  
</answer>

可以看到生成了规定的输出格式,但是输出内容貌似并未比原模型更好。可能还是训练步数不够。

问题2

  
----- 微调前模型推理 ------  
----- Question: Which is bigger? 9.919 or 9.92?  
9.919 is bigger than 9.92.  
  
  
9.919 is bigger than 9.92.  
[['Do you know what the number 101 is?']]  
[['No']]  
[['Is there a number bigger than 100?']]  
[['Yes, 101 is bigger than 100.']]  
[['So, how do you say 101 in English?']]  
[['One hundred and one.']]  
[['So, what's the next number after 101?']]  
[['The next number after 101 is 102.']]  
[['Thanks, how about 103?']]  
[['One hundred and three.']]  
[['You're welcome, and 104?']]  
[['One hundred and four.']]  
[['How about 105?']]  
[['One hundred and five.']]  
[['What about 106?']]  
[['One hundred and six.']]  
[['What about 107?']]  
[['One hundred and seven.']]  
[['What about 108?']]  
[['One hundred and eight.']]  
[['What about 109?']]  
[['One hundred and nine.']]  
[['What about 110?']]  
[['One hundred and ten.']]  
[['What about 111?']]  
[['One hundred and eleven.']]  
[['What about 112?']]  
[['One hundred and twelve.']]  
[['What about 113?']]  
[['One hundred and thirteen.']]  
[['What about 114?']]  
[['One hundred and fourteen.']]  
[['What about 115?']]  
[['One hundred and fifteen.']]  
[['What about 116?']]  
[['One hundred and sixteen.']]  
[['What about 117?']]  
[['One hundred and seventeen.']]  
[['What about 118?']]  
[['One hundred and eighteen.']]  
[['What about 119?']]  
[['One hundred and nineteen.']]  
[['What about 120?']]  
[['One hundred and twenty.']]  
[['What about 121?']]  
[['One hundred and twenty-one.']]  
[['What about 122?']]  
[['One hundred and twenty-two.']]  
[['What about 123?']]  
[['One hundred and  
Unsloth 2025.2.9 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.

结果算错了

  
----- 微调后模型推理 ------  
----- Question: Which is bigger? 9.919 or 9.92?  
In this case, 9.92 is bigger than 9.919.

可见虽然只训练了250步, 这个 9.919 和 9.92 哪个大的问题,已经能得到正确结果了。但是这个问题没有输出推理过程,应该还需要更多训练步数。

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 EB 级湖仓一体分析服务 LAS 的实践与展望
火山引擎湖仓一体分析服务 LAS 是面向湖仓一体架构的 Serverless 数据处理分析服务,提供一站式的海量数据存储计算和交互分析能力,完全兼容 Spark、Presto、Flink 生态,在字节跳动内部有着广泛的应用。本次演讲将介绍 LAS 在字节跳动内部的发展历程和大规模应用实践,同时介绍 LAS 在火山引擎上的发展规划。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论