鸽了一个多月,终于更新辣,这篇是[基于CNN+GRU的文本分类实践](基于CNN+GRU的文本分类实践 | Rufus的B滚木 (gitee.io))的续集,话不多说,直接进入正题!

经过无数次的尝试,最后发现,还是全连接层坠爽,有奇效。

model.add(Embedding(len(vocab) + 1, 256, input_length=300))  
# 使用Embedding层将每个词编码转换为词向量
model.add(Dense(512))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dropout(0.2))
model.add(Dense(256))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dropout(0.2))
model.add(Dense(256))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(9, activation='softmax'))

然后,这次做了一个重大的改动,就是把”其他“类去除了, 因为在搜集数据集的时候,发现其他类的数据集根本无从下手,搜集出来,感觉也只是在模型里面充当噪声,然后经过一波冥思苦想,最后决定这样干,把神经网络最后不太确定的,就归作其他类。

那么应该怎样找出“不确定”的数据呢,我想到了利用标准差来解决这个问题,每次神经网络最后输出的是一个9个float类型数字的数组,现在计算这9个数字的方差,如果是机器比较确定的话,这9个数字势必有一个会特别接近1

其他8个数字特别接近0,如果不确定,则可能数据分布会相对平均,经过多次测验,当这九个数字的标准差小于0.25时,神经网络对判定结果的置信程度比较低。

下面是改进之后的detect.py检测

import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
import os
import pandas as pd
import jieba
import pickle
import numpy as np

cw = lambda x: list(jieba.cut(x))  # jieba分词器,对输入文本分词时的必备工具
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # 为了detect程序的稳定性,还是用cpu进行计算
fr = open('tokenizer', 'rb')  # 打开之前已经保存好的词典映射文件
tokenizer = pickle.load(fr)
fr.close()
model = tf.keras.models.load_model("model_save/model_checkpoints")  # 加载模型
tags = ['体育', '军事', '娱乐', '房产', '教育', '汽车', '游戏', '科技', '财经']
while True:
    y1 = input("------------\n请输入一段新闻:(输入q退出)")
    if y1 == 'q':
        break
    y1 = pd.DataFrame(y1, columns=["a"], index=['b'])  # 处理输入数据
    y1 = y1['a'].apply(cw)
    y1 = str(y1['b'])
    y2 = tokenizer.texts_to_sequences([y1])
    y3 = pad_sequences(y2, maxlen=300)
    result = model.predict(y3)
    print('置信指数:' + str(np.std(result)))  # 计算置信指数(方差)
    if np.std(result) < 0.25:  # 根据方差判定是否需要相信预测结果
        print('该新闻类别为其他')
    else:
        pred = tf.argmax(result, axis=1)
        c = pred.__int__()
        print('该新闻类别为' + tags[c])

简单的UI界面

为了做出一个能让正常人使用的UI效果,单独制作了两个文件

image-20210618171150322

它们分别是用于预测单条新闻和多条新闻的两个函数(好像不用分成两个文件。。。)

另外感谢学长用pyqt写的页面

最后的效果大概是这样

image-20210618171447314

感觉还不错,只是有的时候经常出很多奇奇怪怪的错误,等修复一些BUG,比赛完成之后,我会把源码放上来。