From 787764219f874ce2035699ed772af1e9f3bbf813 Mon Sep 17 00:00:00 2001
From: Serhiy Storchaka <storchaka@gmail.com>
Date: Wed, 30 Nov 2022 23:04:30 +0200
Subject: [PATCH] gh-89189: More compact range iterator (GH-27986)

---
 Include/internal/pycore_range.h               |  1 -
 Lib/test/test_range.py                        | 38 +++++++--
 Lib/test/test_sys.py                          |  3 +-
 .../2021-08-29-15-55-19.bpo-45026.z7nTA3.rst  |  3 +
 Objects/rangeobject.c                         | 79 ++++++++++---------
 Python/bytecodes.c                            |  7 +-
 Python/generated_cases.c.h                    |  7 +-
 7 files changed, 88 insertions(+), 50 deletions(-)
 create mode 100644 Misc/NEWS.d/next/Core and Builtins/2021-08-29-15-55-19.bpo-45026.z7nTA3.rst

diff --git a/Include/internal/pycore_range.h b/Include/internal/pycore_range.h
index 809e89a1e01..bf045ec4fd8 100644
--- a/Include/internal/pycore_range.h
+++ b/Include/internal/pycore_range.h
@@ -10,7 +10,6 @@ extern "C" {
 
 typedef struct {
     PyObject_HEAD
-    long index;
     long start;
     long step;
     long len;
diff --git a/Lib/test/test_range.py b/Lib/test/test_range.py
index 851ad5b7c2f..7be76b32ac2 100644
--- a/Lib/test/test_range.py
+++ b/Lib/test/test_range.py
@@ -407,11 +407,7 @@ def test_iterator_pickling_overflowing_index(self):
         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
             with self.subTest(proto=proto):
                 it = iter(range(2**32 + 2))
-                _, _, idx = it.__reduce__()
-                self.assertEqual(idx, 0)
-                it.__setstate__(2**32 + 1)  # undocumented way to set r->index
-                _, _, idx = it.__reduce__()
-                self.assertEqual(idx, 2**32 + 1)
+                it.__setstate__(2**32 + 1)  # undocumented way to advance an iterator
                 d = pickle.dumps(it, proto)
                 it = pickle.loads(d)
                 self.assertEqual(next(it), 2**32 + 1)
@@ -442,6 +438,38 @@ def test_large_exhausted_iterator_pickling(self):
             self.assertEqual(list(i), [])
             self.assertEqual(list(i2), [])
 
+    def test_iterator_unpickle_compat(self):
+        testcases = [
+            b'c__builtin__\niter\n(c__builtin__\nxrange\n(I10\nI20\nI2\ntRtRI2\nb.',
+            b'c__builtin__\niter\n(c__builtin__\nxrange\n(K\nK\x14K\x02tRtRK\x02b.',
+            b'\x80\x02c__builtin__\niter\nc__builtin__\nxrange\nK\nK\x14K\x02\x87R\x85RK\x02b.',
+            b'\x80\x03cbuiltins\niter\ncbuiltins\nrange\nK\nK\x14K\x02\x87R\x85RK\x02b.',
+            b'\x80\x04\x951\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x8c\x04iter\x93\x8c\x08builtins\x8c\x05range\x93K\nK\x14K\x02\x87R\x85RK\x02b.',
+
+            b'c__builtin__\niter\n(c__builtin__\nxrange\n(L-36893488147419103232L\nI20\nI2\ntRtRL18446744073709551623L\nb.',
+            b'c__builtin__\niter\n(c__builtin__\nxrange\n(L-36893488147419103232L\nK\x14K\x02tRtRL18446744073709551623L\nb.',
+            b'\x80\x02c__builtin__\niter\nc__builtin__\nxrange\n\x8a\t\x00\x00\x00\x00\x00\x00\x00\x00\xfeK\x14K\x02\x87R\x85R\x8a\t\x07\x00\x00\x00\x00\x00\x00\x00\x01b.',
+            b'\x80\x03cbuiltins\niter\ncbuiltins\nrange\n\x8a\t\x00\x00\x00\x00\x00\x00\x00\x00\xfeK\x14K\x02\x87R\x85R\x8a\t\x07\x00\x00\x00\x00\x00\x00\x00\x01b.',
+            b'\x80\x04\x95C\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x8c\x04iter\x93\x8c\x08builtins\x8c\x05range\x93\x8a\t\x00\x00\x00\x00\x00\x00\x00\x00\xfeK\x14K\x02\x87R\x85R\x8a\t\x07\x00\x00\x00\x00\x00\x00\x00\x01b.',
+        ]
+        for t in testcases:
+            it = pickle.loads(t)
+            self.assertEqual(list(it), [14, 16, 18])
+
+    def test_iterator_setstate(self):
+        it = iter(range(10, 20, 2))
+        it.__setstate__(2)
+        self.assertEqual(list(it), [14, 16, 18])
+        it = reversed(range(10, 20, 2))
+        it.__setstate__(3)
+        self.assertEqual(list(it), [12, 10])
+        it = iter(range(-2**65, 20, 2))
+        it.__setstate__(2**64 + 7)
+        self.assertEqual(list(it), [14, 16, 18])
+        it = reversed(range(10, 2**65, 2))
+        it.__setstate__(2**64 - 7)
+        self.assertEqual(list(it), [12, 10])
+
     def test_odd_bug(self):
         # This used to raise a "SystemError: NULL result without error"
         # because the range validation step was eating the exception
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 2403c7c815f..17a5026e257 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -1484,7 +1484,8 @@ def delx(self): del self.__x
         # PyCapsule
         # XXX
         # rangeiterator
-        check(iter(range(1)), size('4l'))
+        check(iter(range(1)), size('3l'))
+        check(iter(range(2**65)), size('3P'))
         # reverse
         check(reversed(''), size('nP'))
         # range
diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-08-29-15-55-19.bpo-45026.z7nTA3.rst b/Misc/NEWS.d/next/Core and Builtins/2021-08-29-15-55-19.bpo-45026.z7nTA3.rst
new file mode 100644
index 00000000000..481ab53e4f5
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2021-08-29-15-55-19.bpo-45026.z7nTA3.rst	
@@ -0,0 +1,3 @@
+Optimize the :class:`range` object iterator. It is now smaller, faster
+iteration of ranges containing large numbers. Smaller pickles, faster
+unpickling.
diff --git a/Objects/rangeobject.c b/Objects/rangeobject.c
index a889aa04db8..992e7c079de 100644
--- a/Objects/rangeobject.c
+++ b/Objects/rangeobject.c
@@ -756,18 +756,19 @@ PyTypeObject PyRange_Type = {
 static PyObject *
 rangeiter_next(_PyRangeIterObject *r)
 {
-    if (r->index < r->len)
-        /* cast to unsigned to avoid possible signed overflow
-           in intermediate calculations. */
-        return PyLong_FromLong((long)(r->start +
-                                      (unsigned long)(r->index++) * r->step));
+    if (r->len > 0) {
+        long result = r->start;
+        r->start = result + r->step;
+        r->len--;
+        return PyLong_FromLong(result);
+    }
     return NULL;
 }
 
 static PyObject *
 rangeiter_len(_PyRangeIterObject *r, PyObject *Py_UNUSED(ignored))
 {
-    return PyLong_FromLong(r->len - r->index);
+    return PyLong_FromLong(r->len);
 }
 
 PyDoc_STRVAR(length_hint_doc,
@@ -794,8 +795,8 @@ rangeiter_reduce(_PyRangeIterObject *r, PyObject *Py_UNUSED(ignored))
     if (range == NULL)
         goto err;
     /* return the result */
-    return Py_BuildValue(
-            "N(N)l", _PyEval_GetBuiltin(&_Py_ID(iter)), range, r->index);
+    return Py_BuildValue("N(N)O", _PyEval_GetBuiltin(&_Py_ID(iter)),
+                         range, Py_None);
 err:
     Py_XDECREF(start);
     Py_XDECREF(stop);
@@ -814,7 +815,8 @@ rangeiter_setstate(_PyRangeIterObject *r, PyObject *state)
         index = 0;
     else if (index > r->len)
         index = r->len; /* exhausted iterator */
-    r->index = index;
+    r->start += index * r->step;
+    r->len -= index;
     Py_RETURN_NONE;
 }
 
@@ -904,13 +906,11 @@ fast_range_iter(long start, long stop, long step, long len)
     it->start = start;
     it->step = step;
     it->len = len;
-    it->index = 0;
     return (PyObject *)it;
 }
 
 typedef struct {
     PyObject_HEAD
-    PyObject *index;
     PyObject *start;
     PyObject *step;
     PyObject *len;
@@ -919,7 +919,8 @@ typedef struct {
 static PyObject *
 longrangeiter_len(longrangeiterobject *r, PyObject *no_args)
 {
-    return PyNumber_Subtract(r->len, r->index);
+    Py_INCREF(r->len);
+    return r->len;
 }
 
 static PyObject *
@@ -946,8 +947,8 @@ longrangeiter_reduce(longrangeiterobject *r, PyObject *Py_UNUSED(ignored))
     }
 
     /* return the result */
-    return Py_BuildValue(
-            "N(N)O", _PyEval_GetBuiltin(&_Py_ID(iter)), range, r->index);
+    return Py_BuildValue("N(N)O", _PyEval_GetBuiltin(&_Py_ID(iter)),
+                         range, Py_None);
 }
 
 static PyObject *
@@ -970,7 +971,22 @@ longrangeiter_setstate(longrangeiterobject *r, PyObject *state)
         if (cmp > 0)
             state = r->len;
     }
-    Py_XSETREF(r->index, Py_NewRef(state));
+    PyObject *product = PyNumber_Multiply(state, r->step);
+    if (product == NULL)
+        return NULL;
+    PyObject *new_start = PyNumber_Add(r->start, product);
+    Py_DECREF(product);
+    if (new_start == NULL)
+        return NULL;
+    PyObject *new_len = PyNumber_Subtract(r->len, state);
+    if (new_len == NULL) {
+        Py_DECREF(new_start);
+        return NULL;
+    }
+    PyObject *tmp = r->start;
+    r->start = new_start;
+    Py_SETREF(r->len, new_len);
+    Py_DECREF(tmp);
     Py_RETURN_NONE;
 }
 
