微信截圖_17435904448874.png)
跟大牛學(xué)LLM訓(xùn)練和使用技巧
scripts
utils
examples : 使用方法,參考案例
...
src/transformers (Transformer相關(guān)的代碼)
data: 數(shù)據(jù)處理
models : 模型的實(shí)現(xiàn)代碼,比如BERT, GPT,Whisper模型,都在此目錄下實(shí)現(xiàn)
generation : 文本生成相關(guān)代碼
...
從上面的代碼中,examples中提供了模型的使用方法的參考例子。
我們的今天介紹的主要內(nèi)容都在 src/transformers 目錄下,其中 models 目錄下,是基于transformer的各種模型的實(shí)現(xiàn)代碼,Generation 包含通用的文本產(chǎn)生的實(shí)現(xiàn)代碼。
我們以Whisper 模型為例來詳細(xì)介紹一下代碼的結(jié)構(gòu)和調(diào)用關(guān)系。下面我們以v4.29.1的版本為例進(jìn)行介紹。
首先,whisper模型的代碼位于:src/transformers/models/whisper 目錄下。其主要功能都封裝在 modeling_whisper.py 文件中。
調(diào)用入口:WhisperForConditionalGeneration類
此python文件中包含多個(gè)類,繼承的關(guān)系比較復(fù)雜,它們之間的主要調(diào)用關(guān)系如下(以greedy search為例):
WhisperForConditionalGeneration (L1312) : 調(diào)用入口類
forward() (L1359) --> 細(xì)節(jié)在:WhisperModel,WhisperEncoder,WhisperDecoder 類中實(shí)現(xiàn)
generate() (L1455) --> 細(xì)節(jié)在: generation/utils.py#L1146 中實(shí)現(xiàn)
greedy_search(): L2164 --> 調(diào)用 search 函數(shù)來做實(shí)際的處理,比如自回歸處理
Forward函數(shù):
forward函數(shù)位于 class transformers.WhisperModel 類中,代碼位置請參考:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1215
def forward():
# Encoder將輸入的語音信號,編碼為聲學(xué)信息,也就是 encoder_outputs
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Decoder 的主要輸入為 decoder_input_ids (對應(yīng)文本) 和 encoder_outputs (對應(yīng)聲學(xué)信息,在翻譯任務(wù)中,對應(yīng)著源語言)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
通過上面,我們可以看到有 Encoder 部分和 Decoder 部分,分別對應(yīng)聲學(xué)特征的提取和文本產(chǎn)生部分。
在 WhisperForConditionalGeneration 類中,也有一個(gè)forward函數(shù),是對上面forward函數(shù)的封裝。
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1359
其中self.encoder的實(shí)現(xiàn)代碼位于 class WhisperEncoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L735
其中self.decoder的實(shí)現(xiàn)代碼位于 class WhisperDecoder(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L881
Generate函數(shù)
Generate函數(shù)入口:位于 class WhisperForConditionalGeneration(WhisperPreTrainedModel) 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455
此處只是調(diào)用的入口,具體的實(shí)現(xiàn)代碼位于 class GenerationMixin 類中:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
def generate() L1146
其中g(shù)enerate函數(shù)使用的greedy_search的實(shí)現(xiàn)位于:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L2164
下面,我們來進(jìn)一步了解 generate 的實(shí)現(xiàn)代碼,來看看如何對此代碼進(jìn)行修改。
入口代碼
Generate函數(shù)的入口位于: WhisperForConditionalGeneration類中的 def generate 函數(shù)
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/models/whisper/modeling_whisper.py#L1455
代碼的概要如下,從代碼中可以看到,這個(gè)函數(shù)主要是進(jìn)行了一些參數(shù)設(shè)置,具體的實(shí)現(xiàn)是調(diào)用了父類中的對應(yīng)函數(shù)來執(zhí)行的。
def generate()
# 參數(shù)設(shè)置部分
# 調(diào)用部分(此處調(diào)用了父類中的generate實(shí)現(xiàn))
return super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
**kwargs,
)
然后,我們可以逐級向上搜索其父類,可以看到
到此為止,我們就可以看到,具體的實(shí)現(xiàn)都在 GenerationMixin 類中。
Generate函數(shù)實(shí)現(xiàn)細(xì)節(jié)
下面,我們來看一下 GenerationMixin類中的 generate 函數(shù)的實(shí)現(xiàn)細(xì)節(jié)。
代碼位置:
https://github.com/huggingface/transformers/blob/v4.29.1/src/transformers/generation/utils.py#L1146
其代碼概要如下:
def generate(): L1146
# 根據(jù)解碼方式的不同,此函數(shù)中有最多14步的處理步驟,我們以greedy search為例
# 1. Handle generation_config
and kwargs that might update it, and validate the .generate()
call
# 2. Set generation parameters if not already defined
# 3. Define model inputs
# 4. Define other model kwargs
# 5. Prepare input_ids
which will be used for auto-regressive generation
# 6. Prepare max_length
depending on other stopping criteria.
# 7. determine generation mode
# 8. prepare distribution pre_processing samplers
# 9. prepare stopping criteria
# 10. go into different generation modes
# 11. run greedy search (L1515)
def greedy_search(): L2164
# 初始化,設(shè)置
# 循環(huán)處理
while True: # L2317
# prepare model inputs (下面函數(shù)的具體實(shí)現(xiàn)位于: modeling_whisper.py#L1627)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
# 這里是調(diào)用了 WhisperForConditionalGeneration 中的forward函數(shù)。這是因?yàn)?PyTorch 的 nn.Module 基類定義了一個(gè) __call__ 方法,當(dāng)你調(diào)用模型實(shí)例(即 self)時(shí),它會(huì)自動(dòng)調(diào)用這個(gè) __call__ 方法,而這個(gè) __call__ 方法又會(huì)調(diào)用 forward 方法。
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# 得到下一個(gè)token的logits
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution 得到其score
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax :使用argmax 獲取對應(yīng)的 tokens
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# 判斷是否結(jié)束search:
# if eos_token was found in one sentence, set sentence to finished
# stop if we exceed the maximum length
通過上面的代碼概要,我們就可以知道generate函數(shù)進(jìn)行了很多的設(shè)置以后,會(huì)調(diào)用 greedy_search() 函數(shù)來進(jìn)行文本產(chǎn)生的實(shí)際處理。
到此為止,我們就已經(jīng)對整個(gè)的代碼結(jié)構(gòu)了解了。
下面我們通過幾個(gè)問題,來回顧一下對代碼的理解。
針對上面的問題7,如果要對generate或者其他部分進(jìn)行修改,建議在 models/whisper的目錄下對父類函數(shù)進(jìn)行重構(gòu)。
比如,如果要對greedy_search功能進(jìn)行調(diào)整來實(shí)現(xiàn)一些獨(dú)特的功能時(shí),可以在modeling_whisper.py中重構(gòu) greedy_search(),具體做法可以是:
函數(shù)在子類中被重新實(shí)現(xiàn)之后,調(diào)用時(shí),將會(huì)優(yōu)先調(diào)用新重構(gòu)的函數(shù)。這樣既實(shí)現(xiàn)了自己獨(dú)特的功能,還不影響其他的模型的運(yùn)行。
文章轉(zhuǎn)載自: Transformers Generate 功能介紹