Skip to content

Commit

Permalink
more groups
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 2, 2024
1 parent 3aa9ad3 commit 7c5e20f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 35 deletions.
48 changes: 21 additions & 27 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,46 +23,40 @@ env:

jobs:
checks:
name: ${{ matrix.task.name }} (py ${{ matrix.python }})
name: ${{ matrix.task.name }}
runs-on: [ubuntu-latest]
timeout-minutes: 15
strategy:
fail-fast: false
matrix:
python: ['3.8', '3.10']
python: ['3.10']
task:
- name: Lint
run: |
make lint-check
run: make lint-check

include:
- python: '3.10'
task:
name: Test (main group)
run: pytest -v --color=yes --durations=5 src/test/ -m 'not group1'
- name: Test (default)
run: pytest -v --color=yes --durations=5 src/test/ -m 'not fsdp1 and not fsdp2'

- python: '3.10'
task:
name: Test (group 1)
run: pytest -v --color=yes --durations=5 src/test/ -m 'group1'
- name: Test FSDP (1)
run: pytest -v --color=yes --durations=5 src/test/ -m 'fsdp1'

- python: '3.10'
task:
name: Type check
run: |
make type-check
- name: Test FSDP (2)
run: pytest -v --color=yes --durations=5 src/test/ -m 'fsdp2'

- python: '3.10'
task:
name: Build
run: |
make build
- name: Type check
run: make type-check

- name: Build
run: make build

- python: '3.10'
- name: Style
run: make style-check

include:
- python: '3.8'
task:
name: Style
run: |
make style-check
name: Lint (min Python)
run: make lint-check

steps:
- uses: actions/checkout@v3
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ log_cli = false
log_cli_level = "DEBUG"
markers = [
"gpu",
"group1",
"fsdp1",
"fsdp2",
]
filterwarnings = [
'ignore::FutureWarning:huggingface_hub\.file_download',
Expand Down
14 changes: 7 additions & 7 deletions src/test/distributed/fsdp/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run_fsdp_against_non_distributed_model(model_factory, model_data_factory):
)


@pytest.mark.group1
@pytest.mark.fsdp1
@pytest.mark.parametrize("backend", BACKENDS)
def test_fsdp_against_non_distributed_model(backend, tiny_model_factory, tiny_model_data_factory):
run_distributed_test(
Expand Down Expand Up @@ -159,7 +159,7 @@ def run_fsdp_against_ddp(model_factory, model_data_factory):
optim.step()


@pytest.mark.group1
@pytest.mark.fsdp1
@pytest.mark.parametrize("backend", BACKENDS)
def test_fsdp_against_ddp(backend, tiny_model_factory, tiny_model_data_factory):
run_distributed_test(
Expand Down Expand Up @@ -217,7 +217,7 @@ def run_fsdp_with_gradient_accumulation(model_factory, model_data_factory):
)


@pytest.mark.group1
@pytest.mark.fsdp1
@pytest.mark.parametrize("backend", BACKENDS)
def test_fsdp_with_gradient_accumulation(backend, tiny_model_factory, tiny_model_data_factory):
run_distributed_test(
Expand Down Expand Up @@ -350,7 +350,7 @@ def forward(self, x):
loss.backward()


@pytest.mark.group1
@pytest.mark.fsdp1
@pytest.mark.parametrize("backend", BACKENDS)
def test_nested_fsdp_api(backend, tiny_model_factory, tiny_model_data_factory):
run_distributed_test(
Expand Down Expand Up @@ -387,7 +387,7 @@ def run_fsdp_with_mixed_precision(model_factory, model_data_factory, precision):
assert param.grad.dtype == param.dtype


@pytest.mark.group1
@pytest.mark.fsdp2
@pytest.mark.parametrize("backend", BACKENDS)
@pytest.mark.parametrize("precision", FSDP_MIXED_PRECISION)
def test_fsdp_with_mixed_precision(backend, tiny_model_factory, tiny_model_data_factory, precision):
Expand Down Expand Up @@ -436,7 +436,7 @@ def __init__(self):
assert fsdp.module.fc3.out_proj.max_prefetch_count == 3


@pytest.mark.group1
@pytest.mark.fsdp1
@pytest.mark.parametrize("backend", BACKENDS)
def test_auto_wrap(backend):
run_distributed_test(
Expand Down Expand Up @@ -497,7 +497,7 @@ def initialize_and_check(m: nn.Module):
assert (param.data.detach() == 1.1).all()


@pytest.mark.group1
@pytest.mark.fsdp1
@pytest.mark.parametrize("backend", BACKENDS)
def test_apply(backend):
run_distributed_test(
Expand Down

0 comments on commit 7c5e20f

Please sign in to comment.