Skip to content

Commit

Permalink
Adaptive Clipping (with Ghost Clipping) (#711)
Browse files Browse the repository at this point in the history
Summary:

Added adaptive clipping as a capability for Opacus. 

Supported only with ghost-clipping. 

Distributed data parallel training is supported.

Differential Revision: D67522957
  • Loading branch information
iden-kalemaj authored and facebook-github-bot committed Jan 10, 2025
1 parent 9b78543 commit 19480b3
Show file tree
Hide file tree
Showing 4 changed files with 474 additions and 0 deletions.
155 changes: 155 additions & 0 deletions opacus/tests/multigpu_adaptive_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import unittest

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
DistributedDPOptimizerFastGradientClipping,
)
from opacus.utils.adaptive_clipping.adaptive_clipping_utils import (
PrivacyEngineAdaptiveClipping,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler


def setup(rank, world_size):
if sys.platform == "win32":
raise ValueError("Windows platform is not supported for this test")
else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"

# initialize the process group

os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
torch.distributed.init_process_group(
init_method="env://",
backend="nccl",
)


def cleanup():
dist.destroy_process_group()


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, weight, world_size, dp):
torch.manual_seed(world_size)
batch_size = 32
setup(rank, world_size)

# create model and move it to GPU with id rank
model = ToyModel().to(rank)
model.net1.weight.data.zero_()
optimizer = optim.SGD(model.parameters(), lr=1)

# create dataset
labels = torch.randn(2 * batch_size, 5).to(rank)
data = torch.randn(2 * batch_size, 10)
dataset = TensorDataset(data, labels)

criterion = nn.CrossEntropyLoss(reduction="mean")

max_grad_norm = 1e8

ddp_model = DDP(model, device_ids=[rank])

privacy_engine = PrivacyEngineAdaptiveClipping()

sampler = DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=False
)
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

if dp:
ddp_model, optimizer, criterion, data_loader = privacy_engine.make_private(
module=ddp_model,
optimizer=optimizer,
criterion=criterion,
data_loader=data_loader,
noise_multiplier=0,
max_grad_norm=max_grad_norm,
poisson_sampling=False,
grad_sample_mode="ghost",
target_unclipped_quantile=1.0,
)
assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping)

for x, y in data_loader:
outputs = ddp_model(x.to(rank))
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
break

weight.copy_(model.net1.weight.data.cpu())
cleanup()


def run_demo(demo_fn, weight, world_size, dp):
mp.spawn(
demo_fn,
args=(weight, world_size, dp),
nprocs=world_size,
join=True,
)


class GradientComputationTestAdaptiveClipping(unittest.TestCase):
def test_gradient_correct_adaptive(self) -> None:

# Tests that gradient is the same with DP or without DP in the distributed setting
n_gpus = torch.cuda.device_count()
self.assertTrue(
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}."
)

weight_dp, weight_nodp = torch.ones(10, 10), torch.ones(10, 10)

run_demo(
demo_basic,
weight_nodp,
2,
dp=False,
)
run_demo(
demo_basic,
weight_dp,
2,
dp=True,
)

self.assertTrue(torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3))
53 changes: 53 additions & 0 deletions opacus/utils/adaptive_clipping/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Adaptive Clipping (with Ghost Clipping)

Adaptive clipping [1] adapts the clipping norm (and amount of noise) during training to a quantile of per-sample gradient norms. It can reduce hyper-parameter tuning efforts and improve model accuracy by injecting less noise.

It is supported with:
- Ghost clipping
- Distributed data parallel training

It is **not** currently supported with:
- Vanilla DP-SGD
- Virtual batch sizes via Batch Memory Manager

## Overview

`PrivacyEngineAdaptiveClipping` is the entry-point for adaptive clipping training. It extends `PrivacyEngine` with additional arguments for adaptive clipping:

* `target_unclipped_quantile`: the quantile of per-sample gradient norms at which to clip (between 0 and 1)
* `min_clipbound`: the minimum allowed clipping norm
* `max_clipbound`: the maximum allowed clipping norm
* `clipbound_learning_rate`: the learning rate for tracking the true quantile
* `max_grad_norm`: the initial clipping norm (used at step 0)

The main hyper-parameter to tune is `target_unclipped_quantile`, which replaces tuning the clipping norm (`max_grad_norm`) in constant clipping DP-SGD. This parameter can be easier to tune, since the search is over a smaller range of values.


## Example usage

```python
from opacus.utils.adaptive_clipping.adaptive_clipping_utils import PrivacyEngineAdaptiveClipping

# ...
privacy_engine = PrivacyEngineAdaptiveClipping()
model, optimizer, criterion, train_loader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=train_loader,
criterion=criterion,
noise_multiplier=args.sigma,
max_grad_norm=10, # initial clipping norm
grad_sample_mode="ghost",
target_unclipped_quantile=0.5, # key parameter, may need tuning
min_clipbound=1, # default value
max_clipbound=1e8, # default value
clipbound_learning_rate=0.2 # default value, tuning not recommended
)
# ...
```

Note that `grad_sample_mode` must be set to `"ghost"` for adaptive clipping to work.

## References

[1] Galen Andrew, Om Thakkar, H. Brendan McMahan, Swaroop Ramaswamy, "Differentially Private Learning with Adaptive Clipping", NeurIPS, 2021.
11 changes: 11 additions & 0 deletions opacus/utils/adaptive_clipping/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .adaptive_clipping_utils import (
DPLossFastGradientAdaptiveClipping,
DPTensorFastGradientAdaptiveClipping,
PrivacyEngineAdaptiveClipping,
)

__all__ = [
"DPTensorFastGradientAdaptiveClipping",
"DPLossFastGradientAdaptiveClipping",
"PrivacyEngineAdaptiveClipping",
]
Loading

0 comments on commit 19480b3

Please sign in to comment.