Skip to content

Add escaping string literal in sql for copy&pasting in dbshell #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 13, 2012
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions debug_toolbar/utils/tracking/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
from django.utils.encoding import force_unicode, smart_str
from django.utils.hashcompat import sha_constructor

from debug_toolbar.utils import ms_from_timedelta, tidy_stacktrace, get_template_info, \
get_stack
from debug_toolbar.utils import ms_from_timedelta, tidy_stacktrace, \
get_template_info, get_stack
from debug_toolbar.utils.compat.db import connections
# TODO:This should be set in the toolbar loader as a default and panels should
# get a copy of the toolbar object with access to its config dictionary
SQL_WARNING_THRESHOLD = getattr(settings, 'DEBUG_TOOLBAR_CONFIG', {}) \
.get('SQL_WARNING_THRESHOLD', 500)


class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
pass


class ThreadLocalState(local):
def __init__(self):
self.enabled = True
Expand All @@ -34,12 +36,15 @@ def Wrapper(self):
def recording(self, v):
self.enabled = v


state = ThreadLocalState()
recording = state.recording # export function
recording = state.recording # export function


def CursorWrapper(*args, **kwds): # behave like a class
return state.Wrapper(*args, **kwds)


class ExceptionCursorWrapper(object):
"""
Wraps a cursor and raises an exception on any operation.
Expand All @@ -51,6 +56,7 @@ def __init__(self, cursor, db, logger):
def __getattr__(self, attr):
raise SQLQueryTriggered()


class NormalCursorWrapper(object):
"""
Wraps a cursor and logs queries.
Expand All @@ -63,6 +69,19 @@ def __init__(self, cursor, db, logger):
# logger must implement a ``record`` method
self.logger = logger

def _quote_expr(self, element):
if isinstance(element, basestring):
element = element.replace("'", "''")
return "'%s'" % element
else:
return repr(element)

def _quote_params(self, params):
if isinstance(params, dict):
return dict((key, self._quote_expr(value))
for key, value in params.iteritems())
return map(self._quote_expr, params)

def execute(self, sql, params=()):
__traceback_hide__ = True
start = datetime.now()
Expand All @@ -71,17 +90,20 @@ def execute(self, sql, params=()):
finally:
stop = datetime.now()
duration = ms_from_timedelta(stop - start)
enable_stacktraces = getattr(settings, 'DEBUG_TOOLBAR_CONFIG', {}) \
enable_stacktraces = getattr(settings,
'DEBUG_TOOLBAR_CONFIG', {}) \
.get('ENABLE_STACKTRACES', True)
if enable_stacktraces:
stacktrace = tidy_stacktrace(reversed(get_stack()))
else:
stacktrace = []
_params = ''
try:
_params = simplejson.dumps([force_unicode(x, strings_only=True) for x in params])
_params = simplejson.dumps(
[force_unicode(x, strings_only=True) for x in params]
)
except TypeError:
pass # object not JSON serializable
pass # object not JSON serializable

template_info = None
cur_frame = sys._getframe().f_back
Expand All @@ -108,11 +130,14 @@ def execute(self, sql, params=()):
params = {
'engine': engine,
'alias': alias,
'sql': self.db.ops.last_executed_query(self.cursor, sql, params),
'sql': self.db.ops.last_executed_query(self.cursor, sql,
self._quote_params(params)),
'duration': duration,
'raw_sql': sql,
'params': _params,
'hash': sha_constructor(settings.SECRET_KEY + smart_str(sql) + _params).hexdigest(),
'hash': sha_constructor(settings.SECRET_KEY \
+ smart_str(sql) \
+ _params).hexdigest(),
'stacktrace': stacktrace,
'start_time': start,
'stop_time': stop,
Expand All @@ -129,7 +154,6 @@ def execute(self, sql, params=()):
'encoding': conn.encoding,
})


# We keep `sql` to maintain backwards compatibility
self.logger.record(**params)

Expand Down