-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adaptive Clipping (with Ghost Clipping) (#711)
Summary: Added adaptive clipping as a capability for Opacus. Supported only with ghost-clipping. Distributed data parallel training is supported. Differential Revision: D67522957
- Loading branch information
1 parent
9b78543
commit 19480b3
Showing
4 changed files
with
474 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import sys | ||
import unittest | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import ( | ||
DistributedDPOptimizerFastGradientClipping, | ||
) | ||
from opacus.utils.adaptive_clipping.adaptive_clipping_utils import ( | ||
PrivacyEngineAdaptiveClipping, | ||
) | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
|
||
def setup(rank, world_size): | ||
if sys.platform == "win32": | ||
raise ValueError("Windows platform is not supported for this test") | ||
else: | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
|
||
# initialize the process group | ||
|
||
os.environ["RANK"] = str(rank) | ||
os.environ["WORLD_SIZE"] = str(world_size) | ||
torch.distributed.init_process_group( | ||
init_method="env://", | ||
backend="nccl", | ||
) | ||
|
||
|
||
def cleanup(): | ||
dist.destroy_process_group() | ||
|
||
|
||
class ToyModel(nn.Module): | ||
def __init__(self): | ||
super(ToyModel, self).__init__() | ||
self.net1 = nn.Linear(10, 10) | ||
self.relu = nn.ReLU() | ||
self.net2 = nn.Linear(10, 5) | ||
|
||
def forward(self, x): | ||
return self.net2(self.relu(self.net1(x))) | ||
|
||
|
||
def demo_basic(rank, weight, world_size, dp): | ||
torch.manual_seed(world_size) | ||
batch_size = 32 | ||
setup(rank, world_size) | ||
|
||
# create model and move it to GPU with id rank | ||
model = ToyModel().to(rank) | ||
model.net1.weight.data.zero_() | ||
optimizer = optim.SGD(model.parameters(), lr=1) | ||
|
||
# create dataset | ||
labels = torch.randn(2 * batch_size, 5).to(rank) | ||
data = torch.randn(2 * batch_size, 10) | ||
dataset = TensorDataset(data, labels) | ||
|
||
criterion = nn.CrossEntropyLoss(reduction="mean") | ||
|
||
max_grad_norm = 1e8 | ||
|
||
ddp_model = DDP(model, device_ids=[rank]) | ||
|
||
privacy_engine = PrivacyEngineAdaptiveClipping() | ||
|
||
sampler = DistributedSampler( | ||
dataset, num_replicas=world_size, rank=rank, shuffle=False | ||
) | ||
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) | ||
|
||
if dp: | ||
ddp_model, optimizer, criterion, data_loader = privacy_engine.make_private( | ||
module=ddp_model, | ||
optimizer=optimizer, | ||
criterion=criterion, | ||
data_loader=data_loader, | ||
noise_multiplier=0, | ||
max_grad_norm=max_grad_norm, | ||
poisson_sampling=False, | ||
grad_sample_mode="ghost", | ||
target_unclipped_quantile=1.0, | ||
) | ||
assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping) | ||
|
||
for x, y in data_loader: | ||
outputs = ddp_model(x.to(rank)) | ||
loss = criterion(outputs, y) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
break | ||
|
||
weight.copy_(model.net1.weight.data.cpu()) | ||
cleanup() | ||
|
||
|
||
def run_demo(demo_fn, weight, world_size, dp): | ||
mp.spawn( | ||
demo_fn, | ||
args=(weight, world_size, dp), | ||
nprocs=world_size, | ||
join=True, | ||
) | ||
|
||
|
||
class GradientComputationTestAdaptiveClipping(unittest.TestCase): | ||
def test_gradient_correct_adaptive(self) -> None: | ||
|
||
# Tests that gradient is the same with DP or without DP in the distributed setting | ||
n_gpus = torch.cuda.device_count() | ||
self.assertTrue( | ||
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}." | ||
) | ||
|
||
weight_dp, weight_nodp = torch.ones(10, 10), torch.ones(10, 10) | ||
|
||
run_demo( | ||
demo_basic, | ||
weight_nodp, | ||
2, | ||
dp=False, | ||
) | ||
run_demo( | ||
demo_basic, | ||
weight_dp, | ||
2, | ||
dp=True, | ||
) | ||
|
||
self.assertTrue(torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Adaptive Clipping (with Ghost Clipping) | ||
|
||
Adaptive clipping [1] adapts the clipping norm (and amount of noise) during training to a quantile of per-sample gradient norms. It can reduce hyper-parameter tuning efforts and improve model accuracy by injecting less noise. | ||
|
||
It is supported with: | ||
- Ghost clipping | ||
- Distributed data parallel training | ||
|
||
It is **not** currently supported with: | ||
- Vanilla DP-SGD | ||
- Virtual batch sizes via Batch Memory Manager | ||
|
||
## Overview | ||
|
||
`PrivacyEngineAdaptiveClipping` is the entry-point for adaptive clipping training. It extends `PrivacyEngine` with additional arguments for adaptive clipping: | ||
|
||
* `target_unclipped_quantile`: the quantile of per-sample gradient norms at which to clip (between 0 and 1) | ||
* `min_clipbound`: the minimum allowed clipping norm | ||
* `max_clipbound`: the maximum allowed clipping norm | ||
* `clipbound_learning_rate`: the learning rate for tracking the true quantile | ||
* `max_grad_norm`: the initial clipping norm (used at step 0) | ||
|
||
The main hyper-parameter to tune is `target_unclipped_quantile`, which replaces tuning the clipping norm (`max_grad_norm`) in constant clipping DP-SGD. This parameter can be easier to tune, since the search is over a smaller range of values. | ||
|
||
|
||
## Example usage | ||
|
||
```python | ||
from opacus.utils.adaptive_clipping.adaptive_clipping_utils import PrivacyEngineAdaptiveClipping | ||
|
||
# ... | ||
privacy_engine = PrivacyEngineAdaptiveClipping() | ||
model, optimizer, criterion, train_loader = privacy_engine.make_private( | ||
module=model, | ||
optimizer=optimizer, | ||
data_loader=train_loader, | ||
criterion=criterion, | ||
noise_multiplier=args.sigma, | ||
max_grad_norm=10, # initial clipping norm | ||
grad_sample_mode="ghost", | ||
target_unclipped_quantile=0.5, # key parameter, may need tuning | ||
min_clipbound=1, # default value | ||
max_clipbound=1e8, # default value | ||
clipbound_learning_rate=0.2 # default value, tuning not recommended | ||
) | ||
# ... | ||
``` | ||
|
||
Note that `grad_sample_mode` must be set to `"ghost"` for adaptive clipping to work. | ||
|
||
## References | ||
|
||
[1] Galen Andrew, Om Thakkar, H. Brendan McMahan, Swaroop Ramaswamy, "Differentially Private Learning with Adaptive Clipping", NeurIPS, 2021. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .adaptive_clipping_utils import ( | ||
DPLossFastGradientAdaptiveClipping, | ||
DPTensorFastGradientAdaptiveClipping, | ||
PrivacyEngineAdaptiveClipping, | ||
) | ||
|
||
__all__ = [ | ||
"DPTensorFastGradientAdaptiveClipping", | ||
"DPLossFastGradientAdaptiveClipping", | ||
"PrivacyEngineAdaptiveClipping", | ||
] |
Oops, something went wrong.