Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jan 10, 2025
1 parent ba90f85 commit 287e9f8
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ def test_cpu_offload_gpt2(self):

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
cpu_offload(gpt2, execution_device=0)
outputs = gpt2.generate(inputs["input_ids"])
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
== "Hello world! My name is Kiyoshi, and I'm a student at"
)

def test_disk_offload(self):
Expand Down Expand Up @@ -301,10 +301,10 @@ def test_disk_offload_gpt2(self):
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
with TemporaryDirectory() as tmp_dir:
disk_offload(gpt2, tmp_dir, execution_device=0)
outputs = gpt2.generate(inputs["input_ids"])
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
== "Hello world! My name is Kiyoshi, and I'm a student at"
)

@require_non_cpu
Expand Down Expand Up @@ -686,21 +686,21 @@ def test_dispatch_model_gpt2_on_two_devices(self):
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1

gpt2 = dispatch_model(gpt2, device_map)
outputs = gpt2.generate(inputs["input_ids"])
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
== "Hello world! My name is Kiyoshi, and I'm a student at"
)

# Dispatch with a bit of CPU offload
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
for i in range(4):
device_map[f"transformer.h.{i}"] = "cpu"
gpt2 = dispatch_model(gpt2, device_map)
outputs = gpt2.generate(inputs["input_ids"])
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
== "Hello world! My name is Kiyoshi, and I'm a student at"
)
# Dispatch with a bit of CPU and disk offload
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
Expand All @@ -713,10 +713,10 @@ def test_dispatch_model_gpt2_on_two_devices(self):
}
offload_state_dict(tmp_dir, state_dict)
gpt2 = dispatch_model(gpt2, device_map, offload_dir=tmp_dir)
outputs = gpt2.generate(inputs["input_ids"])
outputs = gpt2.generate(inputs["input_ids"], max_new_tokens=10)
assert (
tokenizer.decode(outputs[0].tolist())
== "Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo"
== "Hello world! My name is Kiyoshi, and I'm a student at"
)

@require_non_cpu
Expand Down

0 comments on commit 287e9f8

Please sign in to comment.