import cv2
import os
import numpy as np
import sys
os.chdir(sys.path[0])  #vscode读取相对路径

def changeImg(path):  #将图像处理成相同大小的像素块,再拉成一维数组
    img = cv2.imread(path)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)  #灰度化
    ret,thresh=cv2.threshold(img,127,255,cv2.THRESH_BINARY)  #二值化
    contours, hierarchy = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)  #轮廓检测
    x,y,w,h = cv2.boundingRect(contours[1])  #有三个轮廓，第一个是最外层的，第二个是数字边缘的，所以i取1
    img =thresh[y:y+h,x:x+w]  #裁剪图像
    img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_NEAREST)#标准化图像大小，28*28
    img = img.ravel()  #拉成一维数组
    for i in range(784):  #将255改成1*784的一维数组
        if img[i] == 255:
            img[i] = 1
    return img

class Bayes:
    def __init__(self):
        self.length=-1
        self.labelrate=dict()  #先验概率字典
        self.vectorrate=dict()  #训练集字典，一共有0-9个key，每个的value是9个array([])
    def fit(self,dataset,labels):  #构造训练集
        self.length=len(dataset[0])#训练数据特征值的长度
        labelsnum=len(labels) #类别的数量
        norlabels=set(labels) #不重复类别的数量
        for item in norlabels:
            self.labelrate[item] = labels.count(item) / labelsnum #求当前类别占总类别的比例 即先验概率
        print("先验概率",self.labelrate)
        for vector,label in zip(dataset,labels):
            if label not in self.vectorrate:
                self.vectorrate[label] = []
            self.vectorrate[label].append(vector)
        print("训练结束")
        return self
    def btest(self,testdata,labelset):
        #计算testdata分别为各个类别的概率
        lbDict=dict()  #后验概率字典
        for thislb in labelset:
            p = 1  #类条件概率
            alllabel = self.labelrate[thislb]  #每一个数字的先验概率
            allvector = self.vectorrate[thislb]  #每一个数字的训练集，是9*784的矩阵
            vnum=len(allvector)  #为9
            allvector=np.array(allvector).T  #矩阵转置，结果是一个784*9的二维数组，这样直接取一行就是原来矩阵的一列
            for index in range(0,len(testdata)):  #遍历784次
                vector=list(allvector[index])  #取每一行，就是图片28*28中每一个切块的数据
                p *= (vector.count(testdata[index]) + 1) / (vnum + 1)  
                #类条件概率，9张图片的第index块中和测试图片第index块数据相同的个数/9，同时为了避免出现0，默认加一张与测试图片一样的在训练集里，分子分母+1
            lbDict[thislb]=p * alllabel
        print(str(thisfilename),"后验概率：",lbDict)  #概率字典
        thislbabel=sorted(lbDict,key=lambda x:lbDict[x],reverse=True)
        return thislbabel[0]

def seplabel(fname): #读取文件名字里的真实数字
    filestr = fname.split(".")[0] #提取文件名，返回标签值
    labels = int(filestr.split("-")[0])
    return labels

def traindata():
    labels = []
    trainfile = os.listdir('train_picture/')  #训练文件夹下的图片文件
    num = len(trainfile)  #因该有九个图像文件
    #创建一个数组存放训练数据，行为文件总数，列为784，为一个手写体的内容 zeros创建规定大小的数组
    trainarr = np.zeros((num,784))  #里面放图像数组为一行，num列
    for i in range(0,num):
        thisfname = trainfile[i]  #求取文件名
        thislabel = seplabel(thisfname)  #读出真实数字
        labels.append(thislabel)
        trainarr[i] = changeImg('train_picture/'  + thisfname)
    return trainarr,labels  #返回文件的内容和名字里的值
   #将训练集的灰度图像信息保存在一个数组里

bys = Bayes()
#训练数据
train_data,labels=traindata()
train_data=list(train_data)
bys.fit(train_data,labels)
labelsall=[0,1,2,3,4,5,6,7,8,9]
'''测试一张图片
thisdata=changeImg("test_picture/1-10.bmp")
test=bys.btest(thisdata,labelsall)
print("识别结果",test)
'''
testfile=os.listdir("test_picture")
num=len(testfile)
right_count = 0
for i in range(num):
    thisfilename=testfile[i]
    thislabel=seplabel(thisfilename)
    thisdataarr=changeImg("test_picture/"+thisfilename)
    label=bys.btest(thisdataarr,labelsall)
    if label == thislabel:
        right_count += 1
        print(str(thisfilename) + " : 识别结果：" + str(label) +'  识别成功 √')
    else:
        print(str(thisfilename) + " : 识别结果：" + str(label) + '  识别失败 ×')
print('正确率为: {:.2%}'.format(right_count/num))