fanfuhan OpenCV 教學114 ~ opencv-114-KNN算法介紹

fanfuhan OpenCV 教學114 ~ opencv-114-KNN算法介紹

fanfuhan OpenCV 教學114 ~ opencv-114-KNN算法介紹


資料來源: https://fanfuhan.github.io/

https://fanfuhan.github.io/2019/05/24/opencv-114/

GITHUB:https://github.com/jash-git/fanfuhan_ML_OpenCV


OpenCV中機器學習模塊的最近鄰算法KNN, 使用KNN算法實現手寫數字識別,OpenCV在sample/data中有一張自帶的手寫數字數據集圖像,0~9 每個有500個樣本,總計有5000個數字。圖像大小為1000×2000的大小圖像,分割為20×20大小的單個數字圖像,每個樣本400個像素。然後使用KNN相關API實現訓練與結果的保存。大致的順序如下:

 01.讀入測試圖像digit.png(可以在我的github下載,不知道地址看置頂帖子)
 02.構建樣本數據與標籤

 03.創建KNN訓練並保存訓練結果


C++

#include <opencv2/opencv.hpp>
#include <iostream>

using namespace cv;
using namespace cv::ml;
using namespace std;

int main(int argc, char** argv) {
	Mat data = imread("D:/projects/opencv_tutorial/data/images/digits.png");
	Mat gray;
	cvtColor(data, gray, COLOR_BGR2GRAY);

	// 分割为5000个cells
	Mat images = Mat::zeros(5000, 400, CV_8UC1);
	Mat labels = Mat::zeros(5000, 1, CV_8UC1);
	Rect rect;
	rect.height = 20;
	rect.width = 20;
	int index = 0;
	Rect roi;
	roi.x = 0;
	roi.height = 1;
	roi.width = 400;
	for (int row = 0; row < 50; row++) {
		int label = row / 5;
		for (int col = 0; col < 100; col++) {
			Mat digit = Mat::zeros(20, 20, CV_8UC1);
			index = row * 100 + col;
			rect.x = col * 20;
			rect.y = row * 20;
			gray(rect).copyTo(digit);
			Mat one_row = digit.reshape(1, 1);
			roi.y = index;
			one_row.copyTo(images(roi));
			labels.at<uchar>(index, 0) = label;
		}
	}
	printf("load sample hand-writing data...\n");

	// 转换为浮点数
	images.convertTo(images, CV_32FC1);
	labels.convertTo(labels, CV_32SC1);

	// 开始KNN训练
	printf("Start to knn train...\n");
	Ptr<KNearest> knn = KNearest::create();
	knn->setDefaultK(5);
	knn->setIsClassifier(true);
	Ptr<ml::TrainData> tdata = ml::TrainData::create(images, ml::ROW_SAMPLE, labels);
	knn->train(tdata);
	knn->save("D:/vcworkspaces/knn_knowledge.yml");
	printf("Finished KNN...\n");
	return true;
}


Python

"""
KNN算法介绍
"""

import cv2 as cv
import numpy as np

# 读取数据
img = cv.imread("images/digits.png")
gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)]
x = np.array(cells)

# 创建训练与测试数据
train = x[:, :50].reshape(-1, 400).astype(np.float32)
test = x[:, 50:100].reshape(-1, 400).astype(np.float32)
k = np.arange(10)
train_labels = np.repeat(k,250)[:, np.newaxis]
test_labels = train_labels.copy()

# 训练KNN
knn = cv.ml.KNearest_create()
knn.train(train, cv.ml.ROW_SAMPLE, train_labels)
ret, result, neighbours, dist = knn.findNearest(test, k=5)

# 计算准确率
matches = result == test_labels
correct = np.count_nonzero(matches)
acc = correct * 100.0 / result.size
print(acc)
# 预测准确率: 91.76

發表迴響

你的電子郵件位址並不會被公開。 必要欄位標記為 *