scripts
utils
examples : 使用方法,參考案例
...
src/transformers (Transformer相關(guān)的代碼)
data: 數(shù)據(jù)處理
models : 模型的實(shí)現(xiàn)代碼,比如BERT, GPT,Whisper模型,都在此目錄下實(shí)現(xiàn)
generation : 文本生成相關(guān)代碼

...

從上面的代碼中,examples中提供了模型的使用方法的參考例子。
我們的今天介紹的主要內(nèi)容都在 src/transformers 目錄下,其中 models 目錄下,是基于transformer的各種模型的實(shí)現(xiàn)代碼,Generation 包含通用的文本產(chǎn)生的實(shí)現(xiàn)代碼。

模型 models/whisper

我們以Whisper 模型為例來詳細(xì)介紹一下代碼的結(jié)構(gòu)和調(diào)用關(guān)系。下面我們以v4.29.1的版本為例進(jìn)行介紹。
首先,whisper模型的代碼位于:src/transformers/models/whisper 目錄下。其主要功能都封裝在 modeling_whisper.py 文件中。

調(diào)用入口:WhisperForConditionalGeneration類
此python文件中包含多個(gè)類,繼承的關(guān)系比較復(fù)雜,它們之間的主要調(diào)用關(guān)系如下(以greedy search為例):

WhisperForConditionalGeneration (L1312) : 調(diào)用入口類
forward() (L1359) --> 細(xì)節(jié)在:WhisperModel,WhisperEncoder,WhisperDecoder 類中實(shí)現(xiàn)
generate() (L1455) --> 細(xì)節(jié)在: generation/utils.py#L1146 中實(shí)現(xiàn)
greedy_search(): L2164 --> 調(diào)用 search 函數(shù)來做實(shí)際的處理,比如自回歸處理

Forward函數(shù):
forward函數(shù)位于 class transformers.WhisperModel 類中,代碼位置請參考:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1215

def forward():
# Encoder將輸入的語音信號,編碼為聲學(xué)信息,也就是 encoder_outputs
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Decoder 的主要輸入為 decoder_input_ids (對應(yīng)文本) 和 encoder_outputs (對應(yīng)聲學(xué)信息,在翻譯任務(wù)中,對應(yīng)著源語言)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

通過上面,我們可以看到有 Encoder 部分和 Decoder 部分,分別對應(yīng)聲學(xué)特征的提取和文本產(chǎn)生部分。

在 WhisperForConditionalGeneration 類中,也有一個(gè)forward函數(shù),是對上面forward函數(shù)的封裝。
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1359

其中self.encoder的實(shí)現(xiàn)代碼位于 class WhisperEncoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L735

其中self.decoder的實(shí)現(xiàn)代碼位于 class WhisperDecoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L881

Generate函數(shù)
Generate函數(shù)入口:位于 class WhisperForConditionalGeneration(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455

此處只是調(diào)用的入口,具體的實(shí)現(xiàn)代碼位于 class GenerationMixin 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
def generate() L1146

其中g(shù)enerate函數(shù)使用的greedy_search的實(shí)現(xiàn)位于:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L2164

Generate 代碼分析

下面,我們來進(jìn)一步了解 generate 的實(shí)現(xiàn)代碼,來看看如何對此代碼進(jìn)行修改。

入口代碼
Generate函數(shù)的入口位于: WhisperForConditionalGeneration類中的 def generate 函數(shù)
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455

代碼的概要如下,從代碼中可以看到,這個(gè)函數(shù)主要是進(jìn)行了一些參數(shù)設(shè)置,具體的實(shí)現(xiàn)是調(diào)用了父類中的對應(yīng)函數(shù)來執(zhí)行的。

def generate()
# 參數(shù)設(shè)置部分

# 調(diào)用部分(此處調(diào)用了父類中的generate實(shí)現(xiàn))
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)

然后,我們可以逐級向上搜索其父類,可以看到

到此為止,我們就可以看到,具體的實(shí)現(xiàn)都在 GenerationMixin 類中。

Generate函數(shù)實(shí)現(xiàn)細(xì)節(jié)

下面,我們來看一下 GenerationMixin類中的 generate 函數(shù)的實(shí)現(xiàn)細(xì)節(jié)。
代碼位置:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146

