|
4 | 4 | from time import time
|
5 | 5 |
|
6 | 6 | import django.test.testcases
|
| 7 | +from django.db.backends.utils import CursorWrapper |
7 | 8 | from django.utils.encoding import force_str
|
8 | 9 |
|
9 | 10 | from debug_toolbar import settings as dt_settings
|
@@ -57,54 +58,47 @@ def cursor(*args, **kwargs):
|
57 | 58 | wrapper = NormalCursorWrapper
|
58 | 59 | else:
|
59 | 60 | wrapper = ExceptionCursorWrapper
|
60 |
| - return wrapper(cursor, connection, logger) |
| 61 | + return wrapper(cursor.cursor, connection, logger) |
61 | 62 |
|
62 | 63 | def chunked_cursor(*args, **kwargs):
|
63 | 64 | # prevent double wrapping
|
64 | 65 | # solves https://github.com/jazzband/django-debug-toolbar/issues/1239
|
65 | 66 | logger = connection._djdt_logger
|
66 | 67 | cursor = connection._djdt_chunked_cursor(*args, **kwargs)
|
67 |
| - if logger is not None and not isinstance(cursor, BaseCursorWrapper): |
| 68 | + if logger is not None and not isinstance(cursor, DjDTCursorWrapper): |
68 | 69 | if allow_sql.get():
|
69 | 70 | wrapper = NormalCursorWrapper
|
70 | 71 | else:
|
71 | 72 | wrapper = ExceptionCursorWrapper
|
72 |
| - return wrapper(cursor, connection, logger) |
| 73 | + return wrapper(cursor.cursor, connection, logger) |
73 | 74 | return cursor
|
74 | 75 |
|
75 | 76 | connection.cursor = cursor
|
76 | 77 | connection.chunked_cursor = chunked_cursor
|
77 | 78 |
|
78 | 79 |
|
79 |
| -class BaseCursorWrapper: |
80 |
| - pass |
| 80 | +class DjDTCursorWrapper(CursorWrapper): |
| 81 | + def __init__(self, cursor, db, logger): |
| 82 | + super().__init__(cursor, db) |
| 83 | + # logger must implement a ``record`` method |
| 84 | + self.logger = logger |
81 | 85 |
|
82 | 86 |
|
83 |
| -class ExceptionCursorWrapper(BaseCursorWrapper): |
| 87 | +class ExceptionCursorWrapper(DjDTCursorWrapper): |
84 | 88 | """
|
85 | 89 | Wraps a cursor and raises an exception on any operation.
|
86 | 90 | Used in Templates panel.
|
87 | 91 | """
|
88 | 92 |
|
89 |
| - def __init__(self, cursor, db, logger): |
90 |
| - pass |
91 |
| - |
92 | 93 | def __getattr__(self, attr):
|
93 | 94 | raise SQLQueryTriggered()
|
94 | 95 |
|
95 | 96 |
|
96 |
| -class NormalCursorWrapper(BaseCursorWrapper): |
| 97 | +class NormalCursorWrapper(DjDTCursorWrapper): |
97 | 98 | """
|
98 | 99 | Wraps a cursor and logs queries.
|
99 | 100 | """
|
100 | 101 |
|
101 |
| - def __init__(self, cursor, db, logger): |
102 |
| - self.cursor = cursor |
103 |
| - # Instance of a BaseDatabaseWrapper subclass |
104 |
| - self.db = db |
105 |
| - # logger must implement a ``record`` method |
106 |
| - self.logger = logger |
107 |
| - |
108 | 102 | def _quote_expr(self, element):
|
109 | 103 | if isinstance(element, str):
|
110 | 104 | return "'%s'" % element.replace("'", "''")
|
@@ -246,22 +240,10 @@ def _record(self, method, sql, params):
|
246 | 240 | self.logger.record(**params)
|
247 | 241 |
|
248 | 242 | def callproc(self, procname, params=None):
|
249 |
| - return self._record(self.cursor.callproc, procname, params) |
| 243 | + return self._record(super().callproc, procname, params) |
250 | 244 |
|
251 | 245 | def execute(self, sql, params=None):
|
252 |
| - return self._record(self.cursor.execute, sql, params) |
| 246 | + return self._record(super().execute, sql, params) |
253 | 247 |
|
254 | 248 | def executemany(self, sql, param_list):
|
255 |
| - return self._record(self.cursor.executemany, sql, param_list) |
256 |
| - |
257 |
| - def __getattr__(self, attr): |
258 |
| - return getattr(self.cursor, attr) |
259 |
| - |
260 |
| - def __iter__(self): |
261 |
| - return iter(self.cursor) |
262 |
| - |
263 |
| - def __enter__(self): |
264 |
| - return self |
265 |
| - |
266 |
| - def __exit__(self, type, value, traceback): |
267 |
| - self.close() |
| 249 | + return self._record(super().executemany, sql, param_list) |
0 commit comments