如何控制LLM的输出格式(实战篇)-以Xgrammar为例

大模型向量数据库机器学习
  1. 介绍 =====

在上一篇文章《如何控制LLM的输出格式 ?》中,我们探讨了约束解码技术以及控制大语言模型输出的多种方法。今天,我们将为大家介绍一个专注于结构化输出的开源项目——Xgrammar。这个项目由陈天奇团队开发,我们将通过实际案例,对比传统prompt工程与Xgrammar在结构化输出方面的表现差异。

更多AI相关欢迎关注微信公众号:"小窗幽记机器学习"

  1. 环境配置 =======

参考xgrammar官方文档: https://xgrammar.mlc.ai/docs/start/install.html 安装xgrammar。

选项 1:

通过Conda 环境配置下载XGrammar pip 依赖。

  
conda activate your-environment  
python -m pip install xgrammar  

然后我们在命令行中验证安装:

  
python -c "import xgrammar; print(xgrammar)"  
# Prints out: <module 'xgrammar' from '/path-to-env/lib/python3.11/site-packages/xgrammar/\_\_init\_\_.py'>  

选项2:

从源代码开始构建 XGrammar,当你想修改或获取特定版本的 XGrammar 时,此步骤非常有用。

步骤 1. 设置构建环境。 要从源代码构建,你需要确保满足以下构建依赖项:

  • CMake >= 3.18
  • Git
  • C++ 编译器(例如 apt-get install build-essential)
  
# Using conda  
# make sure to start with a fresh environment  
conda env remove -n xgrammar-venv  
# create the conda environment with build dependency  
conda create -n xgrammar-venv -c conda-forge \  
    "cmake>=3.18" \  
    git \  
    python=3.11 \  
    ninja  
# enter the build environment  
conda activate xgrammar-venv  
  
# Using pip (you will need to install git seperately)  
python -m venv .venv  
source .venv/bin/activate  

步骤2. 配置、构建和安装。 建议使用基于 git 的标准工作流程下载 XGrammar。

  
# 1. clone from GitHub  
git clone --recursive https://github.com/mlc-ai/xgrammar.git && cd xgrammar  
# 2. Install pre-commit hooks (optional, recommended for contributing to XGrammar)  
pre-commit install  
# 3. build and install XGrammar core and Python bindings  
python3 -m pip install .  

步骤 3. 验证安装。 你可以在命令行中验证 XGrammar 是否编译成功。您应该会看到您从源代码构建的路径:

  
python -c "import xgrammar; print(xgrammar)"  

步骤 4.(可选)运行 Python 测试。 你需要一个 HuggingFace 令牌以及访问门控模型的权限,才能运行包含门控模型的测试。

  
# Install the test dependencies  
python3 -m pip install ".[test]"  
  
# To run all tests including the ones that have gated models, you will need a HuggingFace token.  
huggingface-cli login --token YOUR\_HF\_TOKEN  
python3 -m pytest tests/python  
  
# To run a subset of tests that do not require gated models, you can skip the tests with:  
python3 -m pytest tests/python -m "not hf\_token\_required"  

  1. 开始实战 =======

尝试 XGrammar 最简单的方法是transformers在 Python 中使用该库。安装 XGrammar后,运行以下示例,了解 XGrammar 如何实现结构化生成(在本例中为 JSON)。

3.1 加载模型

首先,我们在本地加载LLM,导入模型配置及tokenizer。

  
#!/usr/bin/env python  
# -*- coding: utf-8 -*-  
# @Time    : 2025/5/23 17:36  
# @Author  : <小窗幽记机器学习>  
  
import xgrammar as xgr  
import json  
import torch  
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig  
  
device = "cuda"# Or "cpu", etc.  
model\_name = "Qwen/Qwen3-1.7B"  
model = AutoModelForCausalLM.from\_pretrained(  
    model\_name, torch\_dtype=torch.float32, device\_map=device  
)  
tokenizer = AutoTokenizer.from\_pretrained(model\_name)  
config = AutoConfig.from\_pretrained(model\_name)  

3.2 定义结构化Prompt

