fanfuhan OpenCV 教學116 ~ opencv-116-決策樹算法介紹與使用
fanfuhan OpenCV 教學116 ~ opencv-116-決策樹算法介紹與使用
資料來源: https://fanfuhan.github.io/
https://fanfuhan.github.io/2019/05/25/opencv-116/
GITHUB:https://github.com/jash-git/fanfuhan_ML_OpenCV
OpenCV中機器學習模塊的決策樹算法分為兩個類別,一個是隨機森林(Random Trees),另外一個強化分類(Boosting分類)
C++
#include <opencv2/opencv.hpp> #include <iostream> using namespace cv; using namespace cv::ml; using namespace std; int main(int argc, char** argv) { Mat data = imread("D:/projects/opencv_tutorial/data/images/digits.png"); Mat gray; cvtColor(data, gray, COLOR_BGR2GRAY); // 分割为5000个cells Mat images = Mat::zeros(5000, 400, CV_8UC1); Mat labels = Mat::zeros(5000, 1, CV_8UC1); int index = 0; Rect roi; roi.x = 0; roi.height = 1; roi.width = 400; for (int row = 0; row < 50; row++) { int label = row / 5; int offsety = row * 20; for (int col = 0; col < 100; col++) { int offsetx = col * 20; Mat digit = Mat::zeros(Size(20, 20), CV_8UC1); for (int sr = 0; sr < 20; sr++) { for (int sc = 0; sc < 20; sc++) { digit.at<uchar>(sr, sc) = gray.at<uchar>(sr + offsety, sc + offsetx); } } Mat one_row = digit.reshape(1, 1); printf("index : %d \n", index); roi.y = index; one_row.copyTo(images(roi)); labels.at<uchar>(index, 0) = label; index++; } } printf("load sample hand-writing data...\n"); imwrite("D:/result.png", images); // 转换为浮点数 images.convertTo(images, CV_32FC1); labels.convertTo(labels, CV_32SC1); printf("load sample hand-writing data...\n"); // 开始训练 printf("Start to Random Trees train...\n"); Ptr<RTrees> model = RTrees::create(); /*model->setMaxDepth(10); model->setMinSampleCount(10); model->setRegressionAccuracy(0); model->setUseSurrogates(false); model->setMaxCategories(15); model->setPriors(Mat()); model->setCalculateVarImportance(true); model->setActiveVarCount(4); */ TermCriteria tc = TermCriteria(TermCriteria::MAX_ITER + TermCriteria::EPS, 100, 0.01); model->setTermCriteria(tc); Ptr<ml::TrainData> tdata = ml::TrainData::create(images, ml::ROW_SAMPLE, labels); model->train(tdata); model->save("D:/vcworkspaces/rtrees_knowledge.yml"); printf("Finished Random trees...\n"); waitKey(0); return true; }
Python
""" 决策树算法 介绍与使用 """ import cv2 as cv import numpy as np # 读取数据 img = cv.imread('images/digits.png') gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY) cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)] x = np.array(cells) # 创建训练与测试数据 train = x[:, :50].reshape(-1, 400).astype(np.float32) test = x[:, 50:100].reshape(-1, 400).astype(np.float32) k = np.arange(10) train_labels = np.repeat(k, 250)[:, np.newaxis] test_labels = train_labels.copy() # 训练随机树 dt = cv.ml.RTrees_create() dt.train(train, cv.ml.ROW_SAMPLE, train_labels) retval, results = dt.predict(test) # 计算准确率 matches = results == test_labels correct = np.count_nonzero(matches) accuracy = correct / results.size print(accuracy) cv.waitKey(0) cv.destroyAllWindows()