其代碼概要如下:

def generate(): L1146
# 根據(jù)解碼方式的不同,此函數(shù)中有最多14步的處理步驟,我們以greedy search為例
# 1. Handle generation_config and kwargs that might update it, and validate the .generate() call # 2. Set generation parameters if not already defined # 3. Define model inputs # 4. Define other model kwargs # 5. Prepare input_ids which will be used for auto-regressive generation # 6. Prepare max_length depending on other stopping criteria. # 7. determine generation mode # 8. prepare distribution pre_processing samplers # 9. prepare stopping criteria # 10. go into different generation modes # 11. run greedy search (L1515) def greedy_search(): L2164 # 初始化,設(shè)置 # 循環(huán)處理 while True: # L2317 # prepare model inputs (下面函數(shù)的具體實(shí)現(xiàn)位于: modeling_whisper.py#L1627) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token # 這里是調(diào)用了 WhisperForConditionalGeneration 中的forward函數(shù)。這是因?yàn)?PyTorch 的 nn.Module 基類定義了一個(gè) __call__ 方法,當(dāng)你調(diào)用模型實(shí)例(即 self)時(shí),它會(huì)自動(dòng)調(diào)用這個(gè) __call__ 方法,而這個(gè) __call__ 方法又會(huì)調(diào)用 forward 方法。 outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # 得到下一個(gè)token的logits next_token_logits = outputs.logits[:, -1, :] # pre-process distribution 得到其score next_tokens_scores = logits_processor(input_ids, next_token_logits) # argmax :使用argmax 獲取對應(yīng)的 tokens next_tokens = torch.argmax(next_tokens_scores, dim=-1) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) # 判斷是否結(jié)束search: # if eos_token was found in one sentence, set sentence to finished # stop if we exceed the maximum length

通過上面的代碼概要,我們就可以知道generate函數(shù)進(jìn)行了很多的設(shè)置以后,會(huì)調(diào)用 greedy_search() 函數(shù)來進(jìn)行文本產(chǎn)生的實(shí)際處理。
到此為止,我們就已經(jīng)對整個(gè)的代碼結(jié)構(gòu)了解了。

下面我們通過幾個(gè)問題,來回顧一下對代碼的理解。

代碼修改建議

針對上面的問題7,如果要對generate或者其他部分進(jìn)行修改,建議在 models/whisper的目錄下對父類函數(shù)進(jìn)行重構(gòu)。
比如,如果要對greedy_search功能進(jìn)行調(diào)整來實(shí)現(xiàn)一些獨(dú)特的功能時(shí),可以在modeling_whisper.py中重構(gòu) greedy_search(),具體做法可以是:

  1. 將 utils.py 中的 greedy_search 函數(shù)拷貝到 modeling_whisper.py 文件中。
  2. 需要import 一些必要的庫文件。(具體的庫,可以根據(jù)運(yùn)行時(shí)的錯(cuò)誤提示確定)
  3. 在greedy_search函數(shù)中進(jìn)行修改,來實(shí)現(xiàn)想要的功能。

函數(shù)在子類中被重新實(shí)現(xiàn)之后,調(diào)用時(shí),將會(huì)優(yōu)先調(diào)用新重構(gòu)的函數(shù)。這樣既實(shí)現(xiàn)了自己獨(dú)特的功能,還不影響其他的模型的運(yùn)行。

參考文獻(xiàn)

  1. 【基本概念】https://huggingface.co/blog/how-to-generate
  2. https://huggingface.co/docs/transformers/main_classes/text_generation
  3. https://huggingface.co/docs/transformers/internal/generation_utils

文章轉(zhuǎn)載自: Transformers Generate 功能介紹

上一篇:

如何使用python和django構(gòu)建后端rest api

下一篇:

18種最佳 RAG 技術(shù)
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊

多API并行試用

數(shù)據(jù)驅(qū)動(dòng)選型,提升決策效率

查看全部API→
??

熱門場景實(shí)測,選對API

#AI文本生成大模型API

對比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個(gè)渠道
一鍵對比試用API 限時(shí)免費(fèi)

#AI深度推理大模型API

對比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個(gè)渠道
一鍵對比試用API 限時(shí)免費(fèi)