書寫自動智慧:探索Python文本分類器的開發與應用:支持二分類、多分類、多標簽分類、多層級分類和Kmeans聚類
文本分類器,提供多種文本分類和聚類算法,支持句子和文檔級的文本分類任務,支持二分類、多分類、多標簽分類、多層級分類和Kmeans聚類,開箱即用。python3開發。
-
Classifier支持算法
- LogisticRegression
- Random Forest
- Decision Tree
- K-Nearest Neighbours
- Naive bayes
- Xgboost
- Support Vector Machine(SVM)
- TextCNN
- TextRNN
- Fasttext
- BERT
-
Cluster
- MiniBatchKmeans
While providing rich functions, pytextclassifier internal modules adhere to low coupling, model adherence to inert loading, dictionary publication, and easy to use.
- 安裝
- Requirements and Installation
pip3 install torch # conda install pytorch
pip3 install pytextclassifier
or
git clone https://github.com/shibing624/pytextclassifier.git
cd pytextclassifier
python3 setup.py install
1. English Text Classifier
包括模型訓練、保存、預測、評估等
examples/lr_en_classification_demo.py:
import syssys.path.append('..')
from pytextclassifier import ClassicClassifierif __name__ == '__main__':m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')# ClassicClassifier support model_name:lr, random_forest, decision_tree, knn, bayes, svm, xgboostprint(m)data = [('education', 'Student debt to cost Britain billions within decades'),('education', 'Chinese education for TV experiment'),('sports', 'Middle East and Asia boost investment in top level sports'),('sports', 'Summit Series look launches HBO Canada sports doc series: Mudhar')]# train and save best modelm.train(data)# load best model from model_dirm.load_model()predict_label, predict_proba = m.predict(['Abbott government spends $8 million on higher education media blitz'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')test_data = [('education', 'Abbott government spends $8 million on higher education media blitz'),('sports', 'Middle East and Asia boost investment in top level sports'),]acc_score = m.evaluate_model(test_data)print(f'acc_score: {acc_score}')
output:
ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)
predict_label: ['education'], predict_proba: [0.5378236358492112]
acc_score: 1.0
2. Chinese Text Classifier(中文文本分類)
文本分類兼容中英文語料庫。
example examples/lr_classification_demo.py
import syssys.path.append('..')
from pytextclassifier import ClassicClassifierif __name__ == '__main__':m = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')# 經典分類方法,支持的模型包括:lr, random_forest, decision_tree, knn, bayes, svm, xgboostdata = [('education', '名師指導托福語法技巧:名詞的復數形式'),('education', '中國高考成績海外認可 是“狼來了”嗎?'),('education', '公務員考慮越來越吃香,這是怎么回事?'),('sports', '圖文:法網孟菲爾斯苦戰進16強 孟菲爾斯怒吼'),('sports', '四川丹棱舉行全國長距登山挑戰賽 近萬人參與'),('sports', '米蘭客場8戰不敗國米10年連勝'),]m.train(data)print(m)# load best model from model_dirm.load_model()predict_label, predict_proba = m.predict(['福建春季公務員考試報名18日截止 2月6日考試','意甲首輪補賽交戰記錄:米蘭客場8戰不敗國米10年連勝'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')test_data = [('education', '福建春季公務員考試報名18日截止 2月6日考試'),('sports', '意甲首輪補賽交戰記錄:米蘭客場8戰不敗國米10年連勝'),]acc_score = m.evaluate_model(test_data)print(f'acc_score: {acc_score}') # 1.0#### train model with 1w dataprint('-' * 42)m = ClassicClassifier(output_dir='models/lr', model_name_or_model='lr')data_file = 'thucnews_train_1w.txt'm.train(data_file)m.load_model()predict_label, predict_proba = m.predict(['順義北京蘇活88平米起精裝房在售','美EB-5項目“15日快速移民”將推遲'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
output:
ClassicClassifier instance (LogisticRegression(fit_intercept=False), stopwords size: 2438)
predict_label: ['education' 'sports'], predict_proba: [0.5, 0.598941806741534]
acc_score: 1.0
------------------------------------------
predict_label: ['realty' 'education'], predict_proba: [0.7302956923617372, 0.2565005445322923]
3.可解釋性分析
例如,顯示模型的特征權重,以及預測詞的權重 examples/visual_feature_importance.ipynb
import syssys.path.append('..')
from pytextclassifier import ClassicClassifier
import jiebatc = ClassicClassifier(output_dir='models/lr-toy', model_name_or_model='lr')
data = [('education', '名師指導托福語法技巧:名詞的復數形式'),('education', '中國高考成績海外認可 是“狼來了”嗎?'),('sports', '圖文:法網孟菲爾斯苦戰進16強 孟菲爾斯怒吼'),('sports', '四川丹棱舉行全國長距登山挑戰賽 近萬人參與'),('sports', '米蘭客場8戰不敗國米10年連勝')
]
tc.train(data)
import eli5infer_data = ['高考指導托福語法技巧國際認可','意甲首輪補賽交戰記錄:米蘭客場8戰不敗國米10年連勝']
eli5.show_weights(tc.model, vec=tc.feature)
seg_infer_data = [' '.join(jieba.lcut(i)) for i in infer_data]
eli5.show_prediction(tc.model, seg_infer_data[0], vec=tc.feature,target_names=['education', 'sports'])
output:
4. Deep Classification model
本項目支持以下深度分類模型:FastText、TextCNN、TextRNN、Bert模型,import
模型對應的方法來調用:
from pytextclassifier import FastTextClassifier, TextCNNClassifier, TextRNNClassifier, BertClassifier
下面以FastText模型為示例,其他模型的使用方法類似。
4.1 FastText 模型
訓練和預測FastText
模型示例examples/fasttext_classification_demo.py
import syssys.path.append('..')
from pytextclassifier import FastTextClassifier, load_dataif __name__ == '__main__':m = FastTextClassifier(output_dir='models/fasttext-toy')data = [('education', '名師指導托福語法技巧:名詞的復數形式'),('education', '中國高考成績海外認可 是“狼來了”嗎?'),('education', '公務員考慮越來越吃香,這是怎么回事?'),('sports', '圖文:法網孟菲爾斯苦戰進16強 孟菲爾斯怒吼'),('sports', '四川丹棱舉行全國長距登山挑戰賽 近萬人參與'),('sports', '米蘭客場8戰不敗保持連勝'),]m.train(data, num_epochs=3)print(m)# load trained best modelm.load_model()predict_label, predict_proba = m.predict(['福建春季公務員考試報名18日截止 2月6日考試','意甲首輪補賽交戰記錄:米蘭客場8戰不敗國米10年連勝'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')test_data = [('education', '福建春季公務員考試報名18日截止 2月6日考試'),('sports', '意甲首輪補賽交戰記錄:米蘭客場8戰不敗國米10年連勝'),]acc_score = m.evaluate_model(test_data)print(f'acc_score: {acc_score}') # 1.0#### train model with 1w dataprint('-' * 42)data_file = 'thucnews_train_1w.txt'm = FastTextClassifier(output_dir='models/fasttext')m.train(data_file, names=('labels', 'text'), num_epochs=3)# load best trained model from model_dirm.load_model()predict_label, predict_proba = m.predict(['順義北京蘇活88平米起精裝房在售','美EB-5項目“15日快速移民”將推遲'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')x, y, df = load_data(data_file)test_data = df[:100]acc_score = m.evaluate_model(test_data)print(f'acc_score: {acc_score}')
4.2 BERT 類模型
4.2.1 多分類模型
訓練和預測BERT
多分類模型,示例examples/bert_classification_zh_demo.py
import syssys.path.append('..')
from pytextclassifier import BertClassifierif __name__ == '__main__':m = BertClassifier(output_dir='models/bert-chinese-toy', num_classes=2,model_type='bert', model_name='bert-base-chinese', num_epochs=2)# model_type: support 'bert', 'albert', 'roberta', 'xlnet'# model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...data = [('education', '名師指導托福語法技巧:名詞的復數形式'),('education', '中國高考成績海外認可 是“狼來了”嗎?'),('education', '公務員考慮越來越吃香,這是怎么回事?'),('sports', '圖文:法網孟菲爾斯苦戰進16強 孟菲爾斯怒吼'),('sports', '四川丹棱舉行全國長距登山挑戰賽 近萬人參與'),('sports', '米蘭客場8戰不敗國米10年連勝'),]m.train(data)print(m)# load trained best model from model_dirm.load_model()predict_label, predict_proba = m.predict(['福建春季公務員考試報名18日截止 2月6日考試','意甲首輪補賽交戰記錄:米蘭客場8戰不敗國米10年連勝'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')test_data = [('education', '福建春季公務員考試報名18日截止 2月6日考試'),('sports', '意甲首輪補賽交戰記錄:米蘭客場8戰不敗國米10年連勝'),]acc_score = m.evaluate_model(test_data)print(f'acc_score: {acc_score}')# train model with 1w data file and 10 classesprint('-' * 42)m = BertClassifier(output_dir='models/bert-chinese', num_classes=10,model_type='bert', model_name='bert-base-chinese', num_epochs=2,args={"no_cache": True, "lazy_loading": True, "lazy_text_column": 1, "lazy_labels_column": 0, })data_file = 'thucnews_train_1w.txt'# 如果訓練數據超過百萬條,建議使用lazy_loading模式,減少內存占用m.train(data_file, test_size=0, names=('labels', 'text'))m.load_model()predict_label, predict_proba = m.predict(['順義北京蘇活88平米起精裝房在售','美EB-5項目“15日快速移民”將推遲','恒生AH溢指收平 A股對H股折價1.95%'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
PS:如果訓練數據超過百萬條,建議使用lazy_loading模式,減少內存占用
4.2.2 多標簽分類模型
分類可以分為多分類和多標簽分類。多分類的標簽是排他的,而多標簽分類的所有標簽是不排他的。
多標簽分類比較直觀的理解是,一個樣本可以同時擁有幾個類別標簽,
比如一首歌的標簽可以是流行、輕快,一部電影的標簽可以是動作、喜劇、搞笑等,這都是多標簽分類的情況。
訓練和預測BERT
多標簽分類模型,示例examples/bert_multilabel_classification_zh_demo.py.py
import sys
import pandas as pdsys.path.append('..')
from pytextclassifier import BertClassifierdef load_jd_data(file_path):"""Load jd data from file.@param file_path: format: content,其他,互聯互通,產品功耗,滑輪提手,聲音,APP操控性,呼吸燈,外觀,底座,制熱范圍,遙控器電池,味道,制熱效果,衣物烘干,體積大小@return: """data = []with open(file_path, 'r', encoding='utf-8') as f:for line in f:line = line.strip()if line.startswith('#'):continueif not line:continueterms = line.split(',')if len(terms) != 16:continueval = [int(i) for i in terms[1:]]data.append([terms[0], val])return dataif __name__ == '__main__':# model_type: support 'bert', 'albert', 'roberta', 'xlnet'# model_name: support 'bert-base-chinese', 'bert-base-cased', 'bert-base-multilingual-cased' ...m = BertClassifier(output_dir='models/multilabel-bert-zh-model', num_classes=15,model_type='bert', model_name='bert-base-chinese', num_epochs=2, multi_label=True)# Train and Evaluation data needs to be in a Pandas Dataframe containing at least two columns, a 'text' and a 'labels' column. The `labels` column should contain multi-hot encoded lists.train_data = [["一個小時房間仍然沒暖和", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],["耗電情況:這個沒有注意", [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],]data = load_jd_data('multilabel_jd_comments.csv')train_data.extend(data)print(train_data[:5])train_df = pd.DataFrame(train_data, columns=["text", "labels"])print(train_df.head())m.train(train_df)print(m)# Evaluate the modelacc_score = m.evaluate_model(train_df[:20])print(f'acc_score: {acc_score}')# load trained best model from model_dirm.load_model()predict_label, predict_proba = m.predict(['一個小時房間仍然沒暖和', '耗電情況:這個沒有注意'])print(f'predict_label: {predict_label}, predict_proba: {predict_proba}')
5.模型驗證
- THUCNews中文文本數據集(1.56GB):官方下載地址,抽樣了10萬條THUCNews中文文本10分類數據集(6MB),地址:examples/thucnews_train_10w.txt。
- TNEWS今日頭條中文新聞(短文本)分類 Short Text Classificaiton for News,該數據集(5.1MB)來自今日頭條的新聞版塊,共提取了15個類別的新聞,包括旅游,教育,金融,軍事等,地址:tnews_public.zip
在THUCNews中文文本10分類數據集(6MB)上評估,模型在測試集(test)評測效果如下:
模型 | acc | 說明 |
---|---|---|
LR | 0.8803 | 邏輯回歸Logistics Regression |
TextCNN | 0.8809 | Kim 2014 經典的CNN文本分類 |
TextRNN_Att | 0.9022 | BiLSTM+Attention |
FastText | 0.9177 | bow+bigram+trigram, 效果出奇的好 |
DPCNN | 0.9125 | 深層金字塔CNN |
Transformer | 0.8991 | 效果較差 |
BERT-base | 0.9483 | bert + fc |
ERNIE | 0.9461 | 比bert略差 |
在中文新聞短文本分類數據集TNEWS上評估,模型在開發集(dev)評測效果如下:
模型 | acc | 說明 |
---|---|---|
BERT-base | 0.5660 | 本項目實現 |
BERT-base | 0.5609 | CLUE Benchmark Leaderboard結果 CLUEbenchmark |
- 以上結果均為分類的準確率(accuracy)結果
- THUCNews數據集評測結果可以基于
examples/thucnews_train_10w.txt
數據用examples
下的各模型demo復現 - TNEWS數據集評測結果可以下載TNEWS數據集,運行
examples/bert_classification_tnews_demo.py
復現
- 命令行調用
提供分類模型命令行調用腳本,文件樹:
pytextclassifier
├── bert_classifier.py
├── fasttext_classifier.py
├── classic_classifier.py
├── textcnn_classifier.py
└── textrnn_classifier.py
每個文件對應一個模型方法,各模型完全獨立,可以直接運行,也方便修改,支持通過argparse
修改--data_path
等參數。
直接在終端調用fasttext模型訓練:
python -m pytextclassifier.fasttext_classifier -h
6.文本聚類算法
Text clustering, for example examples/cluster_demo.py
import syssys.path.append('..')
from pytextclassifier.textcluster import TextClusterif __name__ == '__main__':m = TextCluster(output_dir='models/cluster-toy', n_clusters=2)print(m)data = ['Student debt to cost Britain billions within decades','Chinese education for TV experiment','Abbott government spends $8 million on higher education','Middle East and Asia boost investment in top level sports','Summit Series look launches HBO Canada sports doc series: Mudhar']m.train(data)m.load_model()r = m.predict(['Abbott government spends $8 million on higher education media blitz','Middle East and Asia boost investment in top level sports'])print(r)########### load chinese train data from 1w data filefrom sklearn.feature_extraction.text import TfidfVectorizertcluster = TextCluster(output_dir='models/cluster', feature=TfidfVectorizer(ngram_range=(1, 2)), n_clusters=10)data = tcluster.load_file_data('thucnews_train_1w.txt', sep='\t', use_col=1)feature, labels = tcluster.train(data[:5000])tcluster.show_clusters(feature, labels, 'models/cluster/cluster_train_seg_samples.png')r = tcluster.predict(data[:30])print(r)
output:
TextCluster instance (MiniBatchKMeans(n_clusters=2, n_init=10), <pytextclassifier.utils.tokenizer.Tokenizer object at 0x7f80bd4682b0>, TfidfVectorizer(ngram_range=(1, 2)))
[1 1 1 1 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 1 1 9 1 1 8 1 1 9 1]
clustering plot image:
參考鏈接:https://github.com/shibing624/pytextclassifier
如果github進入不了也可進入 https://download.csdn.net/download/sinat_39620217/88205140 免費下載相關資料