在scikit-learn中保存分类器到磁盘
技术背景
在机器学习中,训练一个分类器可能需要大量的时间和计算资源。为了避免重复训练,我们可以将训练好的分类器保存到磁盘,在需要使用时再加载。在scikit-learn中,有多种方法可以实现这一目的。
实现步骤
使用pickle模块
pickle
是Python的标准模块,用于对象的序列化和反序列化。可以将训练好的分类器对象保存为二进制文件,后续再加载使用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import pickle from sklearn import datasets from sklearn.naive_bayes import GaussianNB
iris = datasets.load_iris()
gnb = GaussianNB()
gnb.fit(iris.data, iris.target)
with open('my_dumped_classifier.pkl', 'wb') as fid: pickle.dump(gnb, fid)
with open('my_dumped_classifier.pkl', 'rb') as fid: gnb_loaded = pickle.load(fid)
y_pred = gnb_loaded.predict(iris.data)
|
使用joblib模块
joblib
是一个用于Python的轻量级管道化工具,在处理大型NumPy数组时比 pickle
更高效。它也被集成在scikit-learn中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| import joblib from sklearn.datasets import load_digits from sklearn.linear_model import SGDClassifier
digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)
filename = 'digits_classifier.joblib.pkl' _ = joblib.dump(clf, filename, compress=9)
clf2 = joblib.load(filename)
|
最佳实践
- 保存额外信息:为了在未来重建类似的模型,除了保存分类器本身,还应保存训练数据、生成模型的Python源代码、scikit-learn及其依赖项的版本、训练数据上的交叉验证分数等信息。
- 处理文本分类:在文本分类中,除了保存分类器,还需要保存向量化器,以便在未来对输入进行向量化。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| import pickle from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.naive_bayes import MultinomialNB
corpus = ["This is the first document.", "This document is the second document."] labels = [0, 1]
vectorizer = TfidfVectorizer() X = vectorizer.fit_transform(corpus)
clf = MultinomialNB() clf.fit(X, labels)
with open('model.pkl', 'wb') as fout: pickle.dump((vectorizer, clf), fout)
with open('model.pkl', 'rb') as fin: vectorizer, clf = pickle.load(fin)
new_samples = ["This is a new document."] X_new = vectorizer.transform(new_samples) X_new_preds = clf.predict(X_new)
|
常见问题
sklearn.externals.joblib
已弃用:从scikit-learn 0.21版本开始,sklearn.externals.joblib
已被弃用,并将在0.23版本中移除。建议直接从 joblib
导入功能。
1 2
| pip install joblib import joblib
|
- 兼容性问题:使用
pickle
保存的模型可能在不同版本的Python或scikit-learn中存在兼容性问题。因此,建议保存模型时记录相关版本信息。 - IPython问题:如果使用IPython,不要使用
--pylab
命令行标志或 %pylab
魔法,因为隐式命名空间重载可能会破坏 pickle
过程。建议使用显式导入和 %matplotlib inline
魔法。