django/db/models/functions/datetime.py

from datetime import datetime

from django.conf import settings
from django.db.models.expressions import Func
from django.db.models.fields import (
    DateField,
    DateTimeField,
    DurationField,
    Field,
    IntegerField,
    TimeField,
)
from django.db.models.lookups import (
    Transform,
    YearExact,
    YearGt,
    YearGte,
    YearLt,
    YearLte,
)
from django.utils import timezone


class TimezoneMixin:
    tzinfo = None

    def get_tzname(self):
        # Timezone conversions must happen to the input datetime *before*
        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
        # database as 2016-01-01 01:00:00 +00:00. Any results should be
        # based on the input datetime not the stored datetime.
        tzname = None
        if settings.USE_TZ:
            if self.tzinfo is None:
                tzname = timezone.get_current_timezone_name()
            else:
                tzname = timezone._get_timezone_name(self.tzinfo)
        return tzname


class Extract(TimezoneMixin, Transform):
    lookup_name = None
    output_field = IntegerField()

    def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
        if self.lookup_name is None:
            self.lookup_name = lookup_name
        if self.lookup_name is None:
            raise ValueError("lookup_name must be provided")
        self.tzinfo = tzinfo
        super().__init__(expression, **extra)

    def as_sql(self, compiler, connection):
        sql, params = compiler.compile(self.lhs)
        lhs_output_field = self.lhs.output_field
        if isinstance(lhs_output_field, DateTimeField):
            tzname = self.get_tzname()
            sql, params = connection.ops.datetime_extract_sql(
                self.lookup_name, sql, tuple(params), tzname
            )
        elif self.tzinfo is not None:
            raise ValueError("tzinfo can only be used with DateTimeField.")
        elif isinstance(lhs_output_field, DateField):
            sql, params = connection.ops.date_extract_sql(
                self.lookup_name, sql, tuple(params)
            )
        elif isinstance(lhs_output_field, TimeField):
            sql, params = connection.ops.time_extract_sql(
                self.lookup_name, sql, tuple(params)
            )
        elif isinstance(lhs_output_field, DurationField):
            if not connection.features.has_native_duration_field:
                raise ValueError(
                    "Extract requires native DurationField database support."
                )
            sql, params = connection.ops.time_extract_sql(
                self.lookup_name, sql, tuple(params)
            )
        else:
            # resolve_expression has already validated the output_field so this
            # assert should never be hit.
            assert False, "Tried to Extract from an invalid type."
        return sql, params

    def resolve_expression(
        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
    ):
        copy = super().resolve_expression(
            query, allow_joins, reuse, summarize, for_save
        )
        field = getattr(copy.lhs, "output_field", None)
        if field is None:
            return copy
        if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
            raise ValueError(
                "Extract input expression must be DateField, DateTimeField, "
                "TimeField, or DurationField."
            )
        # Passing dates to functions expecting datetimes is most likely a mistake.
        if type(field) is DateField and copy.lookup_name in (
            "hour",
            "minute",
            "second",
        ):
            raise ValueError(
                "Cannot extract time component '%s' from DateField '%s'."
                % (copy.lookup_name, field.name)
            )
        if isinstance(field, DurationField) and copy.lookup_name in (
            "year",
            "iso_year",
            "month",
            "week",
            "week_day",
            "iso_week_day",
            "quarter",
        ):
            raise ValueError(
                "Cannot extract component '%s' from DurationField '%s'."
                % (copy.lookup_name, field.name)
            )
        return copy


class ExtractYear(Extract):
    lookup_name = "year"


class ExtractIsoYear(Extract):
    """Return the ISO-8601 week-numbering year."""

    lookup_name = "iso_year"


class ExtractMonth(Extract):
    lookup_name = "month"


class ExtractDay(Extract):
    lookup_name = "day"


class ExtractWeek(Extract):
    """
    Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
    week.
    """

    lookup_name = "week"


