diff --git a/docs/index.md b/docs/index.md index 47a0192ab..f8a35cb75 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,7 @@ [kiui](https://kiui.moe/)'s notebook. ## Recent Updates +- [monky_patch.md](python\monky_patch/)
2023-12-12 18:58:16.041742
- [camera_intrinsics_exintrics.md](vision\camera_intrinsics_exintrics/)
2023-12-11 10:48:44.821914
- [set_usual_apps_proxy.md](web\proxy\set_usual_apps_proxy/)
2023-12-08 16:21:56.252977
- [slurm.md](linux\slurm/)
2023-12-07 15:20:34.127883
@@ -23,4 +24,3 @@ - [git_tricks.md](linux\git_tricks/)
2023-11-18 11:04:41.502742
- [trojan.md](web\proxy\trojan/)
2023-11-17 09:55:15.367723
- [bash.md](linux\bash/)
2023-11-17 09:19:57.540734
-- [datetime.md](python\datetime/)
2023-11-06 14:25:24.795182
diff --git a/docs/python/monky_patch.md b/docs/python/monky_patch.md new file mode 100644 index 000000000..687b58bb5 --- /dev/null +++ b/docs/python/monky_patch.md @@ -0,0 +1,125 @@ +## Monkey patch a method + + + +To patch the `forward` of a `nn.Module`, **define a closure** that keeps temporary variables and returns your new `forward`: + +```python +import torch.nn as nn + +class A(nn.Module): + def __init__(self, name): + super().__init__() + self.name = name + + def forward(self): + print(f'original forward of {self.name}') + +a = A('a') +b = A('b') + +for name, m in zip(['a', 'b'], [a, b]): + + def make_forward(): + # record the current name in closure ! + cur_name = name + def _forward(): + print(f'patched forward of {cur_name}') + return _forward + + m.forward = make_forward() + +a() +b() +``` + +Output: + +``` +patched forward of a +patched forward of b +``` + + + +However, you cannot patch magic methods like `__call__` by this: + +```python +class A: + def __init__(self, name): + super().__init__() + self.name = name + + def __call__(self): + print(f'original forward of {self.name}') + +a = A('a') +b = A('b') + +for name, m in zip(['a', 'b'], [a, b]): + + def make_call(): + # record the current name in closure ! + cur_name = name + def _call(): + print(f'patched forward of {cur_name}') + return _call + + m.__call__ = make_call() + +a() +b() +``` + +Output: + +``` +original forward of a +original forward of b +``` + +This is because `__call__` is looked-up with respect to the class instead of instance, so we are still calling the original `__call__`. + +We have to patch the class to make this work, and **cast instances to the derived class**: + +```python +class A: + def __init__(self, name): + super().__init__() + self.name = name + + def __call__(self): + print(f'original forward of {self.name}') + +a = A('a') +b = A('b') + +# a derived class that redirect __call__ to our patched call +class B(A): + def __call__(self): + self.patched_call() + +for name, m in zip(['a', 'b'], [a, b]): + + def make_call(): + # record the current name in closure ! + cur_name = name + def _call(): + print(f'patched forward of {cur_name}') + + return _call + + m.__class__ = B # magic cast! + m.patched_call = make_call() + +a() +b() +``` + +Output: + +``` +patched forward of a +patched forward of b +``` +