django/db/models/fields/generated.py

from django.core import checks
from django.db import connections, router
from django.db.models.sql import Query
from django.utils.functional import cached_property

from . import NOT_PROVIDED, Field

__all__ = ["GeneratedField"]


class GeneratedField(Field):
    generated = True
    db_returning = True

    _query = None
    output_field = None

    def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
        if kwargs.setdefault("editable", False):
            raise ValueError("GeneratedField cannot be editable.")
        if not kwargs.setdefault("blank", True):
            raise ValueError("GeneratedField must be blank.")
        if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
            raise ValueError("GeneratedField cannot have a default.")
        if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
            raise ValueError("GeneratedField cannot have a database default.")
        if db_persist not in (True, False):
            raise ValueError("GeneratedField.db_persist must be True or False.")

        self.expression = expression
        self.output_field = output_field
        self.db_persist = db_persist
        super().__init__(**kwargs)

    @cached_property
    def cached_col(self):
        from django.db.models.expressions import Col

        return Col(self.model._meta.db_table, self, self.output_field)

    def get_col(self, alias, output_field=None):
        if alias != self.model._meta.db_table and output_field in (None, self):
            output_field = self.output_field
        return super().get_col(alias, output_field)

    def contribute_to_class(self, *args, **kwargs):
        super().contribute_to_class(*args, **kwargs)

        self._query = Query(model=self.model, alias_cols=False)
        # Register lookups from the output_field class.
        for lookup_name, lookup in self.output_field.get_class_lookups().items():
            self.register_lookup(lookup, lookup_name=lookup_name)

    def generated_sql(self, connection):
        compiler = connection.ops.compiler("SQLCompiler")(
            self._query, connection=connection, using=None
        )
        resolved_expression = self.expression.resolve_expression(
            self._query, allow_joins=False
        )
        sql, params = compiler.compile(resolved_expression)
        if (
            getattr(self.expression, "conditional", False)
            and not connection.features.supports_boolean_expr_in_select_clause
        ):
            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
        return sql, params

    def check(self, **kwargs):
        databases = kwargs.get("databases") or []
        errors = [
            *super().check(**kwargs),
            *self._check_supported(databases),
            *self._check_persistence(databases),
        ]
        output_field_clone = self.output_field.clone()
        output_field_clone.model = self.model
        output_field_checks = output_field_clone.check(databases=databases)
        if output_field_checks:
            separator = "\n    "
            error_messages = separator.join(
                f"{output_check.msg} ({output_check.id})"
                for output_check in output_field_checks
                if isinstance(output_check, checks.Error)
            )
            if error_messages:
                errors.append(
                    checks.Error(
                        "GeneratedField.output_field has errors:"
                        f"{separator}{error_messages}",
                        obj=self,
                        id="fields.E223",
                    )
                )
            warning_messages = separator.join(
                f"{output_check.msg} ({output_check.id})"
                for output_check in output_field_checks
                if isinstance(output_check, checks.Warning)
            )
            if warning_messages:
                errors.append(
                    checks.Warning(
                        "GeneratedField.output_field has warnings:"
                        f"{separator}{warning_messages}",
                        obj=self,
                        id="fields.W224",
                    )
                )
        return errors

    def _check_supported(self, databases):
        errors = []
        for db in databases:
            if not router.allow_migrate_model(db, self.model):
                continue
            connection = connections[db]
            if (
                self.model._meta.required_db_vendor
                and self.model._meta.required_db_vendor != connection.vendor
            ):
                continue
            if not (
                connection.features.supports_virtual_generated_columns
                or "supports_stored_generated_columns"
                in self.model._meta.required_db_features
            ) and not (
                connection.features.supports_stored_generated_columns
                or "supports_virtual_generated_columns"
                in self.model._meta.required_db_features
            ):
                errors.append(
                    checks.Error(
                        f"{connection.display_name} does not support GeneratedFields.",
                        obj=self,
                        id="fields.E220",
                    )
                )
        return errors

    def _check_persistence(self, databases):
        errors = []
        for db in databases:
            if not router.allow_migrate_model(db, self.model):
                continue
            connection = connections[db]
            if (
                self.model._meta.required_db_vendor
                and self.model._meta.required_db_vendor != connection.vendor
            ):
                continue
            if not self.db_persist and not (
                connection.features.supports_virtual_generated_columns
                or "supports_virtual_generated_columns"
                in self.model._meta.required_db_features
            ):
                errors.append(
                    checks.Error(
                        f"{connection.display_name} does not support non-persisted "
                        "GeneratedFields.",
                        obj=self,
                        id="fields.E221",
                        hint="Set db_persist=True on the field.",
                    )
                )
            if self.db_persist and not (
                connection.features.supports_stored_generated_columns
                or "supports_stored_generated_columns"
                in self.model._meta.required_db_features
            ):
                errors.append(
                    checks.Error(
                        f"{connection.display_name} does not support persisted "
                        "GeneratedFields.",
                        obj=self,
                        id="fields.E222",
                        hint="Set db_persist=False on the field.",
                    )
                )
        return errors

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        del kwargs["blank"]
        del kwargs["editable"]
        kwargs["db_persist"] = self.db_persist
        kwargs["expression"] = self.expression
        kwargs["output_field"] = self.output_field
        return name, path, args, kwargs

    def get_internal_type(self):
        return self.output_field.get_internal_type()

    def db_parameters(self, connection):
        return self.output_field.db_parameters(connection)

    def db_type_parameters(self, connection):
        return self.output_field.db_type_parameters(connection)
Metadata
View Raw File