|
问题背景:
之前做的聊天项目中涉及敏感词过滤,当时只是采用了简单的字符串匹配,敏感词量还不大,性能还能接受。
最近刚好有时间,于是就研究了trie tree来实现敏感词过滤。
资料
trie tree的原理网上很多。
几种Trie树性能比较 https://www.hankcs.com/nlp/perfo ... eral-trie-tree.html
找到了darts的c/c++版 https://github.com/s-yata/darts-clone
使用
主要涉及几个接口的使用:
build : 传入敏感词数组,构建trie tree
exactMatchSearch : 查找完全匹配的敏感词
commonPrefixSearch : 前缀匹配。找到已字符串开头开始匹配的敏感词,并返回匹配的数组
性能比较
主要是和之前普通的字符串匹配进行性能比较,已10000个敏感词做测试,差距在10000倍以上。理论上性能差距量级 = 敏感词数量。
测试代码
// redistest.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
#include <iostream>
#include <stdint.h>
#include <cassert>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <set>
#include <string>
#include <vector>
#include <algorithm>
#include "windows.h"
#include "darts.h"
using namespace Darts;
void generate_valid_keys(std::size_t num_keys,
std::set<std::string>* valid_keys) {
std::vector<char> key;
while (valid_keys->size() < num_keys) {
key.resize(1 + (std::rand() % 8));
for (std::size_t i = 0; i < key.size(); ++i) {
key = 'A' + (std::rand() % 26);
}
valid_keys->insert(std::string(&key[0], key.size()));
}
}
void generate_invalid_keys(std::size_t num_keys,
const std::set<std::string>& valid_keys,
std::set<std::string>* invalid_keys) {
std::vector<char> key;
while (invalid_keys->size() < num_keys) {
key.resize(1 + (std::rand() % 8));
for (std::size_t i = 0; i < key.size(); ++i) {
key = 'A' + (std::rand() % 26);
}
std::string generated_key(&key[0], key.size());
if (valid_keys.find(generated_key) == valid_keys.end())
invalid_keys->insert(std::string(&key[0], key.size()));
}
}
typedef std::vector<bool> Mask;
typedef std::vector<std::string> FilterWords;
typedef std::vector<Mask> MaskArray;
inline std::string to_lower_str(const std::string& s) {
std::string s2;
s2.resize(s.size());
std::transform(s.begin(), s.end(), s2.begin(), ::towlower);
return s2;
}
std::set<std::string> g_valid_keys;
Darts:oubleArray g_dic;
void ReplaceAndReport(std::string sText)
{
std::string sCopy = to_lower_str(sText);
int nMaskLength = sCopy.length();
Mask textMask(nMaskLength);
FilterWords matchedWords;
MaskArray wordMaskArray;
for (const auto& word : g_valid_keys)
{
size_t pos = 0;
if ((pos = sCopy.find(word, 0)) == std::string::npos) continue;
matchedWords.push_back(word);
size_t nWordLength = word.length();
Mask wordMask(nMaskLength);
do {
std::fill_n(wordMask.begin() + pos, nWordLength, true);
std::fill_n(textMask.begin() + pos, nWordLength, true);
pos += nWordLength;
} while ((pos = sCopy.find(word, pos)) != std::string::npos);
wordMaskArray.push_back(std::move(wordMask));
}
if (matchedWords.empty()) return;
int i = 0;
std::replace_if(sText.begin(), sText.end(), [&i, &textMask](char) { return textMask[i++]; }, '*');
}
void ReplaceAndReportV2(std::string sText)
{
std::string sCopy = to_lower_str(sText);
static const std::size_t MAX_NUM_RESULTS = 16;
for (int i = 0; i < sText.size();)
{
typename Darts:oubleArray::result_pair_type results[MAX_NUM_RESULTS] = { 0 };
std::size_t num_results = g_dic.commonPrefixSearch(&sText, results, MAX_NUM_RESULTS);
if (num_results > 0)
{
int offset = results[(std::min)(num_results, MAX_NUM_RESULTS)-1].length;
std::fill_n(sText.begin() + i, offset, '*');
i += offset;
}
else
{
if (sText < 0)//表示中文
{
i += 2;
}
else
{
i++;
}
}
}
}
int main()
{
static const std::size_t NUM_VALID_KEYS = 1 << 17;
static const std::size_t NUM_INVALID_KEYS = 1 << 17;
generate_valid_keys(NUM_VALID_KEYS, &g_valid_keys);
g_valid_keys.insert("傻逼BCD");
g_valid_keys.insert("傻逼A");
g_valid_keys.insert("傻逼");
std::vector<const char*> keys(g_valid_keys.size());
std::vector<std::size_t> lengths(g_valid_keys.size());
std::vector<typename Darts:oubleArray::value_type> values(g_valid_keys.size());
std::size_t key_id = 0;
for (std::set<std::string>::const_iterator it = g_valid_keys.begin(); it != g_valid_keys.end(); ++it, ++key_id) {
keys[key_id] = it->c_str();
}
// 文档里面没有说明,敏感词数组需要排序之后才能build,所以需要使用set容器
g_dic.build(keys.size(), &keys[0]);
std::string strTest = "你是个傻逼BCD啊你是个玩";
int testTimes = 1000;
DWORD a1 = GetTickCount();
for (int i = 0; i < testTimes; ++i)
{
ReplaceAndReport(strTest);
}
printf("find count %d \n",GetTickCount() - a1);
/*a1 = GetTickCount();
int result;
for (int i = 0; i < testTimes; ++i)
{
g_dic.exactMatchSearch(strTest.c_str(), result);
}
printf("exactMatchSearch count %d", GetTickCount() - a1);
*/
a1 = GetTickCount();
for (int i = 0; i < testTimes; ++i)
{
ReplaceAndReportV2(strTest);
}
printf("commonPrefixSearch count %d", GetTickCount() - a1);
}
|
|