class ExtractWeekDay(Extract):
    """
    Return Sunday=1 through Saturday=7.

    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
    """

    lookup_name = "week_day"


class ExtractIsoWeekDay(Extract):
    """Return Monday=1 through Sunday=7, based on ISO-8601."""

    lookup_name = "iso_week_day"


class ExtractQuarter(Extract):
    lookup_name = "quarter"


class ExtractHour(Extract):
    lookup_name = "hour"


class ExtractMinute(Extract):
    lookup_name = "minute"


class ExtractSecond(Extract):
    lookup_name = "second"


DateField.register_lookup(ExtractYear)
DateField.register_lookup(ExtractMonth)
DateField.register_lookup(ExtractDay)
DateField.register_lookup(ExtractWeekDay)
DateField.register_lookup(ExtractIsoWeekDay)
DateField.register_lookup(ExtractWeek)
DateField.register_lookup(ExtractIsoYear)
DateField.register_lookup(ExtractQuarter)

TimeField.register_lookup(ExtractHour)
TimeField.register_lookup(ExtractMinute)
TimeField.register_lookup(ExtractSecond)

DateTimeField.register_lookup(ExtractHour)
DateTimeField.register_lookup(ExtractMinute)
DateTimeField.register_lookup(ExtractSecond)

ExtractYear.register_lookup(YearExact)
ExtractYear.register_lookup(YearGt)
ExtractYear.register_lookup(YearGte)
ExtractYear.register_lookup(YearLt)
ExtractYear.register_lookup(YearLte)

ExtractIsoYear.register_lookup(YearExact)
ExtractIsoYear.register_lookup(YearGt)
ExtractIsoYear.register_lookup(YearGte)
ExtractIsoYear.register_lookup(YearLt)
ExtractIsoYear.register_lookup(YearLte)


class Now(Func):
    template = "CURRENT_TIMESTAMP"
    output_field = DateTimeField()

    def as_postgresql(self, compiler, connection, **extra_context):
        # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
        # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
        # other databases.
        return self.as_sql(
            compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
        )

    def as_mysql(self, compiler, connection, **extra_context):
        return self.as_sql(
            compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
        )

    def as_sqlite(self, compiler, connection, **extra_context):
        return self.as_sql(
            compiler,
            connection,
            template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
            **extra_context,
        )


