Skip to content

Commit

Permalink
Merge pull request caikit#351 from evaline-ju/llm-tok
Browse files Browse the repository at this point in the history
✨ Add tokenization task to generation modules
  • Loading branch information
gkumbhat authored Apr 25, 2024
2 parents c12cb82 + 79f20d8 commit d81f11e
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 7 deletions.
21 changes: 19 additions & 2 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@
ClassificationTrainRecord,
GeneratedTextResult,
GeneratedTextStreamResult,
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
import alog

# Local
Expand Down Expand Up @@ -87,7 +88,7 @@
id="6655831b-960a-4dc5-8df4-867026e2cd41",
name="Peft generation",
version="0.1.0",
task=TextGenerationTask,
tasks=[TextGenerationTask, TokenizationTask],
)
class PeftPromptTuning(ModuleBase):

Expand Down Expand Up @@ -274,6 +275,22 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
def run_tokenizer(
self,
text: str,
) -> TokenizationResults:
"""Run tokenization task against the model
Args:
text: str
Text to tokenize
Returns:
TokenizationResults
The token count
"""
raise NotImplementedError("Tokenization not implemented for local")

@classmethod
def train(
cls,
Expand Down
23 changes: 20 additions & 3 deletions caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from caikit.core.data_model import DataStream
from caikit.core.exceptions import error_handler
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.interfaces.nlp.data_model import GeneratedTextResult
from caikit.interfaces.nlp.tasks import TextGenerationTask
from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
import alog

# Local
Expand Down Expand Up @@ -60,7 +60,7 @@
id="f9181353-4ccf-4572-bd1e-f12bcda26792",
name="Text Generation",
version="0.1.0",
task=TextGenerationTask,
tasks=[TextGenerationTask, TokenizationTask],
)
class TextGeneration(ModuleBase):
"""Module to provide text generation capabilities"""
Expand Down Expand Up @@ -521,6 +521,7 @@ def save(self, model_path):
json.dump(loss_log, f)
f.write("\n")

@TextGenerationTask.taskmethod()
def run(
self,
text: str,
Expand Down Expand Up @@ -575,6 +576,22 @@ def run(
**kwargs,
)

@TokenizationTask.taskmethod()
def run_tokenizer(
self,
text: str,
) -> TokenizationResults:
"""Run tokenization task against the model
Args:
text: str
Text to tokenize
Returns:
TokenizationResults
The token count
"""
raise NotImplementedError("Tokenization not implemented for local")

################################## Private Functions ######################################

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion caikit_nlp/toolkit/torch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

# Third Party
from torch import cuda
from torch.distributed.launcher.api import LaunchConfig, Std
from torch.distributed.elastic.multiprocessing.api import Std
from torch.distributed.launcher.api import LaunchConfig
import torch.distributed as dist

# First Party
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"scipy>=1.8.1",
"sentence-transformers>=2.3.1,<2.4.0",
"tokenizers>=0.13.3",
"torch>=2.0.1",
"torch>=2.0.1,<2.3.0",
"tqdm>=4.65.0",
"transformers>=4.32.0",
"peft==0.6.0",
Expand Down
8 changes: 8 additions & 0 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,14 @@ def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model):
assert isinstance(pred, GeneratedTextResult)


def test_run_tokenizer_not_implemented(causal_lm_dummy_model):
with pytest.raises(NotImplementedError):
causal_lm_dummy_model.run_tokenizer("This text doesn't matter")


######################## Test train ###############################################


def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_device):
"""Check if we are able to throw error for when number of examples are more than configured limit"""
patch_kwargs = {
Expand Down
6 changes: 6 additions & 0 deletions tests/modules/text_generation/test_text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,9 @@ def test_zero_epoch_case(disable_wip):
}
model = TextGeneration.train(**train_kwargs)
assert isinstance(model.model, HFAutoSeq2SeqLM)


def test_run_tokenizer_not_implemented():
with pytest.raises(NotImplementedError):
model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL)
model.run_tokenizer("This text doesn't matter")

0 comments on commit d81f11e

Please sign in to comment.