在scikit-learn中保存分类器到磁盘

在scikit-learn中保存分类器到磁盘

技术背景

在机器学习中,训练一个分类器可能需要大量的时间和计算资源。为了避免重复训练,我们可以将训练好的分类器保存到磁盘,在需要使用时再加载。在scikit-learn中,有多种方法可以实现这一目的。

实现步骤

使用pickle模块

pickle 是Python的标准模块,用于对象的序列化和反序列化。可以将训练好的分类器对象保存为二进制文件,后续再加载使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import pickle
from sklearn import datasets
from sklearn.naive_bayes import GaussianNB

# 加载数据集
iris = datasets.load_iris()
# 创建分类器
gnb = GaussianNB()
# 训练分类器
gnb.fit(iris.data, iris.target)

# 保存分类器到文件
with open('my_dumped_classifier.pkl', 'wb') as fid:
pickle.dump(gnb, fid)

# 从文件中加载分类器
with open('my_dumped_classifier.pkl', 'rb') as fid:
gnb_loaded = pickle.load(fid)

# 使用加载的分类器进行预测
y_pred = gnb_loaded.predict(iris.data)

使用joblib模块

joblib 是一个用于Python的轻量级管道化工具,在处理大型NumPy数组时比 pickle 更高效。它也被集成在scikit-learn中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier

# 加载数据集
digits = load_digits()
# 创建分类器
clf = SGDClassifier().fit(digits.data, digits.target)

# 保存分类器到文件
filename = 'digits_classifier.joblib.pkl'
_ = joblib.dump(clf, filename, compress=9)

# 从文件中加载分类器
clf2 = joblib.load(filename)

最佳实践

  • 保存额外信息:为了在未来重建类似的模型,除了保存分类器本身,还应保存训练数据、生成模型的Python源代码、scikit-learn及其依赖项的版本、训练数据上的交叉验证分数等信息。
  • 处理文本分类:在文本分类中,除了保存分类器,还需要保存向量化器,以便在未来对输入进行向量化。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import pickle
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB

# 示例数据
corpus = ["This is the first document.", "This document is the second document."]
labels = [0, 1]

# 创建向量化器
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(corpus)

# 创建分类器
clf = MultinomialNB()
clf.fit(X, labels)

# 保存向量化器和分类器
with open('model.pkl', 'wb') as fout:
pickle.dump((vectorizer, clf), fout)

# 加载向量化器和分类器
with open('model.pkl', 'rb') as fin:
vectorizer, clf = pickle.load(fin)

# 新样本预测
new_samples = ["This is a new document."]
X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

常见问题

  • sklearn.externals.joblib 已弃用:从scikit-learn 0.21版本开始,sklearn.externals.joblib 已被弃用,并将在0.23版本中移除。建议直接从 joblib 导入功能。
1
2
pip install joblib
import joblib
  • 兼容性问题:使用 pickle 保存的模型可能在不同版本的Python或scikit-learn中存在兼容性问题。因此,建议保存模型时记录相关版本信息。
  • IPython问题:如果使用IPython,不要使用 --pylab 命令行标志或 %pylab 魔法,因为隐式命名空间重载可能会破坏 pickle 过程。建议使用显式导入和 %matplotlib inline 魔法。

在scikit-learn中保存分类器到磁盘
https://119291.xyz/posts/2025-04-22.save-classifier-to-disk-in-scikit-learn/
作者
ww
发布于
2025年4月22日
许可协议