博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习-训练模型的保存与恢复(sklearn)
阅读量:4041 次
发布时间:2019-05-24

本文共 2255 字,大约阅读时间需要 7 分钟。

在做模型训练的时候,尤其是在训练集上做交叉验证,通常想要将模型保存下来,然后放到独立的测试集上测试,下面介绍的是Python中训练模型的保存和再使用。

 

模型保存(pickle模块和joblib模块)

使用pickle模块或者sklearn内部的joblib

一、使用pickle模块

    from sklearn import svm

    from sklearn import datasets
    clf=svm.SVC()
    iris=datasets.load_iris()
    X,y=iris.data,iris.target
    clf.fit(X,y)
     
    import pickle
    s=pickle.dumps(clf)
    f=open('svm.txt','w')
    f.write(s)
    f.close()
    f2=open('svm.txt','r')
    s2=f2.read()
    clf2=pickle.loads(s2)
    clf2.score(X,y)

二、使用joblib

joblib更适合大数据量的模型,且只能往硬盘存储,不能往字符串存储

    from sklearn.externals import joblib

    joblib.dump(clf,'filename.pkl')
    clf=joblib.load('filename.pkl')

三、具体例子如下:

scikit-learn已经有了模型持久化的操作,导入joblib即可

from sklearn.externals import joblib

模型保存

通过joblib的dump可以将模型保存到本地,clf是训练的分类器

    from sklearn.linear_model import LogisticRegression

    from sklearn.svm import SVC
    from sklearn.externals import joblib
    def test_save_model(self):
        model_save_path = "./model_save/"
        train_X = [[0, 0], [1, 1]]
        train_y = [0, 1]
        print "Start LR method ..."
        print('LR Train classifier...')
        clf = LogisticRegression()
        clf.fit(train_X, train_y)
        print "LR Model save..."
        save_path_name = model_save_path + "lr_" + "train_model.m"
        self.is_exist(model_save_path, save_path_name)
        joblib.dump(clf, save_path_name)
        clf = joblib.load(save_path_name)
        print('LR Predict...')
        pred = clf.predict_proba(train_X)
        submit_csv_name = model_save_path + "lr" + '_submission.csv'
        self.make_submission(pred[:, 0], submit_csv_name)
     
        print "Start SVM method ..."
        # 训练
        print('SVM Train classifier...')
        from sklearn import svm
        clf = svm.SVC()
        clf.fit(train_X, train_y)
        # 保存
        print "SVM Model save..."
        save_path_name=model_save_path+"svm_"+"train_model.m"
        self.is_exist(model_save_path,save_path_name)
        joblib.dump(clf, save_path_name)
        clf = joblib.load(save_path_name)
        # 预测
        print('SVM Predict...')
        pred=clf.predict(train_X)
        submit_csv_name = model_save_path+"svm" + '_submission.csv'
        self.make_submission(pred, submit_csv_name)
     
        train_X = [[0, 1], [1, 1]]
        train_y = [0, 1]
        print clf.score(train_X, train_y, sample_weight=None)

模型从本地调回

clf = joblib.load("train_model.m")

通过joblib的load方法,加载保存的模型。

然后就可以在测试集上测试了

clf.predict(test_X) #此处test_X为特征集

————————————————
版权声明:本文为CSDN博主「Data_IT_Farmer」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/helloxiaozhe/article/details/80658438

你可能感兴趣的文章
学习设计模式(3)——单例模式和类的成员函数中的静态变量的作用域
查看>>
自然计算时间复杂度杂谈
查看>>
当前主要目标和工作
查看>>
使用 Springboot 对 Kettle 进行调度开发
查看>>
一文看清HBase的使用场景
查看>>
解析zookeeper的工作流程
查看>>
搞定Java面试中的数据结构问题
查看>>
慢慢欣赏linux make uImage流程
查看>>
linux内核学习(7)脱胎换骨解压缩的内核
查看>>
以太网基础知识
查看>>
慢慢欣赏linux 内核模块引用
查看>>
kprobe学习
查看>>
慢慢欣赏linux phy驱动初始化2
查看>>
慢慢欣赏linux CPU占用率学习
查看>>
2020年终总结
查看>>
Homebrew指令集
查看>>
React Native(一):搭建开发环境、出Hello World
查看>>
React Native(二):属性、状态
查看>>
JSX使用总结
查看>>
React Native(四):布局(使用Flexbox)
查看>>