class TruncBase(TimezoneMixin, Transform):
    kind = None
    tzinfo = None

    # RemovedInDjango50Warning: when the deprecation ends, remove is_dst
    # argument.
    def __init__(
        self,
        expression,
        output_field=None,
        tzinfo=None,
        is_dst=timezone.NOT_PASSED,
        **extra,
    ):
        self.tzinfo = tzinfo
        self.is_dst = is_dst
        super().__init__(expression, output_field=output_field, **extra)

    def as_sql(self, compiler, connection):
        sql, params = compiler.compile(self.lhs)
        tzname = None
        if isinstance(self.lhs.output_field, DateTimeField):
            tzname = self.get_tzname()
        elif self.tzinfo is not None:
            raise ValueError("tzinfo can only be used with DateTimeField.")
        if isinstance(self.output_field, DateTimeField):
            sql, params = connection.ops.datetime_trunc_sql(
                self.kind, sql, tuple(params), tzname
            )
        elif isinstance(self.output_field, DateField):
            sql, params = connection.ops.date_trunc_sql(
                self.kind, sql, tuple(params), tzname
            )
        elif isinstance(self.output_field, TimeField):
            sql, params = connection.ops.time_trunc_sql(
                self.kind, sql, tuple(params), tzname
            )
        else:
            raise ValueError(
                "Trunc only valid on DateField, TimeField, or DateTimeField."
            )
        return sql, params

    def resolve_expression(
        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
    ):
        copy = super().resolve_expression(
            query, allow_joins, reuse, summarize, for_save
        )
        field = copy.lhs.output_field
        # DateTimeField is a subclass of DateField so this works for both.
        if not isinstance(field, (DateField, TimeField)):
            raise TypeError(
                "%r isn't a DateField, TimeField, or DateTimeField." % field.name
            )
        # If self.output_field was None, then accessing the field will trigger
        # the resolver to assign it to self.lhs.output_field.
        if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
            raise ValueError(
                "output_field must be either DateField, TimeField, or DateTimeField"
            )
        # Passing dates or times to functions expecting datetimes is most
        # likely a mistake.
        class_output_field = (
            self.__class__.output_field
            if isinstance(self.__class__.output_field, Field)
            else None
        )
        output_field = class_output_field or copy.output_field
        has_explicit_output_field = (
            class_output_field or field.__class__ is not copy.output_field.__class__
        )
        if type(field) is DateField and (
            isinstance(output_field, DateTimeField)
            or copy.kind in ("hour", "minute", "second", "time")
        ):
            raise ValueError(
                "Cannot truncate DateField '%s' to %s."
                % (
                    field.name,
                    output_field.__class__.__name__
                    if has_explicit_output_field
                    else "DateTimeField",
                )
            )
        elif isinstance(field, TimeField) and (
            isinstance(output_field, DateTimeField)
            or copy.kind in ("year", "quarter", "month", "week", "day", "date")
        ):
            raise ValueError(
                "Cannot truncate TimeField '%s' to %s."
                % (
                    field.name,
                    output_field.__class__.__name__
                    if has_explicit_output_field
                    else "DateTimeField",
                )
            )
        return copy

    def convert_value(self, value, expression, connection):
        if isinstance(self.output_field, DateTimeField):
            if not settings.USE_TZ:
                pass
            elif value is not None:
                value = value.replace(tzinfo=None)
                value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
            elif not connection.features.has_zoneinfo_database:
                raise ValueError(
                    "Database returned an invalid datetime value. Are time "
                    "zone definitions for your database installed?"
                )
        elif isinstance(value, datetime):
            if value is None:
                pass
            elif isinstance(self.output_field, DateField):
                value = value.date()
            elif isinstance(self.output_field, TimeField):
                value = value.time()
        return value


class Trunc(TruncBase):
    # RemovedInDjango50Warning: when the deprecation ends, remove is_dst
    # argument.
    def __init__(
        self,
        expression,
        kind,
        output_field=None,
        tzinfo=None,
        is_dst=timezone.NOT_PASSED,
        **extra,
    ):
        self.kind = kind
        super().__init__(
            expression, output_field=output_field, tzinfo=tzinfo, is_dst=is_dst, **extra
        )


class TruncYear(TruncBase):
    kind = "year"


class TruncQuarter(TruncBase):
    kind = "quarter"


class TruncMonth(TruncBase):
    kind = "month"


class TruncWeek(TruncBase):
    """Truncate to midnight on the Monday of the week."""

    kind = "week"


class TruncDay(TruncBase):
    kind = "day"


class TruncDate(TruncBase):
    kind = "date"
    lookup_name = "date"
    output_field = DateField()

    def as_sql(self, compiler, connection):
        # Cast to date rather than truncate to date.
        sql, params = compiler.compile(self.lhs)
        tzname = self.get_tzname()
        return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)


class TruncTime(TruncBase):
    kind = "time"
    lookup_name = "time"
    output_field = TimeField()

    def as_sql(self, compiler, connection):
        # Cast to time rather than truncate to time.
        sql, params = compiler.compile(self.lhs)
        tzname = self.get_tzname()
        return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)


class TruncHour(TruncBase):
    kind = "hour"


class TruncMinute(TruncBase):
    kind = "minute"


class TruncSecond(TruncBase):
    kind = "second"


DateTimeField.register_lookup(TruncDate)
DateTimeField.register_lookup(TruncTime)
Metadata
View Raw File