LLM微调(四)| 微调Llama 2实现Text-to-SQL,并使用LlamaIndex在数据库上进行推理

技术
    Llama 2是开源LLM发展的一个巨大里程碑。最大模型及其经过微调的变体位居Hugging Face Open LLM排行榜(https://huggingface.co/spaces/HuggingFaceH4/open\_llm\_leaderboard)前列。多个基准测试表明,就性能而言,它正在接近GPT-3.5(在某些情况下甚至超过它)。所有这些都意味着,对于从RAG系统到Agent的复杂LLM应用程序,开源LLM是一种越来越可行和可靠的选择。

picture.image

一、Llama-2–7B不擅长从文本到SQL

   最小的Llama 2模型(7B参数)有一个缺点是它不太擅长生成SQL,因此它不适用于结构化分析示例。例如,我们尝试在给定以下提示模板的情况下提示Llama 2生成正确的SQL语句:

          
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. 
          

          
You must output the SQL query that answers the question.
          

          
### Input:
          
{input}
          

          
### Context:
          
{context}
          

          
### Response:
      
     在这里,我们使用sqlcreatecontext数据集(https://huggingface.co/datasets/b-mc2/sql-create-context)的一个示例来测试一下效果:

          
input: In 1981 which team picked overall 148?
          
context: CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)
      
     同时,这里是生成的输出与正确输出的对比:

          
Generated output: SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;
          

          
Correct output: SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"
      
   这显然并不理想。与ChatGPT和GPT-4不同,  **原始的Llama 2不能生成期望的的格式和正确的SQL** 。




  这正是微调的作用所在——如果有一个合适的文本到SQL数据的语料库,我们可以教Llama 2更好地从自然语言生成SQL输出。

微调 有不同的方法,可以更新模型的所有参数(比如:全量微调),也可以冻结大模型参数仅微调附加参数(比如:LoRA)。

二、微调Ll**** ama-2–7B,使其可以从文本生成SQL

   接下来,我们将展示如何在文本到SQL数据集上微调Llama 2,然后使用LlamaIndex的功能对任何SQL数据库进行结构化分析。

准备工作:

微调数据集 :来自Hugging Face的b-mc2/sql-create-context(https://huggingface.co/datasets/b-mc2/sql-create-context)

base模型 :OpenLLaMa 的open_lama_7b_v2(https://github.com/openlm-research/open\_llama)

步骤1:加载微调LLaMa的训练数据

PS:1)以下代码来自doppel-bot:https://github.com/modal-labs/doppel-bot;2)许多Python代码都包含在src目录中;3)需要设置一个Modal帐户,并生成token。


        
            

          !pip install -r requirements.txt
        
      
   首先,我们使用Modal加载b-mc2/sql-create-context数据集,并将其格式化为.jsonl文件。

        
            

          modal run src.load\_data\_sql --data-dir "data\_sql"
        
      

结果如下所示:


          
# Modal stubs allow our function to run remotely
          
@stub.function(
          
    retries=Retries(
          
        max_retries=3,
          
        initial_delay=5.0,
          
        backoff_coefficient=2.0,
          
    ),
          
    timeout=60 * 60 * 2,
          
    network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
          
    cloud="gcp",
          
)
          
def load_data_sql(data_dir: str = "data_sql"):
          
    from datasets import load_dataset
          

          
    dataset = load_dataset("b-mc2/sql-create-context")
          

          
    dataset_splits = {"train": dataset["train"]}
          
    out_path = get_data_path(data_dir)
          

          
    out_path.parent.mkdir(parents=True, exist_ok=True)
          

          
    for key, ds in dataset_splits.items():
          
        with open(out_path, "w") as f:
          
            for item in ds:
          
                newitem = {
          
                    "input": item["question"],
          
                    "context": item["context"],
          
                    "output": item["answer"],
          
                }
          
                f.write(json.dumps(newitem) + "\n")
      

步骤2:运行微调脚本

在微调数据集微调llama2模型,代码如下:


        
            

          modal run src.finetune\_sql --data-dir "data\_sql" --model-dir "model\_sql"
        
      

微调脚本会执行以下步骤:

将数据集拆分为训练和验证拆分


          
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
          
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
          
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
      
   将每个拆分为元组的格式(输入Prompt、标签):输入query和上下文被格式化为输入Prompt,然后对输入Prompt和标签进行 tokenize,模型采用自回归的方法预测下一个token来进行训练。

          
def generate_and_tokenize_prompt(data_point):
          
  full_prompt = generate_prompt_sql(
          
      data_point["input"],
          
      data_point["context"],
          
      data_point["output"],
          
  )
          
  tokenized_full_prompt = tokenize(full_prompt)
          
  if not train_on_inputs:
          
      raise NotImplementedError("not implemented yet")
          
  return tokenized_full_prompt
      

PS:输入Prompt与开始测试llama2的格式完全相同。

   运行微调脚本时,模型将保存在model\_dir指定的远程云目录中(如果未指定,则设置为默认值)。

步骤3:评估

   该模型已经进行了微调,可以从云端提供服务。下面我们使用b-mc2/sql-create-context中的示例数据进行一些基本评估,比较微调后模型与原始Llama 2模型的性能。

        
            

          modal run src.eval\_sql::main
        
      

结果表明,微调后的模型有了巨大的改进:


          
Input 1: {'input': 'Which region (year) has Abigail at number 7, Sophia at number 1 and Aaliyah at number 5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" AND
          
no_5 = "aaliyah"'}
          
Output 1 (finetuned model): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "aaliyah" AND no_5 = "sophia"
          
Output 1 (base model): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';
          

          

          
Input 2: {'input': 'Name the result/games for 54741', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}
          
Output 2 (finetuned model): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"
          
Output 2 (base model): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;
      

步骤4:将微调模型与LlamaIndex集成

   我们现在可以在LlamaIndex中使用这个模型,在任何数据库上进行文本到SQL。






   我们首先定义一个测试SQL数据库,然后可以使用该数据库来测试模型的推理能力。






   我们创建了一个玩具city\_stats表,其中包含城市名称、人口和国家信息,并用几个示例城市填充它。

          
db_file = "cities.db"
          
engine = create_engine(f"sqlite:///{db_file}")
          
metadata_obj = MetaData()
          
# create city SQL table
          
table_name = "city_stats"
          
city_stats_table = Table(
          
    table_name,
          
    metadata_obj,
          
    Column("city_name", String(16), primary_key=True),
          
    Column("population", Integer),
          
    Column("country", String(16), nullable=False),
          
)
          
metadata_obj.create_all(engine)
      

这存储在cities.db文件中。

 然后,我们可以使用Modal将微调后的模型和该数据库文件加载到LlamaIndex中的NLSQLTableQueryEngine中——该查询引擎允许用户轻松地开始在给定的数据库上执行文本到SQL

        
            

          modal run src.inference\_sql\_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db" --model-dir "model\_sql" --use-finetuned-model True
        
      

我们得到如下回复:


          
SQL Query: SELECT MAX(population) FROM city_stats WHERE country = "United States"
          
Response: [(2679000,)]
      

三、结论

    本文提供了一种非常高级的方法来开始微调生成SQL语句的Llama 2模型,并展示了如何使用LlamaIndex将其端到端插入到文本到SQL工作流中。

参考文献:

[1] https://blog.llamaindex.ai/easily-finetune-llama-2-for-your-text-to-sql-applications-ecd53640e10d

[2] https://github.com/run-llama/modal\_finetune\_sql

[3] https://github.com/run-llama/modal\_finetune\_sql/blob/main/tutorial.ipynb

0
0
0
0
关于作者
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动大数据容器化构建与落地实践
随着字节跳动旗下业务的快速发展,数据急剧膨胀,原有的大数据架构在面临日趋复杂的业务需求时逐渐显现疲态。而伴随着大数据架构向云原生演进的行业趋势,字节跳动也对大数据体系进行了云原生改造。本次分享将详细介绍字节跳动大数据容器化的演进与实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论