django/db/backends/base/creation.py

import os
import sys
from io import StringIO

from django.apps import apps
from django.conf import settings
from django.core import serializers
from django.db import router
from django.db.transaction import atomic
from django.utils.module_loading import import_string

# The prefix to put on the default database name when creating
# the test database.
TEST_DATABASE_PREFIX = "test_"


class BaseDatabaseCreation:
    """
    Encapsulate backend-specific differences pertaining to creation and
    destruction of the test database.
    """

    def __init__(self, connection):
        self.connection = connection

    def _nodb_cursor(self):
        return self.connection._nodb_cursor()

    def log(self, msg):
        sys.stderr.write(msg + os.linesep)

    def create_test_db(
        self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
    ):
        """
        Create a test database, prompting the user for confirmation if the
        database already exists. Return the name of the test database created.
        """
        # Don't import django.core.management if it isn't needed.
        from django.core.management import call_command

        test_database_name = self._get_test_db_name()

        if verbosity >= 1:
            action = "Creating"
            if keepdb:
                action = "Using existing"

            self.log(
                "%s test database for alias %s..."
                % (
                    action,
                    self._get_database_display_str(verbosity, test_database_name),
                )
            )

        # We could skip this call if keepdb is True, but we instead
        # give it the keepdb param. This is to handle the case
        # where the test DB doesn't exist, in which case we need to
        # create it, then just not destroy it. If we instead skip
        # this, we will get an exception.
        self._create_test_db(verbosity, autoclobber, keepdb)

        self.connection.close()
        settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
        self.connection.settings_dict["NAME"] = test_database_name

        try:
            if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
                # Disable migrations for all apps.
                old_migration_modules = settings.MIGRATION_MODULES
                settings.MIGRATION_MODULES = {
                    app.label: None for app in apps.get_app_configs()
                }
            # We report migrate messages at one level lower than that
            # requested. This ensures we don't get flooded with messages during
            # testing (unless you really ask to be flooded).
            call_command(
                "migrate",
                verbosity=max(verbosity - 1, 0),
                interactive=False,
                database=self.connection.alias,
                run_syncdb=True,
            )
        finally:
            if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
                settings.MIGRATION_MODULES = old_migration_modules

        # We then serialize the current state of the database into a string
        # and store it on the connection. This slightly horrific process is so people
        # who are testing on databases without transactions or who are using
        # a TransactionTestCase still get a clean database on every test run.
        if serialize:
            self.connection._test_serialized_contents = self.serialize_db_to_string()

        call_command("createcachetable", database=self.connection.alias)

        # Ensure a connection for the side effect of initializing the test database.
        self.connection.ensure_connection()

        if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
            self.mark_expected_failures_and_skips()

        return test_database_name

    def set_as_test_mirror(self, primary_settings_dict):
        """
        Set this database up to be used in testing as a mirror of a primary
        database whose settings are given.
        """
        self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]

    def serialize_db_to_string(self):
        """
        Serialize all data in the database into a JSON string.
        Designed only for test runner usage; will not handle large
        amounts of data.
        """

        # Iteratively return every object for all models to serialize.
        def get_objects():
            from django.db.migrations.loader import MigrationLoader

            loader = MigrationLoader(self.connection)
            for app_config in apps.get_app_configs():
                if (
                    app_config.models_module is not None
                    and app_config.label in loader.migrated_apps
                    and app_config.name not in settings.TEST_NON_SERIALIZED_APPS
                ):
                    for model in app_config.get_models():
                        if model._meta.can_migrate(
                            self.connection
                        ) and router.allow_migrate_model(self.connection.alias, model):
                            queryset = model._base_manager.using(
                                self.connection.alias,
                            ).order_by(model._meta.pk.name)
                            chunk_size = (
                                2000 if queryset._prefetch_related_lookups else None
                            )
                            yield from queryset.iterator(chunk_size=chunk_size)

        # Serialize to a string
        out = StringIO()
        serializers.serialize("json", get_objects(), indent=None, stream=out)
        return out.getvalue()

    def deserialize_db_from_string(self, data):
        """
        Reload the database with data from a string generated by
        the serialize_db_to_string() method.
        """
        data = StringIO(data)
        table_names = set()
        # Load data in a transaction to handle forward references and cycles.
        with atomic(using=self.connection.alias):
            # Disable constraint checks, because some databases (MySQL) doesn't
            # support deferred checks.
            with self.connection.constraint_checks_disabled():
                for obj in serializers.deserialize(
                    "json", data, using=self.connection.alias
                ):
                    obj.save()
                    table_names.add(obj.object.__class__._meta.db_table)
            # Manually check for any invalid keys that might have been added,
            # because constraint checks were disabled.
            self.connection.check_constraints(table_names=table_names)

    def _get_database_display_str(self, verbosity, database_name):
        """
        Return display string for a database for use in various actions.
        """
        return "'%s'%s" % (
            self.connection.alias,
            (" ('%s')" % database_name) if verbosity >= 2 else "",
        )

    def _get_test_db_name(self):
        """
        Internal implementation - return the name of the test DB that will be
        created. Only useful when called from create_test_db() and
        _create_test_db() and when no external munging is done with the 'NAME'
        settings.
        """
        if self.connection.settings_dict["TEST"]["NAME"]:
            return self.connection.settings_dict["TEST"]["NAME"]
        return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]

    def _execute_create_test_db(self, cursor, parameters, keepdb=False):
        cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters)

    def _create_test_db(self, verbosity, autoclobber, keepdb=False):
        """
        Internal implementation - create the test db tables.
        """
        test_database_name = self._get_test_db_name()
        test_db_params = {
            "dbname": self.connection.ops.quote_name(test_database_name),
            "suffix": self.sql_table_creation_suffix(),
        }
        # Create the test database and connect to it.
        with self._nodb_cursor() as cursor:
            try:
                self._execute_create_test_db(cursor, test_db_params, keepdb)
            except Exception as e:
                # if we want to keep the db, then no need to do any of the below,
                # just return and skip it all.
                if keepdb:
                    return test_database_name

                self.log("Got an error creating the test database: %s" % e)
                if not autoclobber:
                    confirm = input(
                        "Type 'yes' if you would like to try deleting the test "
                        "database '%s', or 'no' to cancel: " % test_database_name
                    )
                if autoclobber or confirm == "yes":
                    try:
                        if verbosity >= 1:
                            self.log(
                                "Destroying old test database for alias %s..."
                                % (
                                    self._get_database_display_str(
                                        verbosity, test_database_name
                                    ),
                                )
                            )
                        cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
                        self._execute_create_test_db(cursor, test_db_params, keepdb)
                    except Exception as e:
                        self.log("Got an error recreating the test database: %s" % e)
                        sys.exit(2)
                else:
                    self.log("Tests cancelled.")
                    sys.exit(1)

        return test_database_name

    def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
        """
        Clone a test database.
        """
        source_database_name = self.connection.settings_dict["NAME"]

        if verbosity >= 1:
            action = "Cloning test database"
            if keepdb:
                action = "Using existing clone"
            self.log(
                "%s for alias %s..."
                % (
                    action,
                    self._get_database_display_str(verbosity, source_database_name),
                )
            )

        # We could skip this call if keepdb is True, but we instead
        # give it the keepdb param. See create_test_db for details.
        self._clone_test_db(suffix, verbosity, keepdb)

    def get_test_db_clone_settings(self, suffix):
        """
        Return a modified connection settings dict for the n-th clone of a DB.
        """
        # When this function is called, the test database has been created
        # already and its name has been copied to settings_dict['NAME'] so
        # we don't need to call _get_test_db_name.
        orig_settings_dict = self.connection.settings_dict
        return {
            **orig_settings_dict,
            "NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
        }

    def _clone_test_db(self, suffix, verbosity, keepdb=False):
        """
        Internal implementation - duplicate the test db tables.
        """
        raise NotImplementedError(
            "The database backend doesn't support cloning databases. "
            "Disable the option to run tests in parallel processes."
        )

    def destroy_test_db(
        self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
    ):
        """
        Destroy a test database, prompting the user for confirmation if the
        database already exists.
        """
        self.connection.close()
        if suffix is None:
            test_database_name = self.connection.settings_dict["NAME"]
        else:
            test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]

        if verbosity >= 1:
            action = "Destroying"
            if keepdb:
                action = "Preserving"
            self.log(
                "%s test database for alias %s..."
                % (
                    action,
                    self._get_database_display_str(verbosity, test_database_name),
                )
            )

        # if we want to preserve the database
        # skip the actual destroying piece.
        if not keepdb:
            self._destroy_test_db(test_database_name, verbosity)

        # Restore the original database name
        if old_database_name is not None:
            settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
            self.connection.settings_dict["NAME"] = old_database_name

    def _destroy_test_db(self, test_database_name, verbosity):
        """
        Internal implementation - remove the test db tables.
        """
        # Remove the test database to clean up after
        # ourselves. Connect to the previous database (not the test database)
        # to do so, because it's not allowed to delete a database while being
        # connected to it.
        with self._nodb_cursor() as cursor:
            cursor.execute(
                "DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
            )

    def mark_expected_failures_and_skips(self):
        """
        Mark tests in Django's test suite which are expected failures on this
        database and test which should be skipped on this database.
        """
        # Only load unittest if we're actually testing.
        from unittest import expectedFailure, skip

        for test_name in self.connection.features.django_test_expected_failures:
            test_case_name, _, test_method_name = test_name.rpartition(".")
            test_app = test_name.split(".")[0]
            # Importing a test app that isn't installed raises RuntimeError.
            if test_app in settings.INSTALLED_APPS:
                test_case = import_string(test_case_name)
                test_method = getattr(test_case, test_method_name)
                setattr(test_case, test_method_name, expectedFailure(test_method))
        for reason, tests in self.connection.features.django_test_skips.items():
            for test_name in tests:
                test_case_name, _, test_method_name = test_name.rpartition(".")
                test_app = test_name.split(".")[0]
                # Importing a test app that isn't installed raises RuntimeError.
                if test_app in settings.INSTALLED_APPS:
                    test_case = import_string(test_case_name)
                    test_method = getattr(test_case, test_method_name)
                    setattr(test_case, test_method_name, skip(reason)(test_method))

    def sql_table_creation_suffix(self):
        """
        SQL to append to the end of the test table creation statements.
        """
        return ""

    def test_db_signature(self):
        """
        Return a tuple with elements of self.connection.settings_dict (a
        DATABASES setting value) that uniquely identify a database
        accordingly to the RDBMS particularities.
        """
        settings_dict = self.connection.settings_dict
        return (
            settings_dict["HOST"],
            settings_dict["PORT"],
            settings_dict["ENGINE"],
            self._get_test_db_name(),
        )

    def setup_worker_connection(self, _worker_id):
        settings_dict = self.get_test_db_clone_settings(str(_worker_id))
        # connection.settings_dict must be updated in place for changes to be
        # reflected in django.db.connections. If the following line assigned
        # connection.settings_dict = settings_dict, new threads would connect
        # to the default database instead of the appropriate clone.
        self.connection.settings_dict.update(settings_dict)
        self.connection.close()
Metadata
View Raw File