二、預(yù)創(chuàng)建的 Estimator

借助預(yù)創(chuàng)建的 Estimator,您能夠在比基本 TensorFlow API 高級很多的概念層面上進行操作。由于 Estimator 會為您處理所有 “ 管道工作 ”,因此您不必再為創(chuàng)建計算圖或會話而操心。也就是說,預(yù)創(chuàng)建的 Estimator 會為您創(chuàng)建和管理 Graph(https://tensorflow.google.cn/api_docs/python/tf/Graph?hl=zh-CN)和 Session(https://tensorflow.google.cn/api_docs/python/tf/Session?hl=zh-CN)對象。此外,借助預(yù)創(chuàng)建的 Estimator,您只需稍微更改下代碼,就可以嘗試不同的模型架構(gòu)。例如,DNNClassifier 是一個預(yù)創(chuàng)建的 Estimator 類(https://tensorflow.google.cn/api_docs/python/tf/estimator/DNNClassifier?hl=zh-CN),它根據(jù)密集的前饋神經(jīng)網(wǎng)絡(luò)訓(xùn)練分類模型。

1、預(yù)創(chuàng)建的 Estimator 程序的結(jié)構(gòu)

依賴預(yù)創(chuàng)建的 Estimator 的 TensorFlow 程序通常包含下列四個步驟:

編寫一個或多個數(shù)據(jù)集導(dǎo)入函數(shù)。

例如,您可以創(chuàng)建一個函數(shù)來導(dǎo)入訓(xùn)練集,并創(chuàng)建另一個函數(shù)來導(dǎo)入測試集。每個數(shù)據(jù)集導(dǎo)入函數(shù)都必須返回兩個對象 :

例如,以下代碼展示了輸入函數(shù)的基本框架:

def input_fn(dataset):

... # manipulate dataset, extracting the feature dict and the label
return feature_dict, label

(要了解完整的詳細信息,請參閱導(dǎo)入數(shù)據(jù)(https://tensorflow.google.cn/guide/datasets?hl=zh-CN)。)

定義特征列。

每個 tf.feature_column(https://tensorflow.google.cn/api_docs/python/tf/feature_column?hl=zh-CN)都標識了特征名稱、特征類型和任何輸入預(yù)處理操作。例如,以下代碼段創(chuàng)建了三個存儲整數(shù)或浮點數(shù)據(jù)的特征列。前兩個特征列僅標識了特征的名稱和類型。第三個特征列還指定了一個 lambda,該程序?qū)⒄{(diào)用此 lambda 來調(diào)節(jié)原始數(shù)據(jù):

# Define three numeric feature columns.

population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education',
normalizer_fn=lambda x: x - global_education_mean)

2、實例化相關(guān)的預(yù)創(chuàng)建的 Estimator。 

例如,下面是對名為 LinearClassifier 的預(yù)創(chuàng)建 Estimator 進行實例化的示例代碼:

# Instantiate an estimator, passing the feature columns.

estimator = tf.estimator.LinearClassifier(
feature_columns=[population, crime_rate, median_education],
)

3、調(diào)用訓(xùn)練、評估或推理方法。

例如,所有 Estimator 都提供訓(xùn)練模型的 train 方法。

# my_training_set is the function created in Step 1

estimator.train(input_fn=my_training_set, steps=2000)

4、預(yù)創(chuàng)建的 Estimator 的優(yōu)勢

預(yù)創(chuàng)建的 Estimator 會編碼最佳做法,從而具有下列優(yōu)勢:

如果您不使用預(yù)創(chuàng)建的 Estimator,則必須自行實現(xiàn)上述功能。

5、利用Estimator如何把python庫導(dǎo)出

利用Estimator如何把Python庫導(dǎo)出,可以通過以下幾種方式實現(xiàn):

  1. 使用tf.estimator.export
    TensorFlow提供了tf.estimator.export模塊,它包含了一系列用于導(dǎo)出Estimator的類和函數(shù)。你可以使用tf.estimator.export.build_parsing_serving_input_receiver_fntf.estimator.export.build_raw_serving_input_receiver_fn來構(gòu)建輸入接收函數(shù),然后使用tf.estimator.Estimator.export_saved_model方法將模型導(dǎo)出為SavedModel格式。
  2. 導(dǎo)出為SavedModel格式
    SavedModel是TensorFlow 2.0推薦的模型導(dǎo)出格式,它包含了完整的TensorFlow程序,包括權(quán)重和計算圖。使用tf.saved_model.save函數(shù),你可以將模型保存到指定的目錄。在需要載入SavedModel文件時,使用tf.saved_model.load函數(shù)即可。
  3. 使用tf.keras.estimator.model_to_estimator
    如果你有一個Keras模型,可以使用tf.keras.estimator.model_to_estimator將Keras模型轉(zhuǎn)換為Estimator,這樣你的Keras模型就可以利用Estimator的優(yōu)勢,例如分布式訓(xùn)練。
  4. 使用tf.estimator.Exporter
    tf.estimator.Exporter是一個表示模型導(dǎo)出類型的類,它允許你將Estimator導(dǎo)出為特定格式。通過實現(xiàn)export方法,你可以將Estimator導(dǎo)出為不同格式,包括SavedModel。
  5. 導(dǎo)出檢查點(Checkpoints)
    在TensorFlow 1.x中,你可以使用tf.train.Saver來保存和恢復(fù)模型的檢查點。檢查點包含了模型的結(jié)構(gòu)和參數(shù)權(quán)重,可以通過saver.savesaver.restore方法來導(dǎo)出和加載。
  6. 導(dǎo)出為Pickle文件
    對于使用Scikit-learn等庫訓(xùn)練的模型,可以使用Pickle庫將模型導(dǎo)出為Pickle文件。使用pickle.dump函數(shù)將模型保存到文件,使用pickle.load函數(shù)加載模型。

這些方法提供了不同的方式來將Estimator和Python庫導(dǎo)出,可以根據(jù)具體的需求和環(huán)境選擇合適的導(dǎo)出方式。

三、自定義 Estimator

每個 Estimator(無論是預(yù)創(chuàng)建還是自定義)的核心都是其模型函數(shù),這是一種為訓(xùn)練、評估和預(yù)測構(gòu)建圖的方法。如果您使用預(yù)創(chuàng)建的 Estimator,則有人已經(jīng)實現(xiàn)了模型函數(shù)。如果您使用自定義 Estimator,則必須自行編寫模型函數(shù)。隨附文檔介紹了如何編寫模型函數(shù)(https://tensorflow.google.cn/guide/custom_estimators?hl=zh-CN)。

四、推薦的工作流程

我們推薦以下工作流程:

  1. 假設(shè)存在合適的預(yù)創(chuàng)建的 Estimator,使用它構(gòu)建第一個模型并使用其結(jié)果確定基準。
  2. 使用此預(yù)創(chuàng)建的 Estimator 構(gòu)建和測試整體管道,包括數(shù)據(jù)的完整性和可靠性。
  3. 如果存在其他合適的預(yù)創(chuàng)建的 Estimator,則運行實驗來確定哪個預(yù)創(chuàng)建的 Estimator 效果最好。
  4. 可以通過構(gòu)建自定義 Estimator 進一步改進模型。

五、從 Keras 模型創(chuàng)建 Estimator

您可以將現(xiàn)有的 Keras 模型轉(zhuǎn)換為 Estimator。這樣做之后,Keras 模型就可以利用 Estimator 的優(yōu)勢,例如分布式訓(xùn)練。調(diào)用

 tf.keras.estimator.model_to_estimator(https://tensorflow.google.cn/api_docs/python/tf/keras/estimator/model_to_estimator?hl=zh-CN),如下例所示:

# Instantiate a Keras inception v3 model.

keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)

# Treat the derived Estimator as you would with any other Estimator.
# First, recover the input name(s) of Keras model, so we can use them as the
# feature column name(s) of the Estimator input function:
keras_inception_v3.input_names # print out: ['input_1']
# Once we have the input name(s), we can create the input function, for example,
# for input(s) in the format of numpy ndarray:
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"input_1": train_data},
y=train_labels,
num_epochs=1,
shuffle=False)
# To train, we call Estimator's train function:
est_inception_v3.train(input_fn=train_input_fn, steps=2000)

請注意,Keras Estimator 的特征列名稱和標簽來自經(jīng)過編譯的對應(yīng) Keras 模型。例如,上面的 train_input_fn 的輸入鍵名稱可以從 keras_inception_v3.input_names 獲得;同樣,預(yù)測的輸出名稱可以從 keras_inception_v3.output_names 獲得。

要了解詳情,請參閱 tf.keras.estimator.model_to_estimator 的文檔(https://tensorflow.google.cn/api_docs/python/tf/keras/estimator/model_to_estimator?hl=zh-CN)。

文章轉(zhuǎn)自微信公眾號@TensorFlow

上一篇:

從人臉識別到機器翻譯:58個超有用的機器學(xué)習(xí)和預(yù)測API

下一篇:

將 API 從 MySQL 遷移到 AWS OpenSearch 后,我們將響應(yīng)時間提高了 1000 倍
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊

多API并行試用

數(shù)據(jù)驅(qū)動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費