From 67b4d2772c5124b908f8ed9b13166a79bbeb88d2 Mon Sep 17 00:00:00 2001
From: Nikita Sobolev <mail@sobolevn.me>
Date: Fri, 11 Nov 2022 11:04:30 +0300
Subject: [PATCH] gh-98086: Now ``patch.dict`` can decorate async functions
 (#98095)

---
 Lib/test/test_unittest/testmock/testasync.py   | 17 +++++++++++++++++
 Lib/unittest/mock.py                           | 18 ++++++++++++++++++
 ...22-10-08-19-39-27.gh-issue-98086.y---WC.rst |  1 +
 3 files changed, 36 insertions(+)
 create mode 100644 Misc/NEWS.d/next/Library/2022-10-08-19-39-27.gh-issue-98086.y---WC.rst

diff --git a/Lib/test/test_unittest/testmock/testasync.py b/Lib/test/test_unittest/testmock/testasync.py
index 1bab671acde..e05a22861d4 100644
--- a/Lib/test/test_unittest/testmock/testasync.py
+++ b/Lib/test/test_unittest/testmock/testasync.py
@@ -149,6 +149,23 @@ async def test_async():
 
         run(test_async())
 
+    def test_patch_dict_async_def(self):
+        foo = {'a': 'a'}
+        @patch.dict(foo, {'a': 'b'})
+        async def test_async():
+            self.assertEqual(foo['a'], 'b')
+
+        self.assertTrue(iscoroutinefunction(test_async))
+        run(test_async())
+
+    def test_patch_dict_async_def_context(self):
+        foo = {'a': 'a'}
+        async def test_async():
+            with patch.dict(foo, {'a': 'b'}):
+                self.assertEqual(foo['a'], 'b')
+
+        run(test_async())
+
 
 class AsyncMockTest(unittest.TestCase):
     def test_iscoroutinefunction_default(self):
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 096b1a57147..a273753d6a0 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -1809,6 +1809,12 @@ def __init__(self, in_dict, values=(), clear=False, **kwargs):
     def __call__(self, f):
         if isinstance(f, type):
             return self.decorate_class(f)
+        if inspect.iscoroutinefunction(f):
+            return self.decorate_async_callable(f)
+        return self.decorate_callable(f)
+
+
+    def decorate_callable(self, f):
         @wraps(f)
         def _inner(*args, **kw):
             self._patch_dict()
@@ -1820,6 +1826,18 @@ def _inner(*args, **kw):
         return _inner
 
 
+    def decorate_async_callable(self, f):
+        @wraps(f)
+        async def _inner(*args, **kw):
+            self._patch_dict()
+            try:
+                return await f(*args, **kw)
+            finally:
+                self._unpatch_dict()
+
+        return _inner
+
+
     def decorate_class(self, klass):
         for attr in dir(klass):
             attr_value = getattr(klass, attr)
diff --git a/Misc/NEWS.d/next/Library/2022-10-08-19-39-27.gh-issue-98086.y---WC.rst b/Misc/NEWS.d/next/Library/2022-10-08-19-39-27.gh-issue-98086.y---WC.rst
new file mode 100644
index 00000000000..f4a1d272e13
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-10-08-19-39-27.gh-issue-98086.y---WC.rst
@@ -0,0 +1 @@
+Make sure ``patch.dict()`` can be applied on async functions.
-- 
GitLab