fanfuhan OpenCV 教學115 ~ opencv-115-KNN算法的使用 [手寫辨識]
fanfuhan OpenCV 教學115 ~ opencv-115-KNN算法的使用 [手寫辨識]
資料來源: https://fanfuhan.github.io/
https://fanfuhan.github.io/2019/05/24/opencv-115/
GITHUB:https://github.com/jash-git/fanfuhan_ML_OpenCV
OpenCV中機器學習模塊的最近鄰算法KNN,對使用KNN訓練好的XML文件,可以通過算法接口的負載方法加載成為KNN分類器,使用findNearest方法進行預測。
C++
#include <opencv2/opencv.hpp>
#include <iostream>
using namespace cv;
using namespace cv::ml;
using namespace std;
void knn_test(Mat& data, Mat& labels);
int main(int argc, char** argv) {
Mat data = imread("D:/projects/opencv_tutorial/data/images/digits.png");
Mat gray;
cvtColor(data, gray, COLOR_BGR2GRAY);
// 分割为5000个cells
Mat images = Mat::zeros(5000, 400, CV_8UC1);
Mat labels = Mat::zeros(5000, 1, CV_8UC1);
int index = 0;
Rect roi;
roi.x = 0;
roi.height = 1;
roi.width = 400;
for (int row = 0; row < 50; row++) {
int label = row / 5;
int offsety = row * 20;
for (int col = 0; col < 100; col++) {
int offsetx = col * 20;
Mat digit = Mat::zeros(Size(20, 20), CV_8UC1);
for (int sr = 0; sr < 20; sr++) {
for (int sc = 0; sc < 20; sc++) {
digit.at<uchar>(sr, sc) = gray.at<uchar>(sr + offsety, sc + offsetx);
}
}
Mat one_row = digit.reshape(1, 1);
printf("index : %d \n", index);
roi.y = index;
one_row.copyTo(images(roi));
labels.at<uchar>(index, 0) = label;
index++;
}
}
printf("load sample hand-writing data...\n");
imwrite("D:/result.png", images);
// 转换为浮点数
images.convertTo(images, CV_32FC1);
labels.convertTo(labels, CV_32SC1);
printf("load sample hand-writing data...\n");
// 开始KNN训练
printf("Start to knn train...\n");
Ptr<KNearest> knn = KNearest::create();
knn->setDefaultK(5);
knn->setIsClassifier(true);
Ptr<ml::TrainData> tdata = ml::TrainData::create(images, ml::ROW_SAMPLE, labels);
knn->train(tdata);
knn->save("D:/vcworkspaces/knn_knowledge.yml");
printf("Finished KNN...\n");
// 测试KNN
printf("start to test knn...\n");
knn_test(images, labels);
waitKey(0);
return true;
}
void knn_test(Mat& data, Mat& labels) {
// 加载KNN分类器
Ptr<ml::KNearest> knn = Algorithm::load<ml::KNearest>("D:/vcworkspaces/knn_knowledge.yml");
Mat result;
knn->findNearest(data, 5, result);
float count = 0;
for (int row = 0; row < result.rows; row++) {
int predict = result.at<float>(row, 0);
if (labels.at<int>(row, 0) == predict) {
count++;
}
}
printf("test acc of digit numbers : %.2f \n ", (count / result.rows));
// real test it
Mat t1 = imread("D:/images/knn_01.png", IMREAD_GRAYSCALE);
Mat t2 = imread("D:/images/knn_02.png", IMREAD_GRAYSCALE);
imshow("t1", t1);
imshow("t2", t2);
Mat m1, m2;
resize(t1, m1, Size(20, 20));
resize(t2, m2, Size(20, 20));
Mat testdata = Mat::zeros(2, 400, CV_8UC1);
Mat testlabels = Mat::zeros(2, 1, CV_32SC1);
Rect rect;
rect.x = 0;
rect.y = 0;
rect.height = 1;
rect.width = 400;
Mat one = m1.reshape(1, 1);
Mat two = m2.reshape(1, 1);
one.copyTo(testdata(rect));
rect.y = 1;
two.copyTo(testdata(rect));
testlabels.at<int>(0, 0) = 1;
testlabels.at<int>(1, 0) = 2;
testdata.convertTo(testdata, CV_32F);
Mat result2;
knn->findNearest(testdata, 5, result2);
for (int i = 0; i< result2.rows; i++) {
int predict = result2.at<float>(i, 0);
printf("knn t%d predict : %d, actual label :%d \n",(i+1), predict, testlabels.at<int>(i, 0));
}
}