-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
36 lines (32 loc) · 1.21 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import classify
def validate(method,truth,data,labellist,k,m):
'''
:param method: 判别方法,nn,knn,wknn
:param truth: 测试数据编号
:param data: 训练数据集
:param labellist: 训练数据地址标签
:param k:
:param m: 扰动范围
:return:
'''
disturbdata = data[truth]+m*np.random.random(len(data[truth])) # 由于没有收集测试数据,在此加上扰动充当测试数据。
predict = method(disturbdata,data,labellist,k)
m,n = predict
i,j = labellist[truth]
print('use method {}, the predict result is {}'.format(method.__name__,predict))
return np.linalg.norm((abs(m-i),abs(n-j)))
if __name__=='__main__':
# 测试
# 建立地点标签
data = np.load('data/dataset.npy')
labellist = []
for i in range(len(data)):
y, x = divmod(i,6) # 行x, 列 y
labellist.append((x,y))
# print(labellist)
for truth in range(len(data)):
print('the truth is {}'.format(labellist[truth]))
validate(classify.classifynn,truth,data,labellist,3,50)
validate(classify.classifyknn, truth, data, labellist,3,50)
validate(classify.classifywknn, truth, data,labellist, 3,50)