fanfuhan OpenCV 教學115 ~ opencv-115-KNN算法的使用 [手寫辨識]

fanfuhan OpenCV 教學115 ~ opencv-115-KNN算法的使用 [手寫辨識]

fanfuhan OpenCV 教學115 ~ opencv-115-KNN算法的使用 [手寫辨識]


資料來源: https://fanfuhan.github.io/

https://fanfuhan.github.io/2019/05/24/opencv-115/

GITHUB:https://github.com/jash-git/fanfuhan_ML_OpenCV


OpenCV中機器學習模塊的最近鄰算法KNN,對使用KNN訓練好的XML文件,可以通過算法接口的負載方法加載成為KNN分類器,使用findNearest方法進行預測。


C++

#include <opencv2/opencv.hpp>
#include <iostream>

using namespace cv;
using namespace cv::ml;
using namespace std;

void knn_test(Mat& data, Mat& labels);
int main(int argc, char** argv) {
	Mat data = imread("D:/projects/opencv_tutorial/data/images/digits.png");
	Mat gray;
	cvtColor(data, gray, COLOR_BGR2GRAY);

	// 分割为5000个cells
	Mat images = Mat::zeros(5000, 400, CV_8UC1);
	Mat labels = Mat::zeros(5000, 1, CV_8UC1);

	int index = 0;
	Rect roi;
	roi.x = 0;
	roi.height = 1;
	roi.width = 400;
	for (int row = 0; row < 50; row++) {
		int label = row / 5;
		int offsety = row * 20;
		for (int col = 0; col < 100; col++) {
			int offsetx = col * 20;
			Mat digit = Mat::zeros(Size(20, 20), CV_8UC1);
			for (int sr = 0; sr < 20; sr++) {
				for (int sc = 0; sc < 20; sc++) {
					digit.at<uchar>(sr, sc) = gray.at<uchar>(sr + offsety, sc + offsetx);
				}
			}
			Mat one_row = digit.reshape(1, 1);
			printf("index : %d \n", index);
			roi.y = index;
			one_row.copyTo(images(roi));
			labels.at<uchar>(index, 0) = label;
			index++;
		}
	}
	printf("load sample hand-writing data...\n");
	imwrite("D:/result.png", images);

	// 转换为浮点数
	images.convertTo(images, CV_32FC1);
	labels.convertTo(labels, CV_32SC1);

	printf("load sample hand-writing data...\n");


	// 开始KNN训练
	printf("Start to knn train...\n");
	Ptr<KNearest> knn = KNearest::create();
	knn->setDefaultK(5);
	knn->setIsClassifier(true);
	Ptr<ml::TrainData> tdata = ml::TrainData::create(images, ml::ROW_SAMPLE, labels);
	knn->train(tdata);
	knn->save("D:/vcworkspaces/knn_knowledge.yml");
	printf("Finished KNN...\n");

	// 测试KNN
	printf("start to test knn...\n");
	knn_test(images, labels);

	waitKey(0);
	return true;
}

void knn_test(Mat& data, Mat& labels) {
	// 加载KNN分类器
	Ptr<ml::KNearest> knn = Algorithm::load<ml::KNearest>("D:/vcworkspaces/knn_knowledge.yml"); 
	Mat result;
	knn->findNearest(data, 5, result);
	float count = 0;
	for (int row = 0; row < result.rows; row++) {
		int predict = result.at<float>(row, 0);
		if (labels.at<int>(row, 0) == predict) {
			count++;
		}
	}
	printf("test acc of digit numbers : %.2f \n ", (count / result.rows));
	
	// real test it
	Mat t1 = imread("D:/images/knn_01.png", IMREAD_GRAYSCALE);
	Mat t2 = imread("D:/images/knn_02.png", IMREAD_GRAYSCALE);
	imshow("t1", t1);
	imshow("t2", t2);
	Mat m1, m2;
	resize(t1, m1, Size(20, 20));
	resize(t2, m2, Size(20, 20));
	Mat testdata = Mat::zeros(2, 400, CV_8UC1);
	Mat testlabels = Mat::zeros(2, 1, CV_32SC1);
	Rect rect;
	rect.x = 0;
	rect.y = 0;
	rect.height = 1;
	rect.width = 400;
	Mat one = m1.reshape(1, 1);
	Mat two = m2.reshape(1, 1);
	one.copyTo(testdata(rect));
	rect.y = 1;
	two.copyTo(testdata(rect));
	testlabels.at<int>(0, 0) = 1;
	testlabels.at<int>(1, 0) = 2;
	testdata.convertTo(testdata, CV_32F);

	Mat result2;
	knn->findNearest(testdata, 5, result2);
	for (int i = 0; i< result2.rows; i++) {
		int predict = result2.at<float>(i, 0);
		printf("knn t%d predict : %d, actual label :%d \n",(i+1),  predict, testlabels.at<int>(i, 0));
	}

}

發表迴響

你的電子郵件位址並不會被公開。 必要欄位標記為 *