MLOps 流水线中的持续训练





5.00/5 (2投票s)
在本文中,我们将深入研究持续训练的代码。
在本系列文章中,我们将引导您完成将 CI/CD 应用于 AI 任务的过程。您最终将得到一个符合 Google MLOps 成熟度模型中第 2 级要求的管道。我们假设您对 Python、深度学习、Docker、DevOps 和 Flask 有一定的了解。
在上一篇文章中,我们讨论了 CI/CD MLOps 管道的模型创建、自动调整和通知。 在本文中,我们将看看在我们的 ML 管道中实现持续训练所需的代码。下图显示了我们在项目过程中的位置。
请记住,只要有推送到 数据集 存储库的操作,就会执行此工作流程。 该脚本将检查生产或测试注册表中是否有模型。 然后它将重新训练它找到的模型。 这是我们的应用程序文件结构
我们以精简版显示代码文件。 完整的版本请参阅 代码仓库。
data_utils.py
data_utils.py 代码与之前完全相同。它从存储库中加载数据,转换数据,并将生成的模型保存到 GCS。唯一的区别是现在该文件包含两个附加函数。其中一个检查测试注册表中是否存在模型,另一个加载该模型。
从上一篇文章中获取 data_utils.py 文件,并将这些函数添加到文件末尾
def previous_model(bucket_name,model_type,model_filename):
try:
storage_client = storage.Client() #if running on GCP
bucket = storage_client.bucket(bucket_name)
status = storage.Blob(bucket = bucket, name = '{}/{}'.format(model_type,model_filename)).exists(storage_client)
return status,None
except Exception as e:
print('Something went wrong when trying to check if previous model exists GCS bucket. Exception: ',flush=True)
return None,e
def load_model(bucket_name,model_type,model_filename):
try:
storage_client = storage.Client() #if running on GCP
bucket = storage_client.bucket(bucket_name)
blob1 = bucket.blob('{}/{}'.format(model_type,model_filename))
blob1.download_to_filename('/root/'+str(model_filename))
return True,None
except Exception as e:
print('Something went wrong when trying to load previous model from GCS bucket. Exception: '+str(e),flush=True)
return False,e
email_notifications.py
email_notifications.py 代码与之前基本相同,只是现在它发送不同的消息。
import smtplib
import os
# Variables definition
sender = ‘example@gmail.com’
receiver = ['svirahonda@gmail.com'] #replace this by the owner's email address
smtp_provider = 'smtp.gmail.com' #replace this by your STMP provider
smtp_port = 587
smtp_account = ‘example@gmail.com’
smtp_password = ‘your_password’
def training_result(result,accuracy):
if result == 'old_evaluation_prod':
message = "A data push has been detected. Old model from production has reached more than 0.85 of accuracy. There's no need to retrain it."
if result == 'retrain_prod':
message = 'A data push has been detected. Old model from production has been retrained and has reached more than 0.85 of accuracy. It has been saved into /testing.'
if result == 'old_evaluation_test':
message = "A data push has been detected,. Old model from /testing has reached more than 0.85 of accuracy. There's no need to retrain it."
if result == 'retrain_test':
message = 'A data push has been detected. Old model from /testing has been retrained and reached more than 0.85 of accuracy. It has been saved into /testing.'
if result == 'poor_metrics':
message = 'A data push has been detected. Old models from /production and /testing have been retrained but none of them reached more than 0.85 of accuracy.’
if result == 'not_found':
message = 'No previous models were found at GCS. '
message = 'Subject: {}\n\n{}'.format('An automatic training job has ended recently', message)
try:
server = smtplib.SMTP(smtp_provider,smtp_port)
server.starttls()
server.login(smtp_account,smtp_password)
server.sendmail(sender, receiver, message)
return
except Exception as e:
print('Something went wrong. Unable to send email: 'str(e),flush=True)
return
def exception(e_message):
try:
message = 'Subject: {}\n\n{}'.format('An automatic training job has failed recently', e_message)
server = smtplib.SMTP(smtp_provider,smtp_port)
server.starttls()
server.login(smtp_account,smtp_password)
server.sendmail(sender, receiver, message)
return
except Exception as e:
print('Something went wrong. Unable to send email.',flush=True)
print('Exception: ',e)
return
task.py
task.py 代码协调上述文件的执行。 与之前一样,它检查主机上是否存在 GPU,初始化 GPU(如果找到),处理传递给代码执行的参数,并加载数据。 并开始重新训练。 一旦重新训练结束,代码将把生成的模型推送到测试注册表并通知产品负责人。 让我们看看代码是什么样子的
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
import argparse
import data_utils
import data_utils, email_notifications
import sys
import os
from google.cloud import storage
import datetime
# general variables declaration
model_name = 'best_model.hdf5'
def initialize_gpu():
if len(tf.config.experimental.list_physical_devices('GPU')) > 0:
tf.config.set_soft_device_placement(True)
tf.debugging.set_log_device_placement(True)
return
def start_training(args):
# Loading splitted data
X_train, X_test, y_train, y_test = data_utils.load_data(args)
# Initializing GPU if available (if available)
initialize_gpu()
# Checking if there's any model saved at testing or production folders in GCS
model_gcs_prod = data_utils.previous_model(args.bucket_name,'production',model_name)
model_gcs_test = data_utils.previous_model(args.bucket_name,'testing',model_name)
# If any model exists at production, load it, test it on data and if it doesn't reach good metric then retrain it and save it to testing folder
if model_gcs_prod[0] == True:
train_prod_model(X_train, X_test, y_train, y_test,args)
if model_gcs_prod[0] == False:
if model_gcs_test[0] == True:
train_test_model(X_train, X_test, y_train, y_test,args)
if model_gcs_test[0] == False:
email_notifications.training_result('not_found',' ')
sys.exit(1)
if model_gcs_test[0] == None:
email_notifications.exception('Something went wrong when trying to check if old testing model exists. Exception: '+model_gcs_test[1]+'. Aborting automatic training.')
sys.exit(1)
if model_gcs_prod[0] == None:
email_notifications.exception('Something went wrong when trying to check if old production model exists. Exception: '+model_gcs_prod[1]+'. Aborting automatic training.')
sys.exit(1)
def train_prod_model(X_train, X_test, y_train, y_test,args):
model_gcs_prod = data_utils.load_model(args.bucket_name,'production',model_name)
if model_gcs_prod[0] == True:
try:
cnn = load_model(model_name)
model_loss, model_acc = cnn.evaluate(X_test, y_test,verbose=2)
if model_acc > 0.85:
saved_ok = data_utils.save_model(args.bucket_name,model_name)
if saved_ok[0] == True:
email_notifications.training_result('old_evaluation_prod', model_acc)
sys.exit(0)
else:
email_notifications.exception(saved_ok[1])
sys.exit(1)
else:
cnn = load_model(model_name)
checkpoint = ModelCheckpoint(model_name, monitor='val_loss', verbose=1, save_best_only=True, mode='auto', save_freq="epoch")
cnn.fit(X_train, y_train, epochs=args.epochs,validation_data=(X_test, y_test), callbacks=[checkpoint])
model_loss, model_acc = cnn.evaluate(X_test, y_test,verbose=2)
if model_acc > 0.85:
saved_ok = data_utils.save_model(args.bucket_name,model_name)
if saved_ok[0] == True:
email_notifications.training_result('retrain_prod',model_acc)
sys.exit(0)
else:
email_notifications.exception(saved_ok[1])
sys.exit(1)
else:
return
except Exception as e:
email_notifications.exception('Something went wrong when trying to retrain old production model. Exception: '+str(e))
sys.exit(1)
else:
email_notifications.exception('Something went wrong when trying to load old production model. Exception: '+str(model_gcs_prod[1]))
sys.exit(1)
def train_test_model(X_train, X_test, y_train, y_test,args):
model_gcs_test = data_utils.load_model(args.bucket_name,'testing',model_name)
if model_gcs_test[0] == True:
try:
cnn = load_model(model_name)
model_loss, model_acc = cnn.evaluate(X_test, y_test,verbose=2)
if model_acc > 0.85: # Nothing to do, keep the model the way it is.
email_notifications.training_result('old_evaluation_test',model_acc)
sys.exit(0)
else:
cnn = load_model(model_name)
checkpoint = ModelCheckpoint(model_name, monitor='val_loss', verbose=1, save_best_only=True, mode='auto', save_freq="epoch")
cnn.fit(X_train, y_train, epochs=args.epochs, validation_data=(X_test, y_test), callbacks=[checkpoint])
model_loss, model_acc = cnn.evaluate(X_test, y_test,verbose=2)
if model_acc > 0.85:
saved_ok = data_utils.save_model(args.bucket_name,model_name)
if saved_ok[0] == True:
email_notifications.training_result('retrain_test',model_acc)
sys.exit(0)
else:
email_notifications.exception(saved_ok[1])
sys.exit(1)
else:
email_notifications.training_result('poor_metrics',model_acc)
sys.exit(1)
except Exception as e:
email_notifications.exception('Something went wrong when trying to retrain old testing model. Exception: '+str(e))
sys.exit(1)
else:
email_notifications.exception('Something went wrong when trying to load old testing model. Exception: '+str(model_gcs_test[1]))
sys.exit(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--bucket-name', type=str, default = 'automatictrainingcicd-aiplatform',help='GCP bucket name')
parser.add_argument('--epochs', type=int, default=2, help='Epochs number')
args = parser.parse_args()
return args
def main():
args = get_args()
start_training(args)
if __name__ == '__main__':
main()
Dockerfile
Dockerfile 处理 Docker 容器构建。 它从其存储库中加载数据集,从其存储库中加载代码文件,并定义容器执行的起始位置
FROM gcr.io/deeplearning-platform-release/tf2-cpu.2-0 WORKDIR /root RUN pip install pandas numpy google-cloud-storage scikit-learn opencv-python RUN apt-get update; apt-get install git -y; apt-get install -y libgl1-mesa-dev ADD "https://www.random.org/cgi-bin/randbyte?nbytes=10&format=h" skipcache RUN git clone https://github.com/sergiovirahonda/AutomaticTraining-Dataset.git ADD "https://www.random.org/cgi-bin/randbyte?nbytes=10&format=h" skipcache RUN git clone https://github.com/sergiovirahonda/AutomaticTraining-DataCommit.git RUN mv /root/AutomaticTraining-DataCommit/task.py /root RUN mv /root/AutomaticTraining-DataCommit/data_utils.py /root RUN mv /root/AutomaticTraining-DataCommit/email_notifications.py /root ENTRYPOINT ["python","task.py"]
您会注意到代码中的 ADD 指令。 这些强制构建过程始终拉取存储库内容——而不是在本地注册表中缓存它们——在构建容器时。
在本地构建并运行容器后,您应该能够使用新收集的数据重新训练您的模型。 我们还没有讨论过如何触发此作业。 我们将在稍后讨论 GitHub Webhook 和 Jenkins 时介绍此步骤,但本质上,Jenkins 可以在检测到相应存储库中推送时触发此工作流程。 通过 Webhook(在存储库本身中配置的一种方法)来检测推送。
在流程结束时,您应该会找到存储在 GCS 测试注册表中的模型文件。
后续步骤
在下一篇文章中,我们将开发一个模型单元测试容器。敬请关注!