算法原理
k近邻(k-Nearest Neighbor,kNN),应该是最简单的传统机器学习模型,给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例中的大多数属于哪个类别,就把该输入实例划分到这个类别。
k近邻算法没有显示的训练过程,在“训练阶段”仅仅是把样本保存起来,训练时间开销为零,待收到测试样本后在进行计算处理。
这个k实际上是一个超参数,k值的选择会对k近邻法的结果产生重大影响。如果选择较小的k值,意味着只有与输入实例较近的(相似的)训练实例才会对预测结果起作用,预测结果会对近邻的实例点非常敏感,如果近邻的实例点恰巧是噪声点,预测就会出错;如果选择较大的k值,就意味着与输入实例较远的(不相似的)训练实例也会对预测起作用,这样预测也会出错。在实际应用中,k值一般取一个比较小的数值,并且通常采用交叉验证法来选取最优的k值。如上图的k=5。
模型训练代码地址:
def train():
print("start training...")
# 处理训练数据
train_feature, train_target = process_file(train_dir, word_to_id, cat_to_id)
# 模型训练
model.fit(train_feature, train_target)
def test():
print("start testing...")
# 处理测试数据
test_feature, test_target = process_file(test_dir, word_to_id, cat_to_id)
# test_predict = model.predict(test_feature) # 返回预测类别
test_predict_proba = model.predict_proba(test_feature) # 返回属于各个类别的概率
test_predict = np.argmax(test_predict_proba, 1) # 返回概率最大的类别标签
# accuracy
true_false = (test_predict == test_target)
accuracy = np.count_nonzero(true_false) / float(len(test_target))
print()
print("accuracy is %f" % accuracy)
# precision recall f1-score
print()
print(metrics.classification_report(test_target, test_predict, target_names=categories))
# 混淆矩阵
print("Confusion Matrix...")
print(metrics.confusion_matrix(test_target, test_predict))
if not os.path.exists(vocab_dir):
# 构建词典表
build_vocab(train_dir, vocab_dir)
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_dir)
# kNN
model = neighbors.KNeighborsClassifier()
train()
test()运行结果:
read_category...
read_vocab...
start training...
start testing...
accuracy is 0.820000
precision recall f1-score support
时政 0.65 0.85 0.74 94
财经 0.81 0.94 0.87 115
科技 0.96 0.97 0.96 94
游戏 0.99 0.74 0.85 104
娱乐 0.99 0.75 0.85 89
时尚 0.88 0.67 0.76 91
家居 0.44 0.78 0.56 89
房产 0.93 0.82 0.87 104
体育 1.00 0.98 0.99 116
教育 0.96 0.65 0.78 104
avg / total 0.87 0.82 0.83 1000
Confusion Matrix...
[[ 80 4 0 0 0 0 6 3 0 1]
[ 1 108 0 0 0 0 6 0 0 0]
[ 0 0 91 0 0 0 3 0 0 0]
[ 4 0 1 77 0 3 18 0 0 1]
[ 4 3 0 1 67 4 10 0 0 0]
[ 0 0 0 0 1 61 29 0 0 0]
[ 9 5 2 0 0 0 69 3 0 1]
[ 9 3 0 0 0 0 7 85 0 0]
[ 2 0 0 0 0 0 0 0 114 0]
[ 14 10 1 0 0 1 10 0 0 68]]
社群:
公众号:
了解更多干货文章,可以关注小程序八斗问答
作者:gaoyan0335
来源:CSDN
原文:
版权声明:本文为博主原创文章,转载请附上博文链接!