Skip to content

huajingyun/nlp_classifier_zh

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 

Repository files navigation

中文说明 | English

中文文本分类模型(nlp_classifier_zh)

本项目包含了面向中文的两种分类模型,包括了multi-classmulti-label两种分类问题,采用的预训练模型为目前SOTA的RoBERTa-zh-Large模型进行finetuning。

两种文本分类的区别
multi-class:单条样本仅对应单个分类名
multi-label:单条样本可能对应多个分类名

中文文本分类模型(multi-label)

1. 数据准备

data文件夹中存放train.csv, dev.csv, test.csv
文件格式:(label部分类似one-hot格式)

index text a b c d e f
0 A 1 1 1 0 0 0
1 B 1 0 1 0 0 0
2 C 1 1 0 0 0 1

2.代码

在bert代码的基础之上,仅需修改run_classifier.py文件,
其中添加MultiLabelClassifierProcessor类,将其中tf.nn.softmax()改为sigmoid
更多细节详见代码

3. 模型训练

运行如下代码,BERT_BASE_DIR为加载预训练模型,DATA_DIR为数据存放文件夹,TRAINED_CLASSIFIER为加载模型

export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_label_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/roeberta_zh_L-24_H-1024_A-16

python run_classifier.py \
  --task_name=mlc \
  --do_train=true \
  --do_eval=true \
  --data_dir=$DATA_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=512 \
  --output_dir=./multi_label_classifier/mlc_output/

4. 模型预测

运行如下代码,BERT_BASE_DIR为加载预训练模型,DATA_DIR为数据存放文件夹,TRAINED_CLASSIFIER为已fine-tuning后的模型

export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_label_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/multi_label_classifier/mlc_output/model.ckpt-10000

python run_classifier.py \
  --task_name=mlc \
  --do_train=true \
  --do_eval=true \
  --do_predict=true \
  --data_dir=$DATA_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=512 \
  --output_dir=./multi_label_classifier/mlc_output/

中文文本分类模型(multi-class)

1. 数据准备

data文件夹中存放train.csv, dev.csv, test.csv
文件格式:(label部分类似one-hot格式)

index text label
0 text_A c
1 text_B a
2 text_C a
3 text_D b
4 text_E b

2.代码

在bert代码的基础之上,仅需修改run_classifier.py文件,
其中添加MultiClassClassifierProcessor类 更多细节详见代码

3. 模型训练

运行如下代码,BERT_BASE_DIR为加载预训练模型,DATA_DIR为数据存放文件夹,TRAINED_CLASSIFIER为加载模型

export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_class_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/roeberta_zh_L-24_H-1024_A-16

python run_classifier.py \
  --task_name=mcc \
  --do_train=true \
  --do_eval=true \
  --data_dir=$DATA_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=512 \
  --output_dir=./multi_class_classifier/mcc_output/

4. 模型预测

运行如下代码,BERT_BASE_DIR为加载预训练模型,DATA_DIR为数据存放文件夹,TRAINED_CLASSIFIER为已fine-tuning后的模型

export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_class_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/multi_label_classifier/mcc_output/model.ckpt-10000

python run_classifier.py \
  --task_name=mcc \
  --do_train=true \
  --do_eval=true \
  --do_predict=true \
  --data_dir=$DATA_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=512 \
  --output_dir=./multi_class_classifier/mcc_output/

Reference:

  1. RoBERTa: A Robustly Optimized BERT Pretraining Approach
  2. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

About

include multi-class case and multi-label case

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages