diff --git a/opacus/tests/multigpu_adaptive_clipping.py b/opacus/tests/multigpu_adaptive_clipping.py new file mode 100644 index 00000000..ffd7c6f1 --- /dev/null +++ b/opacus/tests/multigpu_adaptive_clipping.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import ( + DistributedDPOptimizerFastGradientClipping, +) +from opacus.utils.adaptive_clipping.adaptive_clipping_utils import ( + PrivacyEngineAdaptiveClipping, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.distributed import DistributedSampler + + +def setup(rank, world_size): + if sys.platform == "win32": + raise ValueError("Windows platform is not supported for this test") + else: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + torch.distributed.init_process_group( + init_method="env://", + backend="nccl", + ) + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def demo_basic(rank, weight, world_size, dp): + torch.manual_seed(world_size) + batch_size = 32 + setup(rank, world_size) + + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + model.net1.weight.data.zero_() + optimizer = optim.SGD(model.parameters(), lr=1) + + # create dataset + labels = torch.randn(2 * batch_size, 5).to(rank) + data = torch.randn(2 * batch_size, 10) + dataset = TensorDataset(data, labels) + + criterion = nn.CrossEntropyLoss(reduction="mean") + + max_grad_norm = 1e8 + + ddp_model = DDP(model, device_ids=[rank]) + + privacy_engine = PrivacyEngineAdaptiveClipping() + + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler) + + if dp: + ddp_model, optimizer, criterion, data_loader = privacy_engine.make_private( + module=ddp_model, + optimizer=optimizer, + criterion=criterion, + data_loader=data_loader, + noise_multiplier=0, + max_grad_norm=max_grad_norm, + poisson_sampling=False, + grad_sample_mode="ghost", + target_unclipped_quantile=1.0, + ) + assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping) + + for x, y in data_loader: + outputs = ddp_model(x.to(rank)) + loss = criterion(outputs, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + break + + weight.copy_(model.net1.weight.data.cpu()) + cleanup() + + +def run_demo(demo_fn, weight, world_size, dp): + mp.spawn( + demo_fn, + args=(weight, world_size, dp), + nprocs=world_size, + join=True, + ) + + +class GradientComputationTestAdaptiveClipping(unittest.TestCase): + def test_gradient_correct_adaptive(self) -> None: + + # Tests that gradient is the same with DP or without DP in the distributed setting + n_gpus = torch.cuda.device_count() + self.assertTrue( + n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}." + ) + + weight_dp, weight_nodp = torch.ones(10, 10), torch.ones(10, 10) + + run_demo( + demo_basic, + weight_nodp, + 2, + dp=False, + ) + run_demo( + demo_basic, + weight_dp, + 2, + dp=True, + ) + + self.assertTrue(torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)) diff --git a/opacus/utils/adaptive_clipping/README.md b/opacus/utils/adaptive_clipping/README.md new file mode 100644 index 00000000..1f7923eb --- /dev/null +++ b/opacus/utils/adaptive_clipping/README.md @@ -0,0 +1,53 @@ +# Adaptive Clipping (with Ghost Clipping) + +Adaptive clipping [1] adapts the clipping norm (and amount of noise) during training to a quantile of per-sample gradient norms. It can reduce hyper-parameter tuning efforts and improve model accuracy by injecting less noise. + +It is supported with: +- Ghost clipping +- Distributed data parallel training + +It is **not** currently supported with: +- Vanilla DP-SGD +- Virtual batch sizes via Batch Memory Manager + +## Overview + +`PrivacyEngineAdaptiveClipping` is the entry-point for adaptive clipping training. It extends `PrivacyEngine` with additional arguments for adaptive clipping: + +* `target_unclipped_quantile`: the quantile of per-sample gradient norms at which to clip (between 0 and 1) +* `min_clipbound`: the minimum allowed clipping norm +* `max_clipbound`: the maximum allowed clipping norm +* `clipbound_learning_rate`: the learning rate for tracking the true quantile +* `max_grad_norm`: the initial clipping norm (used at step 0) + +The main hyper-parameter to tune is `target_unclipped_quantile`, which replaces tuning the clipping norm (`max_grad_norm`) in constant clipping DP-SGD. This parameter can be easier to tune, since the search is over a smaller range of values. + + +## Example usage + +```python +from opacus.utils.adaptive_clipping.adaptive_clipping_utils import PrivacyEngineAdaptiveClipping + +# ... +privacy_engine = PrivacyEngineAdaptiveClipping() +model, optimizer, criterion, train_loader = privacy_engine.make_private( + module=model, + optimizer=optimizer, + data_loader=train_loader, + criterion=criterion, + noise_multiplier=args.sigma, + max_grad_norm=10, # initial clipping norm + grad_sample_mode="ghost", + target_unclipped_quantile=0.5, # key parameter, may need tuning + min_clipbound=1, # default value + max_clipbound=1e8, # default value + clipbound_learning_rate=0.2 # default value, tuning not recommended +) +# ... +``` + +Note that `grad_sample_mode` must be set to `"ghost"` for adaptive clipping to work. + +## References + +[1] Galen Andrew, Om Thakkar, H. Brendan McMahan, Swaroop Ramaswamy, "Differentially Private Learning with Adaptive Clipping", NeurIPS, 2021. diff --git a/opacus/utils/adaptive_clipping/__init__.py b/opacus/utils/adaptive_clipping/__init__.py new file mode 100644 index 00000000..cfde5854 --- /dev/null +++ b/opacus/utils/adaptive_clipping/__init__.py @@ -0,0 +1,11 @@ +from .adaptive_clipping_utils import ( + DPLossFastGradientAdaptiveClipping, + DPTensorFastGradientAdaptiveClipping, + PrivacyEngineAdaptiveClipping, +) + +__all__ = [ + "DPTensorFastGradientAdaptiveClipping", + "DPLossFastGradientAdaptiveClipping", + "PrivacyEngineAdaptiveClipping", +] diff --git a/opacus/utils/adaptive_clipping/adaptive_clipping_utils.py b/opacus/utils/adaptive_clipping/adaptive_clipping_utils.py new file mode 100644 index 00000000..ddd490d7 --- /dev/null +++ b/opacus/utils/adaptive_clipping/adaptive_clipping_utils.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP +from opacus.grad_sample import GradSampleModule +from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import ( + GradSampleModuleFastGradientClipping, +) +from opacus.optimizers import DPOptimizerFastGradientClipping +from opacus.privacy_engine import PrivacyEngine +from opacus.utils.fast_gradient_clipping_utils import ( + DPLossFastGradientClipping, + DPTensorFastGradientClipping, +) +from torch.nn.parallel import DistributedDataParallel as DDP + + +class DPTensorFastGradientAdaptiveClipping(DPTensorFastGradientClipping): + """ + Packages the training loop for Adaptive clipping (with Fast Gradient and Ghost Clipping) into loss.backward(). + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + loss_per_sample: torch.Tensor, + loss_reduction: str = "mean", + target_unclipped_quantile: float = 0.5, + min_clipbound: float = 1, + max_clipbound: float = 1e8, + clipbound_learning_rate: float = 0.2, + initial_noise_multiplier: float = 1.0, + ): + """ + + Args: + module: the module to train + optimizer: the optimizer used to train the module + loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] + target_unclipped_quantile: target quantile for unclipped gradients, between 0 and 1 + min_clipbound: minimum clipping norm allowed + max_clipbound: maximum clipping norm allowed + clipbound_learning_rate: learning rate for the descent algorithm that finds the target unclipped quantile + initial_noise_multiplier: initial noise multiplier provided at step 0 + + """ + + super().__init__(module, optimizer, loss_per_sample, loss_reduction) + + self.target_unclipped_quantile = target_unclipped_quantile + self.min_clipbound = min_clipbound + self.max_clipbound = max_clipbound + self.clipbound_learning_rate = clipbound_learning_rate + self.initial_clipping_norm = self.optimizer.max_grad_norm + self.initial_noise_multiplier = initial_noise_multiplier + + def backward(self): + """ + Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between + """ + + if self.loss_reduction == "mean": + reduced_loss = torch.mean(self.loss_per_sample, dim=0) + elif self.loss_reduction == "sum": + reduced_loss = torch.sum(self.loss_per_sample, dim=0) + else: + raise ValueError( + f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported" + ) + reduced_loss.backward(retain_graph=True) + self.optimizer.zero_grad() + + # calc per_sample gradient norms + per_sample_norms = self.module.get_norm_sample() + + # calculate new max grad norm and noise multiplier + new_max_grad_norm, new_noise_multiplier = self._update_clip_and_noise( + per_sample_norms + ) + + # update max grad norm and noise multiplier + self.module.max_grad_norm = new_max_grad_norm + self.optimizer.max_grad_norm = new_max_grad_norm + self.optimizer.noise_multiplier = new_noise_multiplier + + # get the loss rescaling coefficients using the updated max_grad_norm + coeff = torch.where( + per_sample_norms <= self.module.max_grad_norm, + torch.ones_like(per_sample_norms), + self.module.max_grad_norm / per_sample_norms, + ) # per-sample coeff [batch_size] + + second_loss_per_sample = coeff * self.loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + self.module.disable_hooks() + second_loss.backward() + self.module.enable_hooks() + + def _is_distributed(self): + + return isinstance(self.module, (DPDDP, DDP)) + + def _update_clip_and_noise(self, per_sample_norms): + + assert ( + self.module.max_grad_norm == self.optimizer.max_grad_norm + ), "Max grad norm does not match between optimizer and model." + + # calculate new max_grad_norm + new_max_grad_norm = current_max_norm = self.module.max_grad_norm + batch_size = len(per_sample_norms) + unclipped_num = (per_sample_norms <= current_max_norm).sum() + + if self._is_distributed(): + unclipped_and_batch = torch.tensor([unclipped_num, batch_size]) + torch.distributed.all_reduce( + unclipped_and_batch, op=torch.distributed.ReduceOp.SUM + ) + unclipped_num = unclipped_and_batch[0].item() + batch_size = unclipped_and_batch[1].item() + + unclipped_num_std = batch_size / 20.0 # use heuristic from [ATMR'22] + unclipped_num = ( + unclipped_num + + torch.normal(mean=0.0, std=unclipped_num_std, size=(1,)).item() + ) + unclipped_frac = unclipped_num / batch_size + + new_max_grad_norm = current_max_norm * torch.exp( + -self.clipbound_learning_rate + * (unclipped_frac - self.target_unclipped_quantile) + ) + new_max_grad_norm = new_max_grad_norm.clamp( + min=self.min_clipbound, max=self.max_clipbound + ).item() + + # the following ensures that the updated noise multiplier is a real number + assert ( + batch_size > 10 * self.initial_noise_multiplier + ), "Batch size is too small. For adaptive clipping, please use a batch size larger than 10 * noise_multiplier." + if self.initial_noise_multiplier > 0: + new_noise_multiplier = ( + self.initial_noise_multiplier ** (-2) + - (2.0 * unclipped_num_std) ** (-2) + ) ** (-1 / 2.0) + else: + new_noise_multiplier = self.initial_noise_multiplier + + return new_max_grad_norm, new_noise_multiplier + + +class DPLossFastGradientAdaptiveClipping(DPLossFastGradientClipping): + """ + Wrapper on the loss function to be used with Adaptive Clipping (together with Fast Gradient and Ghost Clipping). + It computes the per-sample loss, and wraps it in DPTensorFastGradientAdaptiveClipping. + """ + + def __init__( + self, + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + criterion, + loss_reduction: str = "mean", + target_unclipped_quantile: float = 0.5, + min_clipbound: float = 1, + max_clipbound: float = 1e8, + clipbound_learning_rate: float = 0.2, + initial_noise_multiplier: float = 1.0, + ): + + super().__init__(module, optimizer, criterion, loss_reduction) + + self.target_unclipped_quantile = target_unclipped_quantile + self.min_clipbound = min_clipbound + self.max_clipbound = max_clipbound + self.clipbound_learning_rate = clipbound_learning_rate + self.initial_noise_multiplier = initial_noise_multiplier + + def __call__(self, input, target) -> DPTensorFastGradientAdaptiveClipping: + """ + Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientAdaptiveClipping + """ + + loss_per_sample = self.criterion( + input, + target, + ) + return DPTensorFastGradientAdaptiveClipping( + self.module, + self.optimizer, + loss_per_sample, + self.loss_reduction, + self.target_unclipped_quantile, + self.min_clipbound, + self.max_clipbound, + self.clipbound_learning_rate, + self.initial_noise_multiplier, + ) + + +class PrivacyEngineAdaptiveClipping(PrivacyEngine): + + def __init__(self, *, accountant: str = "prv", secure_mode: bool = False): + super().__init__(accountant=accountant, secure_mode=secure_mode) + + def _prepare_criterion( + self, + *, + module: GradSampleModule, + optimizer: DPOptimizerFastGradientClipping, + criterion=torch.nn.CrossEntropyLoss(), + loss_reduction: str = "mean", + target_unclipped_quantile: float = 0.5, + min_clipbound: float = 1, + max_clipbound: float = 1e8, + clipbound_learning_rate: float = 0.2, + **kwargs, + ) -> DPLossFastGradientAdaptiveClipping: + """ + Args: + module: the module to train + optimizer: the optimizer used to train the module + criterion: the loss function used to train the module + loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients) + target_unclipped_quantile: target quantile for unclipped gradients, between 0 and 1 + min_clipbound: minimum clipping norm allowed + max_clipbound: maximum clipping norm allowed + clipbound_learning_rate: learning rate for the descent algorithm that finds the target unclipped quantile + """ + + return DPLossFastGradientAdaptiveClipping( + module, + optimizer, + criterion, + loss_reduction=loss_reduction, + target_unclipped_quantile=target_unclipped_quantile, + min_clipbound=min_clipbound, + max_clipbound=max_clipbound, + clipbound_learning_rate=clipbound_learning_rate, + initial_noise_multiplier=optimizer.noise_multiplier, + )