diff --git a/README.md b/README.md index e0d9c1671..7ded8ae35 100644 --- a/README.md +++ b/README.md @@ -2,42 +2,25 @@ extremeText is an extension of [fastText](https://github.com/facebookresearch/fastText) library for multi-label classification including extreme cases with hundreds of thousands and millions of labels. -extremeText adds new options for fastText supervised command: - -``` -$ ./extremetext supervised - -New losses for multi-label classification: - -loss sigmoid - -loss plt (Probabilistic Labels Tree) - -With the following optional arguments: - -treeType tree type of PLT: complete, huffman, kmeans (default = kmeans) - -ensemble number of trees in ensemble (default = 1) - -bagging bagging ration for ensemble (default = 1.0) - -l2 l2 regularization (default = 0) - -tfidfWeights calculates TF-IDF weights for words - -wordsWeights reads words weights from file (format: :) - -weight document weight prefix (default = __weight__; format: :) - -tag tags prefix (default = __tag__), tags are words tha are ingnored, by are outputed with prediction -``` +extremeText implements: -extremeText adds new commands and makes other to work in parallel: -``` -$ ./extremetext predict[-prob] [] [] [] [] -$ ./extremetext get-prob [] [] [] -``` +* Probabilistic Labels Tree (PLT) loss for extreme multi-Label classification with top-down hierarchical clustering (k-means) for tree building, +* sigmoid loss for multi-label classification, +* L2 regularization and FOBOS update for all losses, +* ensemble of loss layers with bagging, +* calculation of hidden (document) vector as a weighted average of the word vectors, +* calculation of TF-IDF weights for words. ## Installation ### Building executable -extremeText like fastText can be build as executable using Make or/and CMake: +extremeText like fastText can be build as executable using Make (recommended) or/and CMake: ``` $ git clone https://github.com/mwydmuch/extremeText.git $ cd extremeText -$ (optional) cmake . +(optional) $ cmake . $ make ``` @@ -45,38 +28,81 @@ This will produce object files for all the classes as well as the main binary `e ### Python package -The easiest way to get extremeText is to use [pip](https://pip.pypa.io/en/stable/): +The easiest way to get extremeText is to use [pip](https://pip.pypa.io/en/stable/). ``` -pip install extremetext +$ pip install extremetext ``` -The latest version of extremeText can be build from sources: +Installing on MacOS may require setting `MACOSX_DEPLOYMENT_TARGET=10.9` first: +``` +$ export MACOSX_DEPLOYMENT_TARGET=10.9 +$ pip install extremetext +``` + +The latest version of extremeText can be build from sources using pip or alternatively setuptools. ``` $ git clone https://github.com/mwydmuch/extremeText.git $ cd extremeText $ pip install . +(or) $ python setup.py install ``` -Alternatively you can also install extremeText using setuptools: +Now you can import this library with: ``` -$ git clone https://github.com/mwydmuch/extremeText.git -$ cd extremeText -$ python setup.py install +import extremeText +``` + +## Usage + +extremeText adds new options for fastText supervised command: + +``` +$ ./extremetext supervised + +New losses for multi-label classification: + -loss sigmoid + -loss plt (Probabilistic Labels Tree) + +With the following optional arguments: + General: + -l2 L2 regularization (default = 0) + -fobos use FOBOS update + -tfidfWeights calculate TF-IDF weights for words + -wordsWeights read word weights from file (format: :) + -weight document weight prefix (default = __weight__; format: :) + -tag tags prefix (default = __tag__), tags are ignored words, that are outputed with prediction + -addEosToken add EOS token at the end of document (default = 0) + -eosWeight weight of EOS token (default = 1.0) + + PLT (Probabilistic Labels Tree): + -treeType type of PLT: complete, huffman, kmeans (default = kmeans) + -arity arity of PLT (default = 2) + -maxLeaves maximum number of leaves (labels) in one internal node of PLT (default = 100) + -kMeansEps stopping criteria for k-means clustering (default = 0.001) + + Ensemble: + -ensemble size of the ensemble (default = 1) + -bagging bagging ratio (default = 1.0) +``` + +extremeText also adds new commands and makes other to work in parallel: +``` +$ ./extremetext predict[-prob] [] [] [] [] +$ ./extremetext get-prob [] [] [] ``` ## Reference -Please cite below work if using this code for classification. +Please cite below work if using this code for extreme classification. M. Wydmuch, K. Jasinska, M. Kuznetsov, R. Busa-Fekete, K. Dembczyński, [*A no-regret generalization of hierarchical softmax to extreme multi-label classification*](https://arxiv.org/abs/1810.11671) ## TODO -* Merge with latest changes from fastText. -* Rewrite vanilla fastText losses as extremeText loss layers to support new features. - +* Merge with the latest changes from fastText. +* Rewrite vanilla fastText losses as extremeText loss layers to support all new features. --- diff --git a/python/README.md b/python/README.md index fb63495e7..91c767b13 100644 --- a/python/README.md +++ b/python/README.md @@ -2,6 +2,15 @@ [extremeText](https://github.com/mwydmuch/extremeText) is an extension of [fastText](https://github.com/facebookresearch/fastText) library for multi-label classification including extreme cases with hundreds of thousands and millions of labels. +[extremeText](https://github.com/mwydmuch/extremeText) implements: + +* Probabilistic Labels Tree (PLT) loss for extreme multi-Label classification with top-down hierarchical clustering (k-means) for tree building, +* sigmoid loss for multi-label classification, +* L2 regularization and FOBOS update for all losses, +* ensemble of loss layers with bagging, +* calculation of hidden (document) vector as a weighted average of the word vectors, +* calculation of TF-IDF weights for words. + ## Requirements [extremeText](https://github.com/mwydmuch/extremeText) builds on modern Mac OS and Linux distributions. @@ -18,26 +27,25 @@ You will need: ## Installing extremeText -The easiest way to get extremeText is to use [pip](https://pip.pypa.io/en/stable/): +The easiest way to get [extremeText](https://github.com/mwydmuch/extremeText) is to use [pip](https://pip.pypa.io/en/stable/). ``` -pip install extremetext +$ pip install extremetext ``` -The latest version of extremeText can be build from sources: - +Installing on MacOS may require setting `MACOSX_DEPLOYMENT_TARGET=10.9` first: ``` -$ git clone https://github.com/mwydmuch/extremeText.git -$ cd extremeText -$ pip install . +$ export MACOSX_DEPLOYMENT_TARGET=10.9 +$ pip install extremetext ``` -Alternatively you can also install extremeText using setuptools: +The latest version of [extremeText](https://github.com/mwydmuch/extremeText) can be build from sources using pip or alternatively setuptools. ``` $ git clone https://github.com/mwydmuch/extremeText.git $ cd extremeText -$ python setup.py install +$ pip install . +(or) $ python setup.py install ``` Now you can import this library with: @@ -54,7 +62,7 @@ We recommend you look at the [examples within the doc folder](https://github.com As with any package you can get help on any Python function using the help function. -For example +For example: ``` +>>> import extremeText @@ -84,7 +92,7 @@ FUNCTIONS ## IMPORTANT: Preprocessing data / enconding conventions -In general it is important to properly preprocess your data. In particular our example scripts in the [root folder](https://github.com/mwydmuch/extremeText/extremeText) do this. +In general it is important to properly preprocess your data. Example scripts in the [root folder](https://github.com/mwydmuch/extremeText/extremeText) do this. extremeText like fastText assumes UTF-8 encoded text. All text must be [unicode for Python2](https://docs.python.org/2/library/functions.html#unicode) and [str for Python3](https://docs.python.org/3.5/library/stdtypes.html#textseq). The passed text will be [encoded as UTF-8 by pybind11](https://pybind11.readthedocs.io/en/master/advanced/cast/strings.html?highlight=utf-8#strings-bytes-and-unicode-conversions) before passed to the extremeText C++ library. This means it is important to use UTF-8 encoded text when building a model. On Unix-like systems you can convert text using [iconv](https://en.wikipedia.org/wiki/Iconv). @@ -100,3 +108,9 @@ extremeText will tokenize (split text into pieces) based on the following ASCII The newline character is used to delimit lines of text. In particular, the EOS token is appended to a line of text if a newline character is encountered. The only exception is if the number of tokens exceeds the MAX\_LINE\_SIZE constant as defined in the [Dictionary header](https://github.com/mwydmuch/extremeText/blob/master/src/dictionary.h). This means if you have text that is not separate by newlines, such as the [fil9 dataset](http://mattmahoney.net/dc/textdata), it will be broken into chunks with MAX\_LINE\_SIZE of tokens and the EOS token is not appended. The length of a token is the number of UTF-8 characters by considering the [leading two bits of a byte](https://en.wikipedia.org/wiki/UTF-8#Description) to identify [subsequent bytes of a multi-byte sequence](https://github.com/mwydmuch/extremeText/blob/master/src/dictionary.cc). Knowing this is especially important when choosing the minimum and maximum length of subwords. Further, the EOS token (as specified in the [Dictionary header](https://github.com/mwydmuch/extremeText/blob/master/src/dictionary.h)) is considered a character and will not be broken into subwords. + +## Reference + +Please cite below work if using this package for extreme classification. + +M. Wydmuch, K. Jasinska, M. Kuznetsov, R. Busa-Fekete, K. Dembczyński, [A no-regret generalization of hierarchical softmax to extreme multi-label classification](https://arxiv.org/abs/1810.11671) \ No newline at end of file diff --git a/python/README.rst b/python/README.rst index a1ad60f2d..3dbae09e8 100644 --- a/python/README.rst +++ b/python/README.rst @@ -1,51 +1,68 @@ extremeText -======== +=========== -`extremeText `__ is an extension -of `fastText `__ library -for multi-label classification including extreme cases with hundreds of thousands and millions of labels. +`extremeText `__ is an +extension of `fastText `__ +library for multi-label classification including extreme cases with +hundreds of thousands and millions of labels. + +`extremeText `__ implements: + +- Probabilistic Labels Tree (PLT) loss for extreme multi-Label + classification with top-down hierarchical clustering (k-means) for + tree building, +- sigmoid loss for multi-label classification, +- L2 regularization and FOBOS update for all losses, +- ensemble of loss layers with bagging, +- calculation of hidden (document) vector as a weighted average of the + word vectors, +- calculation of TF-IDF weights for words. Requirements ------------ -`extremeText `__ builds on modern Mac OS and Linux -distributions. Since it uses C++11 features, it requires a compiler with -good C++11 support. These include : +`extremeText `__ builds on +modern Mac OS and Linux distributions. Since it uses C++11 features, it +requires a compiler with good C++11 support. These include: - (gcc-4.8 or newer) or (clang-3.3 or newer) -You will need +You will need: - `Python `__ version 2.7 or >=3.4 - `NumPy `__ & `SciPy `__ - `pybind11 `__ -Building extremeText ------------------ +Installing extremeText +---------------------- -The easiest way to get extremeText is to use -pip `__: +The easiest way to get +`extremeText `__ is to use +`pip `__. :: $ pip install extremetext -The latest version of extremeText can be build from sources: +Installing on MacOS may require setting +``MACOSX_DEPLOYMENT_TARGET=10.9`` first: :: - $ git clone https://github.com/mwydmuch/extremeText.git - $ cd extremeText - $ pip install . + $ export MACOSX_DEPLOYMENT_TARGET=10.9 + $ pip install extremetext -Alternatively you can also install extremeText using setuptools: +The latest version of +`extremeText `__ can be build +from sources using pip or alternatively setuptools. :: $ git clone https://github.com/mwydmuch/extremeText.git $ cd extremeText - $ python setup.py install + $ pip install . + (or) $ python setup.py install Now you can import this library with: @@ -56,10 +73,11 @@ Now you can import this library with: Examples -------- -In general it is assumed that the reader already has good knowledge of fastText/extremeText. - For this consider the main +In general it is assumed that the reader already has good knowledge of +fastText/extremeText. For this consider the main `README `__ -and `the tutorials on fastText website `__. +and `the tutorials on fastText +website `__. We recommend you look at the `examples within the doc folder `__. @@ -67,7 +85,7 @@ folder As with any package you can get help on any Python function using the help function. -For example +For example: :: @@ -98,23 +116,25 @@ For example IMPORTANT: Preprocessing data / enconding conventions ----------------------------------------------------- -In general it is important to properly preprocess your data. In -particular our example scripts in the `root -folder `__ do this. +In general it is important to properly preprocess your data. Example +scripts in the `root +folder `__ do this. -extremeText like extremeText assumes UTF-8 encoded text. All text must be `unicode for +extremeText like fastText assumes UTF-8 encoded text. All text must be +`unicode for Python2 `__ and `str for Python3 `__. The passed text will be `encoded as UTF-8 by pybind11 `__ -before passed to the extremeText C++ library. This means it is important to -use UTF-8 encoded text when building a model. On Unix-like systems you -can convert text using `iconv `__. - -extremeText like fastText will tokenize (split text into pieces) based on the following -ASCII characters (bytes). In particular, it is not aware of UTF-8 -whitespace. We advice the user to convert UTF-8 whitespace / word +before passed to the extremeText C++ library. This means it is important +to use UTF-8 encoded text when building a model. On Unix-like systems +you can convert text using +`iconv `__. + +extremeText will tokenize (split text into pieces) based on the +following ASCII characters (bytes). In particular, it is not aware of +UTF-8 whitespace. We advice the user to convert UTF-8 whitespace / word boundaries into one of the following symbols as appropiate. - space @@ -144,3 +164,12 @@ maximum length of subwords. Further, the EOS token (as specified in the `Dictionary header `__) is considered a character and will not be broken into subwords. + +Reference +--------- + +Please cite below work if using this package for extreme classification. + +M. Wydmuch, K. Jasinska, M. Kuznetsov, R. Busa-Fekete, K. Dembczyński, +`A no-regret generalization of hierarchical softmax to extreme +multi-label classification `__ diff --git a/python/doc/examples/bin_to_vec.py b/python/doc/examples/bin_to_vec.py index 4f01908b5..78c8a2eff 100644 --- a/python/doc/examples/bin_to_vec.py +++ b/python/doc/examples/bin_to_vec.py @@ -19,7 +19,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( - description=("Print fasttext .vec file to stdout from .bin file") + description=("Print fasttext/extremetext .vec file to stdout from .bin file") ) parser.add_argument( "model", diff --git a/python/doc/examples/train_supervised.py b/python/doc/examples/train_supervised.py index 5b75c8455..c92f5ae7f 100644 --- a/python/doc/examples/train_supervised.py +++ b/python/doc/examples/train_supervised.py @@ -21,20 +21,15 @@ def print_results(N, p, r, c): print("N\t" + str(N)) print("P@{}\t{:.3f}".format(1, p)) print("R@{}\t{:.3f}".format(1, r)) + print("C@{}\t{:.3f}".format(1, c)) if __name__ == "__main__": train_data = os.path.join(os.getenv("DATADIR", ''), 'dbpedia.train') valid_data = os.path.join(os.getenv("DATADIR", ''), 'dbpedia.test') - # train_supervised uses the same arguments and defaults as the fastText cli - # model = train_supervised( - # input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=3, minCount=1 - # ) - # print_results(*model.test(valid_data)) - + #train_supervised uses the same arguments and defaults as the fastText/extremeText cli model = train_supervised( input=train_data, epoch=25, lr=1.0, wordNgrams=2, verbose=3, minCount=1, - loss="sigmoid" ) print_results(*model.test(valid_data)) model.save_model("dbpedia.bin") diff --git a/python/doc/examples/train_unsupervised.py b/python/doc/examples/train_unsupervised.py index e1585928d..9943aa19e 100644 --- a/python/doc/examples/train_unsupervised.py +++ b/python/doc/examples/train_unsupervised.py @@ -51,6 +51,7 @@ def similarity(v1, v2): model = train_unsupervised( input=os.path.join(os.getenv("DATADIR", ''), 'fil9'), model='skipgram', + verbose=3 ) model.save_model("fil9.bin") dataset, corr, oov = compute_similarity('rw.txt') diff --git a/setup.py b/setup.py index 1be513f11..fd94fea70 100644 --- a/setup.py +++ b/setup.py @@ -144,7 +144,7 @@ def _get_readme(): version=__version__, author='Marek Wydmuch', author_email='mwydmuch@cs.put.poznan.pl', - description='extremeText Python bindings', + description='A Python interface for extremeText library', long_description=_get_readme(), ext_modules=ext_modules, url='https://github.com/mwydmuch/extremeText', diff --git a/src/args.cc b/src/args.cc index 8fa99cf26..e05209b58 100644 --- a/src/args.cc +++ b/src/args.cc @@ -53,7 +53,8 @@ Args::Args() { // Features args wordsWeights = false; tfidfWeights = false; - addEosToken = true; + addEosToken = false; + eosWeight = 1.0; weight = "__weight__"; tag = "__tag__"; @@ -72,16 +73,15 @@ Args::Args() { probNorm = false; maxLeaves = 100; - // KMeans + // K-means kMeansEps = 0.001; kMeansBalanced = true; // Update args l2 = 0; fobos = false; - labelsWeights = false; - // Bagging args + // Ensemble args bagging = 1.0; ensemble = 1; } @@ -122,6 +122,20 @@ std::string Args::modelToString(model_name mn) const { return "Unknown model name!"; // should never happen } +std::string Args::treeTypeToString(tree_type_name ttn) const { + switch (ttn) { + case tree_type_name::huffman: + return "huffman"; + case tree_type_name::complete: + return "complete"; + case tree_type_name::kmeans: + return "kmeans"; + case tree_type_name::custom: + return "custom"; + } + return "Unknown tree type name!"; // should never happen +} + void Args::parseArgs(const std::vector& args) { std::string command(args[1]); if (command == "supervised") { @@ -222,6 +236,8 @@ void Args::parseArgs(const std::vector& args) { } else if (args[ai] == "-addEosToken") { addEosToken = true; ai--; + } else if (args[ai] == "-eosWeight") { + eosWeight = std::stof(args.at(ai + 1)); } else if (args[ai] == "-probNorm") { probNorm = true; ai--; @@ -253,9 +269,6 @@ void Args::parseArgs(const std::vector& args) { } else if (args[ai] == "-fobos") { fobos = true; ai--; - } else if (args[ai] == "-labelsWeights") { - labelsWeights = true; - ai--; } else if (args[ai] == "-treeStructure") { treeStructure = std::string(args.at(ai + 1)); } else if (args[ai] == "-randomTree") { @@ -298,6 +311,9 @@ void Args::parseArgs(const std::vector& args) { if (wordNgrams <= 1 && maxn == 0) { bucket = 0; } + if (wordsWeights) { + tfidfWeights = false; + } } void Args::printHelp() { @@ -327,7 +343,11 @@ void Args::printDictionaryHelp() { << " -minn min length of char ngram [" << minn << "]\n" << " -maxn max length of char ngram [" << maxn << "]\n" << " -t sampling threshold [" << t << "]\n" - << " -label labels prefix [" << label << "]\n"; + << " -label labels prefix [" << label << "]\n" + << " -weight document weight prefix [" << weight << "]\n" + << " -tag tags prefix [" << tag << "]\n" + << " -tfidfWeights calculate TF-IDF weights for words\n" + << " -wordsWeights read words weights from file (format: :)\n"; } void Args::printTrainingHelp() { @@ -335,16 +355,25 @@ void Args::printTrainingHelp() { << "\nThe following arguments for training are optional:\n" << " -lr learning rate [" << lr << "]\n" << " -lrUpdateRate change the rate of updates for the learning rate [" << lrUpdateRate << "]\n" - << " -l2 l2 regularization [" << l2 << "]\n" + << " -l2 L2 regularization [" << l2 << "]\n" + << " -fobos use FOBOS update [" << boolToString(fobos) << "]\n" << " -dim size of word vectors [" << dim << "]\n" << " -ws size of the context window [" << ws << "]\n" << " -epoch number of epochs [" << epoch << "]\n" << " -neg number of negatives sampled [" << neg << "]\n" - << " -loss loss function {ns, hs, softmax} [" << lossToString(loss) << "]\n" + << " -loss loss function {ns, hs, softmax, plt, sigmoid} [" << lossToString(loss) << "]\n" << " -thread number of threads [" << thread << "]\n" << " -pretrainedVectors pretrained word vectors for supervised learning ["<< pretrainedVectors <<"]\n" - << " -wordsWeights TODO" - << " -saveOutput whether output params should be saved [" << boolToString(saveOutput) << "]\n"; + << " -saveOutput whether output params should be saved [" << boolToString(saveOutput) << "]\n" + << " -saveVectors whether word vectors should be saved [" << boolToString(saveVectors) << "]\n" + << " -treeType type of PLT {complete, huffman, kmeans} [" << treeTypeToString(treeType) << ")\n" + << " -arity arity of PLT [" << arity << "]\n" + << " -maxLeaves maximum number of leaves (labels) in one internal node of PLT [" << maxLeaves <<"]\n" + << " -kMeansEps stopping criteria for k-means clustering [" << kMeansEps << "]\n" + << " -ensemble size of the ensemble [" << ensemble << "]\n" + << " -bagging bagging ratio [" << bagging << "]\n" + << " -addEosToken add EOS token at the end of document [" << boolToString(addEosToken) << "]\n" + << " -eosWeight weight of EOS token [" << eosWeight << "]\n"; } void Args::printQuantizationHelp() { @@ -360,14 +389,19 @@ void Args::printQuantizationHelp() { void Args::printInfo(){ std::cerr << " Model: " << modelToString(model) << ", loss: " << lossToString(loss) << "\n Features: "; if(model == model_name::sup){ - if(tfidfWeights) std::cerr << "tf-idf weights\n"; - else if(wordsWeights) std::cerr << "word weights\n"; - else std::cerr << "bow\n"; + if(tfidfWeights) std::cerr << "TF-IDF weights"; + else if(wordsWeights) std::cerr << "words weights"; + else std::cerr << "BOW"; } - if(ensemble > 1) std::cerr << " Ensemble: " << ensemble << ", bagging ratio: " << bagging << "\n"; - std::cerr << " Lr: " << lr << ", L2: " << l2 << ", dims: " << dim << ", epochs: " << epoch - << ", buckets: " << bucket << ", neg: " << neg << "\n"; - //std::cerr << " Fobos: " << fobos << ", prob. norm.: " << probNorm << "\n"; + std::cerr << ", buckets: " << bucket << std::endl; + if(loss == loss_name::plt) std::cerr << " Tree type: " << treeTypeToString(treeType) << ", arity: " << arity << ", maxLeaves: " << maxLeaves; + if(loss == loss_name::plt && treeType == tree_type_name::kmeans) std::cerr << ", kMeansEps: " << kMeansEps << std::endl; + else std::cerr << std::endl; + if(ensemble > 1) std::cerr << " Ensemble: " << ensemble << std::endl; //", bagging ratio: " << bagging << std::endl; + std::cerr << " Update: "; + if(fobos) std::cerr << "FOBOS"; + else std::cerr << "SGD"; + std::cerr << ", lr: " << lr << ", L2: " << l2 << ", dims: " << dim << ", epochs: " << epoch << ", neg: " << neg << std::endl; } void Args::save(std::ostream& out) { @@ -450,6 +484,17 @@ void Args::dump(std::ostream& out) const { out << "maxn" << " " << maxn << std::endl; out << "lrUpdateRate" << " " << lrUpdateRate << std::endl; out << "t" << " " << t << std::endl; + out << "l2" << " " << l2 << std::endl; + out << "fobos" << " " << fobos << std::endl; + out << "ensemble" << " " << ensemble << std::endl; + if(loss == loss_name::plt) { + out << "treeType" << " " << treeTypeToString(treeType) << std::endl; + out << "arity" << " " << arity << std::endl; + out << "maxLeaves" << " " << maxLeaves << std::endl; + if(treeType == tree_type_name::kmeans) { + out << "kMeansEps" << " " << kMeansEps << std::endl; + } + } } } diff --git a/src/args.h b/src/args.h index ec25d5abf..dc6236177 100644 --- a/src/args.h +++ b/src/args.h @@ -28,6 +28,7 @@ class Args { std::string lossToString(loss_name) const; std::string boolToString(bool) const; std::string modelToString(model_name) const; + std::string treeTypeToString(tree_type_name) const; public: Args(); @@ -66,6 +67,7 @@ class Args { bool wordsWeights; bool tfidfWeights; bool addEosToken; + real eosWeight; std::string weight; std::string tag; @@ -84,18 +86,17 @@ class Args { bool randomTree; int maxLeaves; - // KMeans + // K-means real kMeansEps; bool kMeansBalanced; // Update args bool fobos; real l2; - bool labelsWeights; // Ensemble args - real bagging; int ensemble; + real bagging; void parseArgs(const std::vector& args); void printHelp(); diff --git a/src/dictionary.cc b/src/dictionary.cc index 6736d41d5..251aa2fa4 100644 --- a/src/dictionary.cc +++ b/src/dictionary.cc @@ -227,6 +227,7 @@ bool Dictionary::readWord(std::istream& in, std::string& word, real& value) cons if (word.empty()) { if (c == '\n') { word += EOS; + value = args_->eosWeight; return true; } continue; @@ -421,7 +422,7 @@ real Dictionary::getLine(std::istream& in, words_values.clear(); tags.clear(); - // for tf-idf + // For TF-IDF std::vector doc_counts; while (readWord(in, token, value)) { @@ -434,6 +435,8 @@ real Dictionary::getLine(std::istream& in, continue; } + if(token == EOS && !args_->addEosToken) break; + uint32_t h = utils::hash(token); int32_t wid = getId(token, h); entry_type type = wid < 0 ? getType(token) : getType(wid); @@ -442,7 +445,7 @@ real Dictionary::getLine(std::istream& in, if (type == entry_type::word) { if (args_->tfidfWeights) { addSubwordsTfIdf(words, token, doc_counts, wid); - // TODO: support for word hashes for tf-idf + // TODO: support for word hashes for TF-IDF } else { addSubwords(words, token, words_values, value, wid); word_hashes.push_back(h); @@ -450,16 +453,10 @@ real Dictionary::getLine(std::istream& in, } else if (type == entry_type::label && wid >= 0) { labels.push_back(wid - nwords_); } + if (token == EOS) break; } - //TODO: add support for word ngrams - /* - addWordNgrams(words, word_hashes, args_->wordNgrams); - while(words.size() != words_values.size()) - words_values.push_back(1); - */ - real values_sum = 0; if(args_->tfidfWeights){ assert(words.size() == doc_counts.size()); @@ -478,23 +475,21 @@ real Dictionary::getLine(std::istream& in, values_sum += tfidf; } } else { - assert(words.size() == words_values.size()); for(auto &it : words_values) values_sum += it; } + // TODO: support for wordNgrams for TF-IDF + if(!args_->tfidfWeights && !args_->wordsWeights) { + addWordNgrams(words, word_hashes, args_->wordNgrams); + while (words.size() != words_values.size()) + words_values.push_back(1.0); + } + for(auto &it : words_values) it /= values_sum / words.size(); - // Add EOS word - auto eosId = getId(EOS, utils::hash(EOS)); - if(args_->addEosToken && words.back() != eosId) { - words.push_back(eosId); - words_values.push_back(1.0); - } else if (words.back() == eosId){ - words.pop_back(); - words_values.pop_back(); - } + assert(words.size() == words_values.size()); return weight; } diff --git a/src/dictionary.h b/src/dictionary.h index 4cac9b0ed..521bf46a1 100644 --- a/src/dictionary.h +++ b/src/dictionary.h @@ -30,7 +30,7 @@ enum class entry_type : int8_t {word=0, label=1}; struct entry { std::string word; int64_t count; - int32_t doc_count; + int64_t doc_count; entry_type type; std::vector subwords; }; @@ -64,7 +64,7 @@ class Dictionary { int32_t nwords_; int32_t nlabels_; int64_t ntokens_; - int32_t ndocs_; + int64_t ndocs_; int64_t pruneidx_size_; std::unordered_map pruneidx_; diff --git a/src/loss_plt.cc b/src/loss_plt.cc index d33f767f1..b87ed9907 100644 --- a/src/loss_plt.cc +++ b/src/loss_plt.cc @@ -424,49 +424,48 @@ real PLT::getLabelP(int32_t label, Vector &hidden, const Model *model_){ std::vector path; NodePLT *n = tree_leaves[label]; - real p = 1.0; if(!args_->probNorm){ - p = predictNode(n, hidden, model_); + real p = predictNode(n, hidden, model_); while(n->parent) { n = n->parent; p = p * predictNode(n, hidden, model_); } assert(n == tree_root); return p; - } - path.push_back(n); - while (n->parent) { - n = n->parent; + } else { path.push_back(n); - } - - assert(tree_root == n); - assert(tree_root == path.back()); + while (n->parent) { + n = n->parent; + path.push_back(n); + } - p = predictNode(tree_root, hidden, model_); - for(auto n = path.rbegin(); n != path.rend(); ++n){ - if ((*n)->label < 0) { + assert(tree_root == n); + assert(tree_root == path.back()); - //TODO: rewrite - /* - for (auto child : (*n)->children) { - normChildren.push_back({child, }) - child->p = cp * predictNode(child, hidden, model_); - sumOfP += child->p; - } - if ((sumOfP < cp) //&& (sumOfP > 10e-6)) { + real p = predictNode(tree_root, hidden, model_); + for (auto n = path.rbegin(); n != path.rend(); ++n) { + if ((*n)->label < 0) { + //TODO: rewrite this part + /* for (auto child : (*n)->children) { - child->p = (child->p * cp) / sumOfP; + normChildren.push_back({child, }) + child->p = cp * predictNode(child, hidden, model_); + sumOfP += child->p; } + if ((sumOfP < cp) //&& (sumOfP > 10e-6)) { + for (auto child : (*n)->children) { + child->p = (child->p * cp) / sumOfP; + } + } + float sumOfP = 0.0f; + */ } - float sumOfP = 0.0f; - */ } - } - return p; + return p; + } } void PLT::setup(std::shared_ptr dict, uint32_t seed){ @@ -492,9 +491,11 @@ int32_t PLT::getSize(){ } void PLT::printInfo(){ + /* std::cerr << " Avg n vis: " << static_cast(n_vis_count) / x_count << "\n"; std::cerr << " Avg n in vis: " << static_cast(n_in_vis_count) / x_count << "\n"; std::cerr << " Avg y: " << static_cast(y_count) / x_count << "\n"; + */ } void PLT::save(std::ostream& out){ diff --git a/src/model.cc b/src/model.cc index 6d2495477..af1a2dd85 100644 --- a/src/model.cc +++ b/src/model.cc @@ -54,9 +54,23 @@ void Model::setQuantizePointer(std::shared_ptr qwi, real Model::binaryLogistic(int32_t target, bool label, real lr) { real score = sigmoid(wo_->dotRow(hidden_, target)); + real diff = real(label) - score; + + // Original update + /* real alpha = lr * (real(label) - score); grad_.addRow(*wo_, target, alpha); wo_->addRow(hidden_, target, alpha); + */ + + if(args_->fobos){ + grad_.addRowL2Fobos(*wo_, target, lr, diff, args_->l2); + wo_->addRowL2Fobos(hidden_, target, lr, diff, args_->l2); + } else { + grad_.addRowL2(*wo_, target, lr, diff, args_->l2); + wo_->addRowL2(hidden_, target, lr, diff, args_->l2); + } + if (label) { return -log(score); } else {