本项目包含了面向中文的两种分类模型,包括了multi-class
和multi-label
两种分类问题,采用的预训练模型为目前SOTA的RoBERTa-zh-Large模型进行finetuning。
两种文本分类的区别:
multi-class
:单条样本仅对应单个分类名
multi-label
:单条样本可能对应多个分类名
data文件夹中存放train.csv, dev.csv, test.csv
文件格式:(label部分类似one-hot格式)
index | text | a | b | c | d | e | f |
---|---|---|---|---|---|---|---|
0 | A | 1 | 1 | 1 | 0 | 0 | 0 |
1 | B | 1 | 0 | 1 | 0 | 0 | 0 |
2 | C | 1 | 1 | 0 | 0 | 0 | 1 |
在bert代码的基础之上,仅需修改run_classifier.py
文件,
其中添加MultiLabelClassifierProcessor
类,将其中tf.nn.softmax()
改为sigmoid
更多细节详见代码
运行如下代码,BERT_BASE_DIR
为加载预训练模型,DATA_DIR
为数据存放文件夹,TRAINED_CLASSIFIER
为加载模型
export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_label_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
python run_classifier.py \
--task_name=mlc \
--do_train=true \
--do_eval=true \
--data_dir=$DATA_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=512 \
--output_dir=./multi_label_classifier/mlc_output/
运行如下代码,BERT_BASE_DIR
为加载预训练模型,DATA_DIR
为数据存放文件夹,TRAINED_CLASSIFIER
为已fine-tuning后的模型
export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_label_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/multi_label_classifier/mlc_output/model.ckpt-10000
python run_classifier.py \
--task_name=mlc \
--do_train=true \
--do_eval=true \
--do_predict=true \
--data_dir=$DATA_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=512 \
--output_dir=./multi_label_classifier/mlc_output/
data文件夹中存放train.csv, dev.csv, test.csv
文件格式:(label部分类似one-hot格式)
index | text | label |
---|---|---|
0 | text_A | c |
1 | text_B | a |
2 | text_C | a |
3 | text_D | b |
4 | text_E | b |
在bert代码的基础之上,仅需修改run_classifier.py
文件,
其中添加MultiClassClassifierProcessor
类
更多细节详见代码
运行如下代码,BERT_BASE_DIR
为加载预训练模型,DATA_DIR
为数据存放文件夹,TRAINED_CLASSIFIER
为加载模型
export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_class_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
python run_classifier.py \
--task_name=mcc \
--do_train=true \
--do_eval=true \
--data_dir=$DATA_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=512 \
--output_dir=./multi_class_classifier/mcc_output/
运行如下代码,BERT_BASE_DIR
为加载预训练模型,DATA_DIR
为数据存放文件夹,TRAINED_CLASSIFIER
为已fine-tuning后的模型
export BERT_BASE_DIR=./pretrain_models/roeberta_zh_L-24_H-1024_A-16
export DATA_DIR=./multi_class_classifier/data
export TRAINED_CLASSIFIER=./pretrain_models/multi_label_classifier/mcc_output/model.ckpt-10000
python run_classifier.py \
--task_name=mcc \
--do_train=true \
--do_eval=true \
--do_predict=true \
--data_dir=$DATA_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=512 \
--output_dir=./multi_class_classifier/mcc_output/