圖 1.DB-GPT-Hub 的架構流程圖

如圖一所示:DB-GPT-Hub 項目重點關注在數據預處理 – 數據集構建 – 模型微調 – 模型預測 – 模型驗證部分,微調得到的模型可以無縫銜接部署到 DB-GPT 框架中,然后結合知識問答和數據分析等能力展示模型在 Text2SQL 領域的優越性能。

具體功能:

2.1 數據集構建

以開源數據集 Spider 為例做一個詳細的介紹,Spider 數據集是一個多數據庫、多表、單輪查詢的 Text2SQL 數據集,是 Text2SQL 任務中最具挑戰性的數據集之一,由耶魯大學的 LILY 實驗室于 2018 年發布,具有如下特點:

圖 2: 不同數據集的語法分布

spider 數據集將 SQL 生成分成了四個等級:

為了充分利用數據庫中的表和字段等相關信息,對 Spider 中的原始數據進行處理,用自然語言表示數據庫包含的表結構以及表結構包含的字段以及相應的主鍵和外鍵等,經過數據預處理后,可以得到如下的數據格式:

{"instruction": "concert_singer(數據庫名) contains tables(表) such as stadium, singer, concert, singer_in_concert. Table stadium has columns(列) such as stadium_id, location, name, capacity, highest, lowest, average. stadium_id is the primary key(主鍵). Table singer has columns such as singer_id, name, country, song_name, song_release_year, age, is_male. singer_id is the primary key. Table concert has columns such as concert_id, concert_name, theme, stadium_id, year. concert_id is the primary key. Table singer_in_concert has columns such as concert_id, singer_id. concert_id is the primary key. The year of concert is the foreign key(外鍵)of location of stadium. The stadium_id of singer_in_concert is the foreign key of name of singer. The singer_id of singer_in_concert is the foreign key of concert_name of concert.", 

"input": "How many singers do we have?",

"response": "select count(*) from singer"}
{"instruction": "concert_singer(數據庫名)包含表(表),例如stadium, singer, concert, singer_in_concert。表體育場有列(列),如stadium_id、位置、名稱、容量、最高、最低、平均。Stadium_id是主鍵(主鍵)。表singer有這樣的列:singer_id、name、country、song_name、song_release_year、age、is_male。Singer_id為主鍵。表concert有如下列:concert_id、concert_name、theme、stadium_id、year。Concert_id是主鍵。表singer_in_concert有如下列:concert_id, singer_id。Concert_id是主鍵。演唱會年份是場館位置的外鍵(外鍵)。singer_in_concert的stadium_id是歌手名的外鍵。singer_in_concert的singer_id是concert的concert_name的外鍵。

"input": "我們有多少歌手?"

"response": "select count(*) from singer"}

同時,為了更好的利用大語言模型的理解能力,定制了 prompt dict 以優化輸入,如下所示:

SQL_PROMPT_DICT = {

"prompt_input": (

"I want you to act as a SQL terminal in front of an example database. "

"Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"

"###Instruction:\n{instruction}\n\n###Input:\n{input}\n\n###Response: "

),

"prompt_no_input": (

"I want you to act as a SQL terminal in front of an example database. "

"Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"

"###Instruction:\n{instruction}\n\n### Response: "

),

}

2.2 模型訓練

將從基礎模型和微調方式來進行

2.2.1基礎模型

目前支持的模型結構如下所示,包含了當下主流的中外開源模型系列,比如 Llama 系列、Baichuan 系列、GLM 系列、Qwen 系列等,覆蓋面廣,同時 benchmark 橫跨 7b/13B/70B 的規模。

圖 5: 不同模型的微調模式

2.2.2 微調方式

Text2SQL微調主要包含以下流程:

在大語言模型對特定任務或領域進行微調任務時,重新訓練所有模型參數將會帶來昂貴的訓練成本,因此出現了各種優化的微調方案,綜合評估模型微調速度和精度,實現了當下流行的 LoRA(Low-Rank Adaptation 的簡寫) 方法和 QLoRA(量化 + lora)方法。 LoRA 的基本原理是在凍結原模型參數的情況下,通過向模型中加入額外的網絡層,并只訓練這些新增的網絡層參數。由于這些新增參數數量較少,這樣不僅 finetune 的成本顯著下降,還能獲得和全模型微調類似的效果,如下圖所示:

圖三. LoRA 微調示意圖

QLoRA 方法使用一種低精度的存儲數據類型(NF4)來壓縮預訓練的語言模型。通過凍結 LM 參數,將相對少量的可訓練參數以 Low-Rank Adapters 的形式添加到模型中,LoRA 層是在訓練期間更新的唯一參數,使得模型體量大幅壓縮同時推理效果幾乎沒有受到影響。從 QLoRA 的名字可以看出,QLoRA 實際上是 Quantize+LoRA 技術。

圖 4:QLora 示意圖

2.3 模型預測

模型微調完后,基于保存的權重和基座大模型,對 spider 數據集的 dev 測試集進行測試,可以得到模型預測的 sql。 預測的 dev_sql.json 總共有 1034 條數據,同樣需要經過數據預處理后,再拿給模型預測結果。

{"instruction": "concert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as stadium_id, location, name, capacity, highest, lowest, average. stadium_id is the primary key. Table singer has columns such as singer_id, name, country, song_name, song_release_year, age, is_male. singer_id is the primary key. Table concert has columns such as concert_id, concert_name, theme, stadium_id, year. concert_id is the primary key. Table singer_in_concert has columns such as concert_id, singer_id. concert_id is the primary key. The stadium_id of concert is the foreign key of stadium_id of stadium. The singer_id of singer_in_concert is the foreign key of singer_id of singer. The concert_id of singer_in_concert is the foreign key of concert_id of concert.", "input": "How many singers do we have?", "output": "select count(*) from singer"}

模型預測的核心代碼如下:

def inference(model: ChatModel, predict_data: List[Dict], **input_kwargs):

res = []

# test

# for item in predict_data[:20]:

for item in tqdm(predict_data, desc="Inference Progress", unit="item"):

response, _ = model.chat(query=item["input"], history=[], **input_kwargs)

res.append(response)

return res

2.4 模型評估

模型預測得到 sql 后,需要和 spider 數據集的標準答案對比,使用 EX(execution accuracy)和 EM(Exact Match)指標進行評估 EX 指標是計算 SQL 執行結果正確的數量在數據集中的比例,公示如下所示:

EM 指標是計算模型生成的 SQL 和標注 SQL 的匹配程度。

3.benchmark 設計

3.1 數據集

的 benchmark 在 bird 和 spirder 兩個數據上構建:

整體代碼適配 WikiSQL,CoSQL 等數據集。

更多內容參考:NL2SQL基礎系列(1):業界頂尖排行榜、權威測評數據集及LLM大模型(Spider vs BIRD)全面對比優劣分析[Text2SQL、Text2DSL]

3.1.1 spider

表 1.Spider 的 EX 準確率表,L 代表 LoRA,QL 代表 QLoRA

表 2.Spider 的 EM 準確率表,L 代表 LoRA,QL 代表 QLoRA

3.1.2 BIRD

表 3.BIRD 的 EX 準確率表,L 代表 LoRA,QL 代表 QLoRA

表 4.BIRD 的 EM 準確率表,L 代表 LoRA,QL 代表 QLoRA

4. 實驗 Insight

4.1 不同難易程度任務的效果差異

如下圖所示,以三個 7B 模型為例,展示了調整后的 LLM 針對一系列 SQL 生成難度級別的有效性。對于所有三個微調后的模型,結果都表明性能提升的大小與 SQL 復雜性呈負相關,并且微調對簡單 SQL 的改進更為顯著。

4.2 LoRA 和 QLoRA 對比

如下表所示,總結 Lora 和 QLora 在 EX、EM、時間成本和 GPU 內存指標之間的差異。首先,發現使用 LoRA 和 QLoRA 調整的模型在生成性能(以 EX 和 EM 衡量)方面差異有限。其次,與量化機制一致,QLoRA 需要更多時間才能收斂,而 GPU 內存較少。例如,與 Qwen-14B-LoRA 相比,其 QLoRA 對應模型僅需要 2 倍的時間和 50%GPU 內存

本文章轉載微信公眾號@汀丶人工智能

上一篇:

使用 GraphQL、Prisma 和 React 實現端到端的類型安全:API 準備

下一篇:

Istio 使用 GatewayAPI實現流量管理
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創意新穎性、情感共鳴力、商業轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費