from collections import defaultdict
from django.contrib.contenttypes.models import ContentType
from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, models, router, transaction
from django.db.models import DO_NOTHING
from django.db.models.base import ModelBase, make_foreign_order_accessors
from django.db.models.fields.mixins import FieldCacheMixin
from django.db.models.fields.related import (
ForeignObject, ForeignObjectRel, ReverseManyToOneDescriptor,
lazy_related_operation,
)
from django.db.models.query_utils import PathInfo
from django.utils.functional import cached_property
class GenericForeignKey(FieldCacheMixin):
"""
Provide a generic many-to-one relation through the ``content_type`` and
``object_id`` fields.
This class also doubles as an accessor to the related object (similar to
ForwardManyToOneDescriptor) by adding itself as a model attribute.
"""
# Field flags
auto_created = False
concrete = False
editable = False
hidden = False
is_relation = True
many_to_many = False
many_to_one = True
one_to_many = False
one_to_one = False
related_model = None
remote_field = None
def __init__(self, ct_field='content_type', fk_field='object_id', for_concrete_model=True):
self.ct_field = ct_field
self.fk_field = fk_field
self.for_concrete_model = for_concrete_model
self.editable = False
self.rel = None
self.column = None
def contribute_to_class(self, cls, name, **kwargs):
self.name = name
self.model = cls
cls._meta.add_field(self, private=True)
setattr(cls, name, self)
def get_filter_kwargs_for_object(self, obj):
"""See corresponding method on Field"""
return {
self.fk_field: getattr(obj, self.fk_field),
self.ct_field: getattr(obj, self.ct_field),
}
def get_forward_related_filter(self, obj):
"""See corresponding method on RelatedField"""
return {
self.fk_field: obj.pk,
self.ct_field: ContentType.objects.get_for_model(obj).pk,
}
def __str__(self):
model = self.model
app = model._meta.app_label
return '%s.%s.%s' % (app, model._meta.object_name, self.name)
def check(self, **kwargs):
errors = []
errors.extend(self._check_field_name())
errors.extend(self._check_object_id_field())
errors.extend(self._check_content_type_field())
return errors
def _check_field_name(self):
if self.name.endswith("_"):
return [
checks.Error(
'Field names must not end with an underscore.',
obj=self,
id='fields.E001',
)
]
else:
return []
def _check_object_id_field(self):
try:
self.model._meta.get_field(self.fk_field)
except FieldDoesNotExist:
return [
checks.Error(
"The GenericForeignKey object ID references the "
"nonexistent field '%s'." % self.fk_field,
obj=self,
id='contenttypes.E001',
)
]
else:
return []
def _check_content_type_field(self):
"""
Check if field named `field_name` in model `model` exists and is a
valid content_type field (is a ForeignKey to ContentType).
"""
try:
field = self.model._meta.get_field(self.ct_field)
except FieldDoesNotExist:
return [
checks.Error(
"The GenericForeignKey content type references the "
"nonexistent field '%s.%s'." % (
self.model._meta.object_name, self.ct_field
),
obj=self,
id='contenttypes.E002',
)
]
else:
if not isinstance(field, models.ForeignKey):
return [
checks.Error(
"'%s.%s' is not a ForeignKey." % (
self.model._meta.object_name, self.ct_field
),
hint=(
"GenericForeignKeys must use a ForeignKey to "
"'contenttypes.ContentType' as the 'content_type' field."
),
obj=self,
id='contenttypes.E003',
)
]
elif field.remote_field.model != ContentType:
return [
checks.Error(
"'%s.%s' is not a ForeignKey to 'contenttypes.ContentType'." % (
self.model._meta.object_name, self.ct_field
),
hint=(
"GenericForeignKeys must use a ForeignKey to "
"'contenttypes.ContentType' as the 'content_type' field."
),
obj=self,
id='contenttypes.E004',
)
]
else:
return []
def get_cache_name(self):
return self.name
def get_content_type(self, obj=None, id=None, using=None):
if obj is not None:
return ContentType.objects.db_manager(obj._state.db).get_for_model(
obj, for_concrete_model=self.for_concrete_model)
elif id is not None:
return ContentType.objects.db_manager(using).get_for_id(id)
else:
# This should never happen. I love comments like this, don't you?
raise Exception("Impossible arguments to GFK.get_content_type!")
def get_prefetch_queryset(self, instances, queryset=None):
if queryset is not None:
raise ValueError("Custom queryset can't be used for this lookup.")
# For efficiency, group the instances by content type and then do one
# query per model
fk_dict = defaultdict(set)
# We need one instance for each group in order to get the right db:
instance_dict = {}
ct_attname = self.model._meta.get_field(self.ct_field).get_attname()
for instance in instances:
# We avoid looking for values if either ct_id or fkey value is None
ct_id = getattr(instance, ct_attname)
if ct_id is not None:
fk_val = getattr(instance, self.fk_field)
if fk_val is not None:
fk_dict[ct_id].add(fk_val)
instance_dict[ct_id] = instance
ret_val = []
for ct_id, fkeys in fk_dict.items():
instance = instance_dict[ct_id]
ct = self.get_content_type(id=ct_id, using=instance._state.db)
ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))
# For doing the join in Python, we have to match both the FK val and the
# content type, so we use a callable that returns a (fk, class) pair.
def gfk_key(obj):
ct_id = getattr(obj, ct_attname)
if ct_id is None:
return None
else:
model = self.get_content_type(id=ct_id,
using=obj._state.db).model_class()
return (model._meta.pk.get_prep_value(getattr(obj, self.fk_field)),
model)
return (
ret_val,
lambda obj: (obj.pk, obj.__class__),
gfk_key,
True,
self.name,
True,
)
def __get__(self, instance, cls=None):
if instance is None:
return self
# Don't use getattr(instance, self.ct_field) here because that might
# reload the same ContentType over and over (#5570). Instead, get the
# content type ID here, and later when the actual instance is needed,
# use ContentType.objects.get_for_id(), which has a global cache.
f = self.model._meta.get_field(self.ct_field)
ct_id = getattr(instance, f.get_attname(), None)
pk_val = getattr(instance, self.fk_field)
rel_obj = self.get_cached_value(instance, default=None)
if rel_obj is not None:
if ct_id != self.get_content_type(obj=rel_obj, using=instance._state.db).id:
rel_obj = None
else:
pk = rel_obj._meta.pk
# If the primary key is a remote field, use the referenced
# field's to_python().
to_python_field = pk
# Out of an abundance of caution, avoid infinite loops.
seen = {to_python_field}
while to_python_field.remote_field:
to_python_field = to_python_field.target_field
if to_python_field in seen:
break
seen.add(to_python_field)
pk_to_python = to_python_field.to_python
if pk_to_python(pk_val) != rel_obj._get_pk_val():
rel_obj = None
else:
return rel_obj
if ct_id is not None:
ct = self.get_content_type(id=ct_id, using=instance._state.db)
try:
rel_obj = ct.get_object_for_this_type(pk=pk_val)
except ObjectDoesNotExist:
pass
self.set_cached_value(instance, rel_obj)
return rel_obj
def __set__(self, instance, value):
ct = None
fk = None
if value is not None:
ct = self.get_content_type(obj=value)
fk = value.pk
setattr(instance, self.ct_field, ct)
setattr(instance, self.fk_field, fk)
self.set_cached_value(instance, value)
class GenericRel(ForeignObjectRel):
"""
Used by GenericRelation to store information about the relation.
"""
def __init__(self, field, to, related_name=None, related_query_name=None, limit_choices_to=None):
super().__init__(
field, to, related_name=related_query_name or '+',
related_query_name=related_query_name,
limit_choices_to=limit_choices_to, on_delete=DO_NOTHING,
)
class GenericRelation(ForeignObject):
"""
Provide a reverse to a relation created by a GenericForeignKey.
"""
# Field flags
auto_created = False
many_to_many = False
many_to_one = False
one_to_many = True
one_to_one = False
rel_class = GenericRel
def __init__(self, to, object_id_field='object_id', content_type_field='content_type',
for_concrete_model=True, related_query_name=None, limit_choices_to=None, **kwargs):
kwargs['rel'] = self.rel_class(
self, to,
related_query_name=related_query_name,
limit_choices_to=limit_choices_to,
)
kwargs['blank'] = True
kwargs['on_delete'] = models.CASCADE
kwargs['editable'] = False
kwargs['serialize'] = False
# This construct is somewhat of an abuse of ForeignObject. This field
# represents a relation from pk to object_id field. But, this relation
# isn't direct, the join is generated reverse along foreign key. So,
# the from_field is object_id field, to_field is pk because of the
# reverse join.
super().__init__(to, from_fields=[object_id_field], to_fields=[], **kwargs)
self.object_id_field_name = object_id_field
self.content_type_field_name = content_type_field
self.for_concrete_model = for_concrete_model
def check(self, **kwargs):
errors = super().check(**kwargs)
errors.extend(self._check_generic_foreign_key_existence())
return errors
def _is_matching_generic_foreign_key(self, field):
"""
Return True if field is a GenericForeignKey whose content type and
object id fields correspond to the equivalent attributes on this
GenericRelation.
"""
return (
isinstance(field, GenericForeignKey) and
field.ct_field == self.content_type_field_name and
field.fk_field == self.object_id_field_name
)
def _check_generic_foreign_key_existence(self):
target = self.remote_field.model
if isinstance(target, ModelBase):
fields = target._meta.private_fields
if any(self._is_matching_generic_foreign_key(field) for field in fields):
return []
else:
return [
checks.Error(
"The GenericRelation defines a relation with the model "
"'%s.%s', but that model does not have a GenericForeignKey." % (
target._meta.app_label, target._meta.object_name
),
obj=self,
id='contenttypes.E004',
)
]
else:
return []
def resolve_related_fields(self):
self.to_fields = [self.model._meta.pk.name]
return [(self.remote_field.model._meta.get_field(self.object_id_field_name), self.model._meta.pk)]
def _get_path_info_with_parent(self, filtered_relation):
"""
Return the path that joins the current model through any parent models.
The idea is that if you have a GFK defined on a parent model then we
need to join the parent model first, then the child model.
"""
# With an inheritance chain ChildTag -> Tag and Tag defines the
# GenericForeignKey, and a TaggedItem model has a GenericRelation to
# ChildTag, then we need to generate a join from TaggedItem to Tag
# (as Tag.object_id == TaggedItem.pk), and another join from Tag to
# ChildTag (as that is where the relation is to). Do this by first
# generating a join to the parent model, then generating joins to the
# child models.
path = []
opts = self.remote_field.model._meta.concrete_model._meta
parent_opts = opts.get_field(self.object_id_field_name).model._meta
target = parent_opts.pk
path.append(PathInfo(
from_opts=self.model._meta,
to_opts=parent_opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
))
# Collect joins needed for the parent -> child chain. This is easiest
# to do if we collect joins for the child -> parent chain and then
# reverse the direction (call to reverse() and use of
# field.remote_field.get_path_info()).
parent_field_chain = []
while parent_opts != opts:
field = opts.get_ancestor_link(parent_opts.model)
parent_field_chain.append(field)
opts = field.remote_field.model._meta
parent_field_chain.reverse()
for field in parent_field_chain:
path.extend(field.remote_field.get_path_info())
return path
def get_path_info(self, filtered_relation=None):
opts = self.remote_field.model._meta
object_id_field = opts.get_field(self.object_id_field_name)
if object_id_field.model != opts.model:
return self._get_path_info_with_parent(filtered_relation)
else:
target = opts.pk
return [PathInfo(
from_opts=self.model._meta,
to_opts=opts,
target_fields=(target,),
join_field=self.remote_field,
m2m=True,
direct=False,
filtered_relation=filtered_relation,
)]
def get_reverse_path_info(self, filtered_relation=None):
opts = self.model._meta
from_opts = self.remote_field.model._meta
return [PathInfo(
from_opts=from_opts,
to_opts=opts,
target_fields=(opts.pk,),
join_field=self,
m2m=not self.unique,
direct=False,
filtered_relation=filtered_relation,
)]
def value_to_string(self, obj):
qs = getattr(obj, self.name).all()
return str([instance.pk for instance in qs])
def contribute_to_class(self, cls, name, **kwargs):
kwargs['private_only'] = True
super().contribute_to_class(cls, name, **kwargs)
self.model = cls
setattr(cls, self.name, ReverseGenericManyToOneDescriptor(self.remote_field))
# Add get_RELATED_order() and set_RELATED_order() to the model this
# field belongs to, if the model on the other end of this relation
# is ordered with respect to its corresponding GenericForeignKey.
if not cls._meta.abstract:
def make_generic_foreign_order_accessors(related_model, model):
if self._is_matching_generic_foreign_key(model._meta.order_with_respect_to):
make_foreign_order_accessors(model, related_model)
lazy_related_operation(make_generic_foreign_order_accessors, self.model, self.remote_field.model)
def set_attributes_from_rel(self):
pass
def get_internal_type(self):
return "ManyToManyField"
def get_content_type(self):
"""
Return the content type associated with this field's model.
"""
return ContentType.objects.get_for_model(self.model,
for_concrete_model=self.for_concrete_model)
def get_extra_restriction(self, where_class, alias, remote_alias):
field = self.remote_field.model._meta.get_field(self.content_type_field_name)
contenttype_pk = self.get_content_type().pk
cond = where_class()
lookup = field.get_lookup('exact')(field.get_col(remote_alias), contenttype_pk)
cond.add(lookup, 'AND')
return cond
def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
"""
Return all objects related to ``objs`` via this ``GenericRelation``.
"""
return self.remote_field.model._base_manager.db_manager(using).filter(**{
"%s__pk" % self.content_type_field_name: ContentType.objects.db_manager(using).get_for_model(
self.model, for_concrete_model=self.for_concrete_model).pk,
"%s__in" % self.object_id_field_name: [obj.pk for obj in objs]
})
class ReverseGenericManyToOneDescriptor(ReverseManyToOneDescriptor):
"""
Accessor to the related objects manager on the one-to-many relation created
by GenericRelation.
In the example::
class Post(Model):
comments = GenericRelation(Comment)
``post.comments`` is a ReverseGenericManyToOneDescriptor instance.
"""
@cached_property
def related_manager_cls(self):
return create_generic_related_manager(
self.rel.model._default_manager.__class__,
self.rel,
)
def create_generic_related_manager(superclass, rel):
"""
Factory function to create a manager that subclasses another manager
(generally the default manager of a given model) and adds behaviors
specific to generic relations.
"""
class GenericRelatedObjectManager(superclass):
def __init__(self, instance=None):
super().__init__()
self.instance = instance
self.model = rel.model
content_type = ContentType.objects.db_manager(instance._state.db).get_for_model(
instance, for_concrete_model=rel.field.for_concrete_model)
self.content_type = content_type
self.content_type_field_name = rel.field.content_type_field_name
self.object_id_field_name = rel.field.object_id_field_name
self.prefetch_cache_name = rel.field.attname
self.pk_val = instance.pk
self.core_filters = {
'%s__pk' % self.content_type_field_name: content_type.id,
self.object_id_field_name: self.pk_val,
}
def __call__(self, *, manager):
manager = getattr(self.model, manager)
manager_class = create_generic_related_manager(manager.__class__, rel)
return manager_class(instance=self.instance)
do_not_call_in_templates = True
def __str__(self):
return repr(self)
def _apply_rel_filters(self, queryset):
"""
Filter the queryset for the instance this manager is bound to.
"""
db = self._db or router.db_for_read(self.model, instance=self.instance)
return queryset.using(db).filter(**self.core_filters)
def get_queryset(self):
try:
return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
except (AttributeError, KeyError):
queryset = super().get_queryset()
return self._apply_rel_filters(queryset)
def get_prefetch_queryset(self, instances, queryset=None):
if queryset is None:
queryset = super().get_queryset()
queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
query = {
'%s__pk' % self.content_type_field_name: self.content_type.id,
'%s__in' % self.object_id_field_name: {obj.pk for obj in instances}
}
# We (possibly) need to convert object IDs to the type of the
# instances' PK in order to match up instances:
object_id_converter = instances[0]._meta.pk.to_python
return (
queryset.filter(**query),
lambda relobj: object_id_converter(getattr(relobj, self.object_id_field_name)),
lambda obj: obj.pk,
False,
self.prefetch_cache_name,
False,
)
def add(self, *objs, bulk=True):
db = router.db_for_write(self.model, instance=self.instance)
def check_and_update_obj(obj):
if not isinstance(obj, self.model):
raise TypeError("'%s' instance expected, got %r" % (
self.model._meta.object_name, obj
))
setattr(obj, self.content_type_field_name, self.content_type)
setattr(obj, self.object_id_field_name, self.pk_val)
if bulk:
pks = []
for obj in objs:
if obj._state.adding or obj._state.db != db:
raise ValueError(
"%r instance isn't saved. Use bulk=False or save "
"the object first." % obj
)
check_and_update_obj(obj)
pks.append(obj.pk)
self.model._base_manager.using(db).filter(pk__in=pks).update(**{
self.content_type_field_name: self.content_type,
self.object_id_field_name: self.pk_val,
})
else:
with transaction.atomic(using=db, savepoint=False):
for obj in objs:
check_and_update_obj(obj)
obj.save()
add.alters_data = True
def remove(self, *objs, bulk=True):
if not objs:
return
self._clear(self.filter(pk__in=[o.pk for o in objs]), bulk)
remove.alters_data = True
def clear(self, *, bulk=True):
self._clear(self, bulk)
clear.alters_data = True
def _clear(self, queryset, bulk):
db = router.db_for_write(self.model, instance=self.instance)
queryset = queryset.using(db)
if bulk:
# `QuerySet.delete()` creates its own atomic block which
# contains the `pre_delete` and `post_delete` signal handlers.
queryset.delete()
else:
with transaction.atomic(using=db, savepoint=False):
for obj in queryset:
obj.delete()
_clear.alters_data = True
def set(self, objs, *, bulk=True, clear=False):
# Force evaluation of `objs` in case it's a queryset whose value
# could be affected by `manager.clear()`. Refs #19816.
objs = tuple(objs)
db = router.db_for_write(self.model, instance=self.instance)
with transaction.atomic(using=db, savepoint=False):
if clear:
self.clear()
self.add(*objs, bulk=bulk)
else:
old_objs = set(self.using(db).all())
new_objs = []
for obj in objs:
if obj in old_objs:
old_objs.remove(obj)
else:
new_objs.append(obj)
self.remove(*old_objs)
self.add(*new_objs, bulk=bulk)
set.alters_data = True
def create(self, **kwargs):
kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val
db = router.db_for_write(self.model, instance=self.instance)
return super().using(db).create(**kwargs)
create.alters_data = True
def get_or_create(self, **kwargs):
kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val
db = router.db_for_write(self.model, instance=self.instance)
return super().using(db).get_or_create(**kwargs)
get_or_create.alters_data = True
def update_or_create(self, **kwargs):
kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val
db = router.db_for_write(self.model, instance=self.instance)
return super().using(db).update_or_create(**kwargs)
update_or_create.alters_data = True
return GenericRelatedObjectManager