为了获得恰当的结构化输出,小编在以下测试了一个固定模版化的Prompt,其中Prompt的写法按照Few-shot的方式来定义:

  
Prompt = 你是一个聪明的文本生成助手。请帮我生成由快递公司或快递柜公司发送的快递提醒的通知,目的用于提醒客户及时取快递,同时你需要按一个固定的jsonlines格式返回给我。  
          
        文本要求:  
        - 生成的文本需包含快递公司(非必须)、快递柜(非必须)、快递取件地址(必须)、取件码(必须)  
          
        文本提示:  
        - 快递公司有 以下种类【中通快递, 韵达快递, 圆通快递, 顺丰速运, 申通快递, 中国邮政, 极兔快递, 京东快递, 达达快送, 顺丰同城急送, 安能物流, 联邦快递, 苏宁物流】;  
        - 快递柜分别有【丰巢, 菜鸟驿站, 兔喜生活】;  
        - 取件码的格式一般为 X-XX-XXXX或 XXXXXXXX,X为纯数字  
          
        输出:  
        - {"text": "【通知文本】", "poi\_address": "【取件地址】", "text\_category": "提醒-快递"}  
          
        举例:  
        {"text": "【菜鸟驿站】凭4-3-6308到上海新凯家园一期步行街29号店取件。点此极速取件不排队 p.tb.cn\/\_3sarS5", ''poi\_address": "上海新凯家园一期步行街29号店", "text\_category": "提醒-快递"}  
        {"text": "【兔喜生活】取货码041017,您有中通快递包裹已到古楼公路1858弄步行街14号(原圆通快递),询23XXXX11", ''poi\_address": "古楼公路1858弄步行街14号", "text\_category": "提醒-快递"}"},  
  

3.3 定义Json格式及输出

我们通过Xgrammar的内置函数来控制稳定LLM的输出,其中Xgrammar已经提供了一个GrammarCompiler的控制函数,我们可以改写内置的Json语法、Json模式字符串或 EBNF 字符串,EBNF 提供了更大的自定义灵活性。请参阅 GBNF 文档了解规范。

  
texts = tokenizer.apply\_chat\_template(messages, tokenize=False, add\_generation\_prompt=True)  
model\_inputs = tokenizer(texts, return\_tensors="pt").to(model.device)  
  
tokenizer\_info = xgr.TokenizerInfo.from\_huggingface(tokenizer, vocab\_size=config.vocab\_size)  
grammar\_compiler = xgr.GrammarCompiler(tokenizer\_info)  
person\_schema = {  
    "text": "【短信通知文案】",  
    "poi\_address": "【POI地址】",  
    "text\_category": "【类别】",  
    "properties": {  
        "text": {  
            "type": "string"  
        },  
        "poi\_address": {  
            "type": "string",  
        },  
        "text\_category": {  
            "type": "string",  
        }  
    },  
    "required": ["text", "poi\_address", "text\_category"]  
}  
compiled\_grammar = grammar\_compiler.compile\_json\_schema(json.dumps(person\_schema))  
xgr\_logits\_processor = xgr.contrib.hf.LogitsProcessor(compiled\_grammar)  

最后正常通过LLM输出结果,添加logits_processor参数让模型输出遵循我们所定义的格式。

  
generated\_ids = model.generate(  
    **model\_inputs, max\_new\_tokens=512, logits\_processor=[xgr\_logits\_processor]  
)  
  
generated\_ids = generated\_ids[0][len(model\_inputs.input\_ids[0]):]  

3.4 完整代码

完整代码如下:

  
#!/usr/bin/env python  
# -*- coding: utf-8 -*-  
# @Time    : 2025/3/23 13:36  
# @Author  : <小窗幽记机器学习>  
  
import xgrammar as xgr  
import json  
import torch  
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig  
  
device = "cuda"# Or "cpu", etc.  
model\_name = "Qwen/Qwen3-1.7B"  
model = AutoModelForCausalLM.from\_pretrained(  
    model\_name, torch\_dtype=torch.float32, device\_map=device  
)  
tokenizer = AutoTokenizer.from\_pretrained(model\_name)  
config = AutoConfig.from\_pretrained(model\_name)  
messages = [  
    {"role": "system",  
        "content": "你是一个聪明的文本生成助手。请帮我生成由快递公司或快递柜公司发送的快递提醒的通知,目的用于提醒客户及时取快递,同时你需要按一个固定的jsonlines格式返回给我。"},  
    {"role": "user", "content": """任务背景:  
        你是一个聪明的文本生成助手。请帮我生成由快递公司或快递柜公司发送的快递提醒的通知,目的用于提醒客户及时取快递,同时你需要按一个固定的jsonlines格式返回给我。  
          
        文本要求:  
        - 生成的文本需包含快递公司(非必须)、快递柜(非必须)、快递取件地址(必须)、取件码(必须)  
          
        文本提示:  
        - 快递公司有 以下种类【中通快递, 韵达快递, 圆通快递, 顺丰速运, 申通快递, 中国邮政, 极兔快递, 京东快递, 达达快送, 顺丰同城急送, 安能物流, 联邦快递, 苏宁物流】;  
        - 快递柜分别有【丰巢, 菜鸟驿站, 兔喜生活】;  
        - 取件码的格式一般为 X-XX-XXXX或 XXXXXXXX,X为纯数字  
          
        输出:  
        - {"text": "【通知文本】", "poi\_address": "【取件地址】", "text\_category": "提醒-快递"}  
          
        举例:  
        {"text": "【菜鸟驿站】凭4-3-6308到上海新凯家园一期步行街29号店取件。点此极速取件不排队 p.tb.cn\/\_3sarS5", ''poi\_address": "上海新凯家园一期步行街29号店", "text\_category": "提醒-快递"}  
        {"text": "【兔喜生活】取货码041017,您有中通快递包裹已到古楼公路1858弄步行街14号(原圆通快递),询23XXXX1", ''poi\_address": "古楼公路1858弄步行街14", "text\_category": "提醒-快递"}"},  
"""}]  
texts = tokenizer.apply\_chat\_template(messages, tokenize=False, add\_generation\_prompt=True)  
model\_inputs = tokenizer(texts, return\_tensors="pt").to(model.device)  
  
tokenizer\_info = xgr.TokenizerInfo.from\_huggingface(tokenizer, vocab\_size=config.vocab\_size)  
grammar\_compiler = xgr.GrammarCompiler(tokenizer\_info)  
# compiled\_grammar = grammar\_compiler.compile\_builtin\_json\_grammar()  
person\_schema = {  
    "text": "【短信通知文案】",  
    "poi\_address": "POI地址】",  
    "text\_category": "【类别】",  
    "properties": {  
        "text": {  
            "type": "string"  
        },  
        "poi\_address": {  
            "type": "string",  
        },  
        "text\_category": {  
            "type": "string",  
        }  
    },  
    "required": ["text", "poi\_address", "text\_category"]  
}  
compiled\_grammar = grammar\_compiler.compile\_json\_schema(json.dumps(person\_schema))  
  
xgr\_logits\_processor = xgr.contrib.hf.LogitsProcessor(compiled\_grammar)  
generated\_ids = model.generate(  
    **model\_inputs, max\_new\_tokens=512, logits\_processor=[xgr\_logits\_processor]  
)  
  
generated\_ids = generated\_ids[0][len(model\_inputs.input\_ids[0]):]  
print(tokenizer.decode(generated\_ids, skip\_special\_tokens=True))  

  1. 常规Prompt输出比对 ===============

为了测试Xgrammar的效果,我们同时也拿同样版本的模型训练,我们只需固定同一个基座模型,其他参数及输出不需要做变动。格式较为固定,我们在下方直接提供完整代码。

  
#!/usr/bin/env python  
# -*- coding: utf-8 -*-  
# @Time    : 2025/6/03 13:36  
# @Author  : <小窗幽记机器学习>  
  
import json  
import torch  
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig  
  
device = "cuda"# Or "cpu", etc.  
model\_name = "Qwen3-1.7B"  
model = AutoModelForCausalLM.from\_pretrained(  
    model\_name, torch\_dtype=torch.float32, device\_map=device  
)  
tokenizer = AutoTokenizer.from\_pretrained(model\_name)  
config = AutoConfig.from\_pretrained(model\_name)  
messages = [  
    {"role": "system",  
     "content": "你是一个聪明的文本生成助手。请帮我生成由快递公司或快递柜公司发送的快递提醒的通知,目的用于提醒客户及时取快递,同时你需要按一个固定的jsonlines格式返回给我。"},  
    {"role": "user", "content":  
        """ 任务背景:  
            你是一个聪明的文本生成助手。请帮我生成由快递公司或快递柜公司发送的快递提醒的通知,目的用于提醒客户及时取快递,同时你需要按一个固定的jsonlines格式返回给我。  
      
            文本要求:  
            - 生成的文本需包含快递公司(非必须)、快递柜(非必须)、快递取件地址(必须)、取件码(必须)  
      
            文本提示:  
            - 快递公司有 以下种类【中通快递, 韵达快递, 圆通快递, 顺丰速运, 申通快递, 中国邮政, 极兔快递, 京东快递, 达达快送, 顺丰同城急送, 安能物流, 联邦快递, 苏宁物流】;  
            - 快递柜分别有【丰巢, 菜鸟驿站, 兔喜生活】;  
            - 取件码的格式一般为 X-XX-XXXX或 XXXXXXXX,X为纯数字  
      
            输出:  
            - {"text": "【通知文本】", "poi\_address": "【取件地址】", "text\_category": "提醒-快递"}  
      
            举例:  
            {"text": "【菜鸟驿站】凭4-3-6308到上海新凯家园一期步行街29号店取件。点此极速取件不排队 p.tb.cn\/\_3sarS5", ''poi\_address": "上海新凯家园一期步行街29号店", "text\_category": "提醒-快递"}  
            {"text": "【兔喜生活】取货码041017,您有中通快递包裹已到古楼公路1858弄步行街14号(原圆通快递),询18XX1XX8932", ''poi\_address": "古楼公路1858弄步行街14号", "text\_category": "提醒-快递"}"},  
            """  
     }]  
text = tokenizer.apply\_chat\_template(  
    messages,  
    tokenize=False,  
    add\_generation\_prompt=True,  
    enable\_thinking=True,  # Switches between thinking and non-thinking modes. Default is True.  
)  
model\_inputs = tokenizer([text], return\_tensors="pt").to(model.device)  
  
# conduct text completion  
generated\_ids = model.generate(  
    **model\_inputs,  
    max\_new\_tokens=32768  
)  
output\_ids = generated\_ids[0][len(model\_inputs.input\_ids[0]):].tolist()  
  
# parse thinking content  
try:  
    # rindex finding 151668 (</think>)  
    index = len(output\_ids) - output\_ids[::-1].index(151668)  
except ValueError:  
    index = 0  
  
thinking\_content = tokenizer.decode(output\_ids[:index], skip\_special\_tokens=True).strip("\n")  
content = tokenizer.decode(output\_ids[index:], skip\_special\_tokens=True).strip("\n")  
  
# print("thinking content:", thinking\_content)  
print("content:", content)  

  1. Xgrammar对比输出 ===============

5.1 实验输出

以下是常规模型输出:

  
content: ```json  
{"text": "【中通快递】凭取件码123-4567到北京朝阳区XX街XX号取件。点此极速取件不排队 p.tb.cn\/\_3sarS5", "poi\_address": "北京朝阳区XX街XX号", "text\_category": "提醒-快递"}  

content: ```json
[
{
"text": "【中通快递】凭4-3-6308到上海新凯家园一期步行街29号店取件。点此极速取件不排队 p.tb.cn/_3sarS5",
"poi_address": "上海新凯家园一期步行街29号店",
"text_category": "提醒-快递"
}
]

而Xgrammar输出则是:

  
{  
    "text": "【中通快递】取件码325678,您有快递到上海浦东新区张江路2000号(原韵达快递),请尽快取件",  
    "poi\_address": "上海浦东新区张江路2000号",  
    "text\_category": "提醒-快递"  
}  
==================================================  
{  
    "text": "【顺丰速运】取件码888888,您有快递到上海浦东新区世纪大道100号环球金融中心48楼,敬请尽快取件。",   
    "poi\_address": "上海浦东新区世纪大道100号环球金融中心48楼",   
    "text\_category": "提醒-快递"  
}  

可以看出,Prompt工程的输出结果是makdown-JSON格式,这种并不是我们所期望的,这些被```包裹的结果需要进一步的后处理才能够提取出我们想要的结果。

5.2 多次实验

为了防止模型输出出现幻觉导致实验出现误差,我们将各自模型输出脚本的方式设立了100次的循环,并且为了防止多轮对话导致的模型“愚钝”,我们设立的循环实验皆为一次对话的结果取值。在下方我们统计了错误率(非正确格式的输出的比率)及单纯Prompt工程的错误率。

| 次/模型 | Qwen3-1.7B + Xgrammar + Prompt | Qwen3-1.7B + Prompt Engineering | | --- | --- | --- | | 10次 | 0% | 70% | | 50次 | 2% | 78% | | 100次 | 6% | 65% |

  1. 小结 =====

通过本文的实战与对比实验,我们可以清晰看到:仅依赖Prompt工程很难稳定控制大模型输出格式,而借助Xgrammar的语法约束能力,能显著提升结构化输出的准确性与一致性。在实际应用中,Xgrammar为构建安全、稳定、高可控的生成式系统提供了一种高效可落地的方案。

更多AI相关欢迎关注微信公众号:"小窗幽记机器学习"

picture.image

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动 XR 技术的探索与实践
火山引擎开发者社区技术大讲堂第二期邀请到了火山引擎 XR 技术负责人和火山引擎创作 CV 技术负责人,为大家分享字节跳动积累的前沿视觉技术及内外部的应用实践,揭秘现代炫酷的视觉效果背后的技术实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论