Skip to main content
 Web开发网 » 站长学院 » 浏览器插件

学了那么多算法你都能理解其原理吗?

2021年11月05日6010百度已收录

算法原理

k近邻(k-Nearest Neighbor,kNN),应该是最简单的传统机器学习模型,给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例中的大多数属于哪个类别,就把该输入实例划分到这个类别。

k近邻算法没有显示的训练过程,在“训练阶段”仅仅是把样本保存起来,训练时间开销为零,待收到测试样本后在进行计算处理。

这个k实际上是一个超参数,k值的选择会对k近邻法的结果产生重大影响。如果选择较小的k值,意味着只有与输入实例较近的(相似的)训练实例才会对预测结果起作用,预测结果会对近邻的实例点非常敏感,如果近邻的实例点恰巧是噪声点,预测就会出错;如果选择较大的k值,就意味着与输入实例较远的(不相似的)训练实例也会对预测起作用,这样预测也会出错。在实际应用中,k值一般取一个比较小的数值,并且通常采用交叉验证法来选取最优的k值。如上图的k=5。

模型训练代码地址:

def train():

print("start training...")

# 处理训练数据

train_feature, train_target = process_file(train_dir, word_to_id, cat_to_id)

# 模型训练

model.fit(train_feature, train_target)

def test():

print("start testing...")

# 处理测试数据

test_feature, test_target = process_file(test_dir, word_to_id, cat_to_id)

# test_predict = model.predict(test_feature) # 返回预测类别

test_predict_proba = model.predict_proba(test_feature) # 返回属于各个类别的概率

test_predict = np.argmax(test_predict_proba, 1) # 返回概率最大的类别标签

# accuracy

true_false = (test_predict == test_target)

accuracy = np.count_nonzero(true_false) / float(len(test_target))

print()

print("accuracy is %f" % accuracy)

# precision recall f1-score

print()

print(metrics.classification_report(test_target, test_predict, target_names=categories))

# 混淆矩阵

print("Confusion Matrix...")

print(metrics.confusion_matrix(test_target, test_predict))

if not os.path.exists(vocab_dir):

# 构建词典表

build_vocab(train_dir, vocab_dir)

categories, cat_to_id = read_category()

words, word_to_id = read_vocab(vocab_dir)

# kNN

model = neighbors.KNeighborsClassifier()

train()

test()运行结果:

read_category...

read_vocab...

start training...

start testing...

accuracy is 0.820000

precision recall f1-score support

时政 0.65 0.85 0.74 94

财经 0.81 0.94 0.87 115

科技 0.96 0.97 0.96 94

游戏 0.99 0.74 0.85 104

娱乐 0.99 0.75 0.85 89

时尚 0.88 0.67 0.76 91

家居 0.44 0.78 0.56 89

房产 0.93 0.82 0.87 104

体育 1.00 0.98 0.99 116

教育 0.96 0.65 0.78 104

avg / total 0.87 0.82 0.83 1000

Confusion Matrix...

[[ 80 4 0 0 0 0 6 3 0 1]

[ 1 108 0 0 0 0 6 0 0 0]

[ 0 0 91 0 0 0 3 0 0 0]

[ 4 0 1 77 0 3 18 0 0 1]

[ 4 3 0 1 67 4 10 0 0 0]

[ 0 0 0 0 1 61 29 0 0 0]

[ 9 5 2 0 0 0 69 3 0 1]

[ 9 3 0 0 0 0 7 85 0 0]

[ 2 0 0 0 0 0 0 0 114 0]

[ 14 10 1 0 0 1 10 0 0 68]]

社群:

公众号:

了解更多干货文章,可以关注小程序八斗问答

作者:gaoyan0335

来源:CSDN

原文:

版权声明:本文为博主原创文章,转载请附上博文链接!

评论列表暂无评论
发表评论
微信