Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Jan 13, 2025
1 parent ebc0ecc commit c22b919
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/model/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_hf_casual_models():
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here

input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

# unpad the position_ids to align the rotary
Expand Down Expand Up @@ -96,8 +96,8 @@ def test_hf_value_models():
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here

input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

# unpad the position_ids to align the rotary
Expand Down

0 comments on commit c22b919

Please sign in to comment.