"""
Implementations of SQL functions for SQLite.
"""
import functools
import random
import statistics
import zoneinfo
from datetime import timedelta
from hashlib import md5, sha1, sha224, sha256, sha384, sha512
from math import (
acos,
asin,
atan,
atan2,
ceil,
cos,
degrees,
exp,
floor,
fmod,
log,
pi,
radians,
sin,
sqrt,
tan,
)
from re import search as re_search
from django.db.backends.utils import (
split_tzname_delta,
typecast_time,
typecast_timestamp,
)
from django.utils import timezone
from django.utils.duration import duration_microseconds
def register(connection):
create_deterministic_function = functools.partial(
connection.create_function,
deterministic=True,
)
create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract)
create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
create_deterministic_function(
"django_datetime_cast_date", 3, _sqlite_datetime_cast_date
)
create_deterministic_function(
"django_datetime_cast_time", 3, _sqlite_datetime_cast_time
)
create_deterministic_function(
"django_datetime_extract", 4, _sqlite_datetime_extract
)
create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc)
create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
create_deterministic_function("regexp", 2, _sqlite_regexp)
create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
create_deterministic_function("COT", 1, _sqlite_cot)
create_deterministic_function("LPAD", 3, _sqlite_lpad)
create_deterministic_function("MD5", 1, _sqlite_md5)
create_deterministic_function("REPEAT", 2, _sqlite_repeat)
create_deterministic_function("REVERSE", 1, _sqlite_reverse)
create_deterministic_function("RPAD", 3, _sqlite_rpad)
create_deterministic_function("SHA1", 1, _sqlite_sha1)
create_deterministic_function("SHA224", 1, _sqlite_sha224)
create_deterministic_function("SHA256", 1, _sqlite_sha256)
create_deterministic_function("SHA384", 1, _sqlite_sha384)
create_deterministic_function("SHA512", 1, _sqlite_sha512)
create_deterministic_function("SIGN", 1, _sqlite_sign)
# Don't use the built-in RANDOM() function because it returns a value
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
connection.create_function("RAND", 0, random.random)
connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
connection.create_aggregate("VAR_POP", 1, VarPop)
connection.create_aggregate("VAR_SAMP", 1, VarSamp)
# Some math functions are enabled by default in SQLite 3.35+.
sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')"
if not connection.execute(sql).fetchone()[0]:
create_deterministic_function("ACOS", 1, _sqlite_acos)
create_deterministic_function("ASIN", 1, _sqlite_asin)
create_deterministic_function("ATAN", 1, _sqlite_atan)
create_deterministic_function("ATAN2", 2, _sqlite_atan2)
create_deterministic_function("CEILING", 1, _sqlite_ceiling)
create_deterministic_function("COS", 1, _sqlite_cos)
create_deterministic_function("DEGREES", 1, _sqlite_degrees)
create_deterministic_function("EXP", 1, _sqlite_exp)
create_deterministic_function("FLOOR", 1, _sqlite_floor)
create_deterministic_function("LN", 1, _sqlite_ln)
create_deterministic_function("LOG", 2, _sqlite_log)
create_deterministic_function("MOD", 2, _sqlite_mod)
create_deterministic_function("PI", 0, _sqlite_pi)
create_deterministic_function("POWER", 2, _sqlite_power)
create_deterministic_function("RADIANS", 1, _sqlite_radians)
create_deterministic_function("SIN", 1, _sqlite_sin)
create_deterministic_function("SQRT", 1, _sqlite_sqrt)
create_deterministic_function("TAN", 1, _sqlite_tan)
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
if dt is None:
return None
try:
dt = typecast_timestamp(dt)
except (TypeError, ValueError):
return None
if conn_tzname:
dt = dt.replace(tzinfo=zoneinfo.ZoneInfo(conn_tzname))
if tzname is not None and tzname != conn_tzname:
tzname, sign, offset = split_tzname_delta(tzname)
if offset:
hours, minutes = offset.split(":")
offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
dt += offset_delta if sign == "+" else -offset_delta
dt = timezone.localtime(dt, zoneinfo.ZoneInfo(tzname))
return dt
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
return f"{dt.year:04d}-01-01"
elif lookup_type == "quarter":
month_in_quarter = dt.month - (dt.month - 1) % 3
return f"{dt.year:04d}-{month_in_quarter:02d}-01"
elif lookup_type == "month":
return f"{dt.year:04d}-{dt.month:02d}-01"
elif lookup_type == "week":
dt -= timedelta(days=dt.weekday())
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
elif lookup_type == "day":
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
if dt is None:
return None
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt_parsed is None:
try:
dt = typecast_time(dt)
except (ValueError, TypeError):
return None
else:
dt = dt_parsed
if lookup_type == "hour":
return f"{dt.hour:02d}:00:00"
elif lookup_type == "minute":
return f"{dt.hour:02d}:{dt.minute:02d}:00"
elif lookup_type == "second":
return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.date().isoformat()
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
return dt.time().isoformat()
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "week_day":
return (dt.isoweekday() % 7) + 1
elif lookup_type == "iso_week_day":
return dt.isoweekday()
elif lookup_type == "week":
return dt.isocalendar().week
elif lookup_type == "quarter":
return ceil(dt.month / 3)
elif lookup_type == "iso_year":
return dt.isocalendar().year
else:
return getattr(dt, lookup_type)
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
if dt is None:
return None
if lookup_type == "year":
return f"{dt.year:04d}-01-01 00:00:00"
elif lookup_type == "quarter":
month_in_quarter = dt.month - (dt.month - 1) % 3
return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
elif lookup_type == "month":
return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
elif lookup_type == "week":
dt -= timedelta(days=dt.weekday())
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
elif lookup_type == "day":
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
elif lookup_type == "hour":
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
elif lookup_type == "minute":
return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
f"{dt.hour:02d}:{dt.minute:02d}:00"
)
elif lookup_type == "second":
return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
)
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
def _sqlite_time_extract(lookup_type, dt):
if dt is None:
return None
try:
dt = typecast_time(dt)
except (ValueError, TypeError):
return None
return getattr(dt, lookup_type)
def _sqlite_prepare_dtdelta_param(conn, param):
if conn in ["+", "-"]:
if isinstance(param, int):
return timedelta(0, 0, param)
else:
return typecast_timestamp(param)
return param
def _sqlite_format_dtdelta(connector, lhs, rhs):
"""
LHS and RHS can be either:
- An integer number of microseconds
- A string representing a datetime
- A scalar value, e.g. float
"""
if connector is None or lhs is None or rhs is None:
return None
connector = connector.strip()
try:
real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
except (ValueError, TypeError):
return None
if connector == "+":
# typecast_timestamp() returns a date or a datetime without timezone.
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
out = str(real_lhs + real_rhs)
elif connector == "-":
out = str(real_lhs - real_rhs)
elif connector == "*":
out = real_lhs * real_rhs
else:
out = real_lhs / real_rhs
return out
def _sqlite_time_diff(lhs, rhs):
if lhs is None or rhs is None:
return None
left = typecast_time(lhs)
right = typecast_time(rhs)
return (
(left.hour * 60 * 60 * 1000000)
+ (left.minute * 60 * 1000000)
+ (left.second * 1000000)
+ (left.microsecond)
- (right.hour * 60 * 60 * 1000000)
- (right.minute * 60 * 1000000)
- (right.second * 1000000)
- (right.microsecond)
)
def _sqlite_timestamp_diff(lhs, rhs):
if lhs is None or rhs is None:
return None
left = typecast_timestamp(lhs)
right = typecast_timestamp(rhs)
return duration_microseconds(left - right)
def _sqlite_regexp(pattern, string):
if pattern is None or string is None:
return None
if not isinstance(string, str):
string = str(string)
return bool(re_search(pattern, string))
def _sqlite_acos(x):
if x is None:
return None
return acos(x)
def _sqlite_asin(x):
if x is None:
return None
return asin(x)
def _sqlite_atan(x):
if x is None:
return None
return atan(x)
def _sqlite_atan2(y, x):
if y is None or x is None:
return None
return atan2(y, x)
def _sqlite_bitxor(x, y):
if x is None or y is None:
return None
return x ^ y
def _sqlite_ceiling(x):
if x is None:
return None
return ceil(x)
def _sqlite_cos(x):
if x is None:
return None
return cos(x)
def _sqlite_cot(x):
if x is None:
return None
return 1 / tan(x)
def _sqlite_degrees(x):
if x is None:
return None
return degrees(x)
def _sqlite_exp(x):
if x is None:
return None
return exp(x)
def _sqlite_floor(x):
if x is None:
return None
return floor(x)
def _sqlite_ln(x):
if x is None:
return None
return log(x)
def _sqlite_log(base, x):
if base is None or x is None:
return None
# Arguments reversed to match SQL standard.
return log(x, base)
def _sqlite_lpad(text, length, fill_text):
if text is None or length is None or fill_text is None:
return None
delta = length - len(text)
if delta <= 0:
return text[:length]
return (fill_text * length)[:delta] + text
def _sqlite_md5(text):
if text is None:
return None
return md5(text.encode()).hexdigest()
def _sqlite_mod(x, y):
if x is None or y is None:
return None
return fmod(x, y)
def _sqlite_pi():
return pi
def _sqlite_power(x, y):
if x is None or y is None:
return None
return x**y
def _sqlite_radians(x):
if x is None:
return None
return radians(x)
def _sqlite_repeat(text, count):
if text is None or count is None:
return None
return text * count
def _sqlite_reverse(text):
if text is None:
return None
return text[::-1]
def _sqlite_rpad(text, length, fill_text):
if text is None or length is None or fill_text is None:
return None
return (text + fill_text * length)[:length]
def _sqlite_sha1(text):
if text is None:
return None
return sha1(text.encode()).hexdigest()
def _sqlite_sha224(text):
if text is None:
return None
return sha224(text.encode()).hexdigest()
def _sqlite_sha256(text):
if text is None:
return None
return sha256(text.encode()).hexdigest()
def _sqlite_sha384(text):
if text is None:
return None
return sha384(text.encode()).hexdigest()
def _sqlite_sha512(text):
if text is None:
return None
return sha512(text.encode()).hexdigest()
def _sqlite_sign(x):
if x is None:
return None
return (x > 0) - (x < 0)
def _sqlite_sin(x):
if x is None:
return None
return sin(x)
def _sqlite_sqrt(x):
if x is None:
return None
return sqrt(x)
def _sqlite_tan(x):
if x is None:
return None
return tan(x)
class ListAggregate(list):
step = list.append
class StdDevPop(ListAggregate):
finalize = statistics.pstdev
class StdDevSamp(ListAggregate):
finalize = statistics.stdev
class VarPop(ListAggregate):
finalize = statistics.pvariance
class VarSamp(ListAggregate):
finalize = statistics.variance