@@ -987,7 +1003,6 @@ static PyMethodDef longrangeiter_methods[] = {
 static void
 longrangeiter_dealloc(longrangeiterobject *r)
 {
-    Py_XDECREF(r->index);
     Py_XDECREF(r->start);
     Py_XDECREF(r->step);
     Py_XDECREF(r->len);
@@ -997,29 +1012,21 @@ longrangeiter_dealloc(longrangeiterobject *r)
 static PyObject *
 longrangeiter_next(longrangeiterobject *r)
 {
-    PyObject *product, *new_index, *result;
-    if (PyObject_RichCompareBool(r->index, r->len, Py_LT) != 1)
+    if (PyObject_RichCompareBool(r->len, _PyLong_GetZero(), Py_GT) != 1)
         return NULL;
 
-    new_index = PyNumber_Add(r->index, _PyLong_GetOne());
-    if (!new_index)
+    PyObject *new_start = PyNumber_Add(r->start, r->step);
+    if (new_start == NULL) {
         return NULL;
-
-    product = PyNumber_Multiply(r->index, r->step);
-    if (!product) {
-        Py_DECREF(new_index);
-        return NULL;
-    }
-
-    result = PyNumber_Add(r->start, product);
-    Py_DECREF(product);
-    if (result) {
-        Py_SETREF(r->index, new_index);
     }
-    else {
-        Py_DECREF(new_index);
+    PyObject *new_len = PyNumber_Subtract(r->len, _PyLong_GetOne());
+    if (new_len == NULL) {
+        Py_DECREF(new_start);
+        return NULL;
     }
-
+    PyObject *result = r->start;
+    r->start = new_start;
+    Py_SETREF(r->len, new_len);
     return result;
 }
 
@@ -1108,7 +1115,6 @@ range_iter(PyObject *seq)
     it->start = Py_NewRef(r->start);
     it->step = Py_NewRef(r->step);
     it->len = Py_NewRef(r->length);
-    it->index = Py_NewRef(_PyLong_GetZero());
     return (PyObject *)it;
 }
 
@@ -1186,7 +1192,7 @@ range_reverse(PyObject *seq, PyObject *Py_UNUSED(ignored))
     it = PyObject_New(longrangeiterobject, &PyLongRangeIter_Type);
     if (it == NULL)
         return NULL;
-    it->index = it->start = it->step = NULL;
+    it->start = it->step = NULL;
 
     /* start + (len - 1) * step */
     it->len = Py_NewRef(range->length);
@@ -1210,7 +1216,6 @@ range_reverse(PyObject *seq, PyObject *Py_UNUSED(ignored))
     if (!it->step)
         goto create_failure;
 
-    it->index = Py_NewRef(_PyLong_GetZero());
     return (PyObject *)it;
 
 create_failure:
diff --git a/Python/bytecodes.c b/Python/bytecodes.c
index a1f910da8ed..41dd1acc937 100644
--- a/Python/bytecodes.c
+++ b/Python/bytecodes.c
@@ -2620,14 +2620,15 @@ dummy_func(
             STAT_INC(FOR_ITER, hit);
             _Py_CODEUNIT next = next_instr[INLINE_CACHE_ENTRIES_FOR_ITER];
             assert(_PyOpcode_Deopt[_Py_OPCODE(next)] == STORE_FAST);
-            if (r->index >= r->len) {
+            if (r->len <= 0) {
                 STACK_SHRINK(1);
                 Py_DECREF(r);
                 JUMPBY(INLINE_CACHE_ENTRIES_FOR_ITER + oparg + 1);
             }
             else {
-                long value = (long)(r->start +
-                                    (unsigned long)(r->index++) * r->step);
+                long value = r->start;
+                r->start = value + r->step;
+                r->len--;
                 if (_PyLong_AssignValue(&GETLOCAL(_Py_OPARG(next)), value) < 0) {
                     goto error;
                 }
diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h
index ae8fdd5e99c..3af60b83d84 100644
--- a/Python/generated_cases.c.h
+++ b/Python/generated_cases.c.h
@@ -2638,14 +2638,15 @@
             STAT_INC(FOR_ITER, hit);
             _Py_CODEUNIT next = next_instr[INLINE_CACHE_ENTRIES_FOR_ITER];
             assert(_PyOpcode_Deopt[_Py_OPCODE(next)] == STORE_FAST);
-            if (r->index >= r->len) {
+            if (r->len <= 0) {
                 STACK_SHRINK(1);
                 Py_DECREF(r);
                 JUMPBY(INLINE_CACHE_ENTRIES_FOR_ITER + oparg + 1);
             }
             else {
-                long value = (long)(r->start +
-                                    (unsigned long)(r->index++) * r->step);
+                long value = r->start;
+                r->start = value + r->step;
+                r->len--;
                 if (_PyLong_AssignValue(&GETLOCAL(_Py_OPARG(next)), value) < 0) {
                     goto error;
                 }
-- 
GitLab