login page
This commit is contained in:
21
Lib/site-packages/django/test/__init__.py
Normal file
21
Lib/site-packages/django/test/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Django Unit Test framework."""
|
||||
|
||||
from django.test.client import (
|
||||
AsyncClient, AsyncRequestFactory, Client, RequestFactory,
|
||||
)
|
||||
from django.test.testcases import (
|
||||
LiveServerTestCase, SimpleTestCase, TestCase, TransactionTestCase,
|
||||
skipIfDBFeature, skipUnlessAnyDBFeature, skipUnlessDBFeature,
|
||||
)
|
||||
from django.test.utils import (
|
||||
ignore_warnings, modify_settings, override_settings,
|
||||
override_system_checks, tag,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'AsyncClient', 'AsyncRequestFactory', 'Client', 'RequestFactory',
|
||||
'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase',
|
||||
'skipIfDBFeature', 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature',
|
||||
'ignore_warnings', 'modify_settings', 'override_settings',
|
||||
'override_system_checks', 'tag',
|
||||
]
|
||||
Binary file not shown.
BIN
Lib/site-packages/django/test/__pycache__/client.cpython-38.pyc
Normal file
BIN
Lib/site-packages/django/test/__pycache__/client.cpython-38.pyc
Normal file
Binary file not shown.
BIN
Lib/site-packages/django/test/__pycache__/html.cpython-38.pyc
Normal file
BIN
Lib/site-packages/django/test/__pycache__/html.cpython-38.pyc
Normal file
Binary file not shown.
BIN
Lib/site-packages/django/test/__pycache__/runner.cpython-38.pyc
Normal file
BIN
Lib/site-packages/django/test/__pycache__/runner.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
Lib/site-packages/django/test/__pycache__/signals.cpython-38.pyc
Normal file
BIN
Lib/site-packages/django/test/__pycache__/signals.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
Lib/site-packages/django/test/__pycache__/utils.cpython-38.pyc
Normal file
BIN
Lib/site-packages/django/test/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
913
Lib/site-packages/django/test/client.py
Normal file
913
Lib/site-packages/django/test/client.py
Normal file
@@ -0,0 +1,913 @@
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from importlib import import_module
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.handlers.asgi import ASGIRequest
|
||||
from django.core.handlers.base import BaseHandler
|
||||
from django.core.handlers.wsgi import WSGIRequest
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.core.signals import (
|
||||
got_request_exception, request_finished, request_started,
|
||||
)
|
||||
from django.db import close_old_connections
|
||||
from django.http import HttpRequest, QueryDict, SimpleCookie
|
||||
from django.test import signals
|
||||
from django.test.utils import ContextList
|
||||
from django.urls import resolve
|
||||
from django.utils.encoding import force_bytes
|
||||
from django.utils.functional import SimpleLazyObject
|
||||
from django.utils.http import urlencode
|
||||
from django.utils.itercompat import is_iterable
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
__all__ = ('Client', 'RedirectCycleError', 'RequestFactory', 'encode_file', 'encode_multipart')
|
||||
|
||||
|
||||
BOUNDARY = 'BoUnDaRyStRiNg'
|
||||
MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
|
||||
CONTENT_TYPE_RE = _lazy_re_compile(r'.*; charset=([\w\d-]+);?')
|
||||
# Structured suffix spec: https://tools.ietf.org/html/rfc6838#section-4.2.8
|
||||
JSON_CONTENT_TYPE_RE = _lazy_re_compile(r'^application\/(.+\+)?json')
|
||||
|
||||
|
||||
class RedirectCycleError(Exception):
|
||||
"""The test client has been asked to follow a redirect loop."""
|
||||
def __init__(self, message, last_response):
|
||||
super().__init__(message)
|
||||
self.last_response = last_response
|
||||
self.redirect_chain = last_response.redirect_chain
|
||||
|
||||
|
||||
class FakePayload:
|
||||
"""
|
||||
A wrapper around BytesIO that restricts what can be read since data from
|
||||
the network can't be sought and cannot be read outside of its content
|
||||
length. This makes sure that views can't do anything under the test client
|
||||
that wouldn't work in real life.
|
||||
"""
|
||||
def __init__(self, content=None):
|
||||
self.__content = BytesIO()
|
||||
self.__len = 0
|
||||
self.read_started = False
|
||||
if content is not None:
|
||||
self.write(content)
|
||||
|
||||
def __len__(self):
|
||||
return self.__len
|
||||
|
||||
def read(self, num_bytes=None):
|
||||
if not self.read_started:
|
||||
self.__content.seek(0)
|
||||
self.read_started = True
|
||||
if num_bytes is None:
|
||||
num_bytes = self.__len or 0
|
||||
assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
|
||||
content = self.__content.read(num_bytes)
|
||||
self.__len -= num_bytes
|
||||
return content
|
||||
|
||||
def write(self, content):
|
||||
if self.read_started:
|
||||
raise ValueError("Unable to write a payload after it's been read")
|
||||
content = force_bytes(content)
|
||||
self.__content.write(content)
|
||||
self.__len += len(content)
|
||||
|
||||
|
||||
def closing_iterator_wrapper(iterable, close):
|
||||
try:
|
||||
yield from iterable
|
||||
finally:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
|
||||
def conditional_content_removal(request, response):
|
||||
"""
|
||||
Simulate the behavior of most Web servers by removing the content of
|
||||
responses for HEAD requests, 1xx, 204, and 304 responses. Ensure
|
||||
compliance with RFC 7230, section 3.3.3.
|
||||
"""
|
||||
if 100 <= response.status_code < 200 or response.status_code in (204, 304):
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b''
|
||||
if request.method == 'HEAD':
|
||||
if response.streaming:
|
||||
response.streaming_content = []
|
||||
else:
|
||||
response.content = b''
|
||||
return response
|
||||
|
||||
|
||||
class ClientHandler(BaseHandler):
|
||||
"""
|
||||
A HTTP Handler that can be used for testing purposes. Use the WSGI
|
||||
interface to compose requests, but return the raw HttpResponse object with
|
||||
the originating WSGIRequest attached to its ``wsgi_request`` attribute.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, environ):
|
||||
# Set up middleware if needed. We couldn't do this earlier, because
|
||||
# settings weren't available.
|
||||
if self._middleware_chain is None:
|
||||
self.load_middleware()
|
||||
|
||||
request_started.disconnect(close_old_connections)
|
||||
request_started.send(sender=self.__class__, environ=environ)
|
||||
request_started.connect(close_old_connections)
|
||||
request = WSGIRequest(environ)
|
||||
# sneaky little hack so that we can easily get round
|
||||
# CsrfViewMiddleware. This makes life easier, and is probably
|
||||
# required for backwards compatibility with external tests against
|
||||
# admin views.
|
||||
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
||||
|
||||
# Request goes through middleware.
|
||||
response = self.get_response(request)
|
||||
|
||||
# Simulate behaviors of most Web servers.
|
||||
conditional_content_removal(request, response)
|
||||
|
||||
# Attach the originating request to the response so that it could be
|
||||
# later retrieved.
|
||||
response.wsgi_request = request
|
||||
|
||||
# Emulate a WSGI server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = closing_iterator_wrapper(
|
||||
response.streaming_content, response.close)
|
||||
else:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
response.close() # will fire request_finished
|
||||
request_finished.connect(close_old_connections)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AsyncClientHandler(BaseHandler):
|
||||
"""An async version of ClientHandler."""
|
||||
def __init__(self, enforce_csrf_checks=True, *args, **kwargs):
|
||||
self.enforce_csrf_checks = enforce_csrf_checks
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def __call__(self, scope):
|
||||
# Set up middleware if needed. We couldn't do this earlier, because
|
||||
# settings weren't available.
|
||||
if self._middleware_chain is None:
|
||||
self.load_middleware(is_async=True)
|
||||
# Extract body file from the scope, if provided.
|
||||
if '_body_file' in scope:
|
||||
body_file = scope.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload('')
|
||||
|
||||
request_started.disconnect(close_old_connections)
|
||||
await sync_to_async(request_started.send)(sender=self.__class__, scope=scope)
|
||||
request_started.connect(close_old_connections)
|
||||
request = ASGIRequest(scope, body_file)
|
||||
# Sneaky little hack so that we can easily get round
|
||||
# CsrfViewMiddleware. This makes life easier, and is probably required
|
||||
# for backwards compatibility with external tests against admin views.
|
||||
request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
|
||||
# Request goes through middleware.
|
||||
response = await self.get_response_async(request)
|
||||
# Simulate behaviors of most Web servers.
|
||||
conditional_content_removal(request, response)
|
||||
# Attach the originating ASGI request to the response so that it could
|
||||
# be later retrieved.
|
||||
response.asgi_request = request
|
||||
# Emulate a server by calling the close method on completion.
|
||||
if response.streaming:
|
||||
response.streaming_content = await sync_to_async(closing_iterator_wrapper)(
|
||||
response.streaming_content,
|
||||
response.close,
|
||||
)
|
||||
else:
|
||||
request_finished.disconnect(close_old_connections)
|
||||
# Will fire request_finished.
|
||||
await sync_to_async(response.close)()
|
||||
request_finished.connect(close_old_connections)
|
||||
return response
|
||||
|
||||
|
||||
def store_rendered_templates(store, signal, sender, template, context, **kwargs):
|
||||
"""
|
||||
Store templates and contexts that are rendered.
|
||||
|
||||
The context is copied so that it is an accurate representation at the time
|
||||
of rendering.
|
||||
"""
|
||||
store.setdefault('templates', []).append(template)
|
||||
if 'context' not in store:
|
||||
store['context'] = ContextList()
|
||||
store['context'].append(copy(context))
|
||||
|
||||
|
||||
def encode_multipart(boundary, data):
|
||||
"""
|
||||
Encode multipart POST data from a dictionary of form values.
|
||||
|
||||
The key will be used as the form data name; the value will be transmitted
|
||||
as content. If the value is a file, the contents of the file will be sent
|
||||
as an application/octet-stream; otherwise, str(value) will be sent.
|
||||
"""
|
||||
lines = []
|
||||
|
||||
def to_bytes(s):
|
||||
return force_bytes(s, settings.DEFAULT_CHARSET)
|
||||
|
||||
# Not by any means perfect, but good enough for our purposes.
|
||||
def is_file(thing):
|
||||
return hasattr(thing, "read") and callable(thing.read)
|
||||
|
||||
# Each bit of the multipart form data could be either a form value or a
|
||||
# file, or a *list* of form values and/or files. Remember that HTTP field
|
||||
# names can be duplicated!
|
||||
for (key, value) in data.items():
|
||||
if value is None:
|
||||
raise TypeError(
|
||||
"Cannot encode None for key '%s' as POST data. Did you mean "
|
||||
"to pass an empty string or omit the value?" % key
|
||||
)
|
||||
elif is_file(value):
|
||||
lines.extend(encode_file(boundary, key, value))
|
||||
elif not isinstance(value, str) and is_iterable(value):
|
||||
for item in value:
|
||||
if is_file(item):
|
||||
lines.extend(encode_file(boundary, key, item))
|
||||
else:
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
item
|
||||
])
|
||||
else:
|
||||
lines.extend(to_bytes(val) for val in [
|
||||
'--%s' % boundary,
|
||||
'Content-Disposition: form-data; name="%s"' % key,
|
||||
'',
|
||||
value
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
to_bytes('--%s--' % boundary),
|
||||
b'',
|
||||
])
|
||||
return b'\r\n'.join(lines)
|
||||
|
||||
|
||||
def encode_file(boundary, key, file):
|
||||
def to_bytes(s):
|
||||
return force_bytes(s, settings.DEFAULT_CHARSET)
|
||||
|
||||
# file.name might not be a string. For example, it's an int for
|
||||
# tempfile.TemporaryFile().
|
||||
file_has_string_name = hasattr(file, 'name') and isinstance(file.name, str)
|
||||
filename = os.path.basename(file.name) if file_has_string_name else ''
|
||||
|
||||
if hasattr(file, 'content_type'):
|
||||
content_type = file.content_type
|
||||
elif filename:
|
||||
content_type = mimetypes.guess_type(filename)[0]
|
||||
else:
|
||||
content_type = None
|
||||
|
||||
if content_type is None:
|
||||
content_type = 'application/octet-stream'
|
||||
filename = filename or key
|
||||
return [
|
||||
to_bytes('--%s' % boundary),
|
||||
to_bytes('Content-Disposition: form-data; name="%s"; filename="%s"'
|
||||
% (key, filename)),
|
||||
to_bytes('Content-Type: %s' % content_type),
|
||||
b'',
|
||||
to_bytes(file.read())
|
||||
]
|
||||
|
||||
|
||||
class RequestFactory:
|
||||
"""
|
||||
Class that lets you create mock Request objects for use in testing.
|
||||
|
||||
Usage:
|
||||
|
||||
rf = RequestFactory()
|
||||
get_request = rf.get('/hello/')
|
||||
post_request = rf.post('/submit/', {'foo': 'bar'})
|
||||
|
||||
Once you have a request object you can pass it to any view function,
|
||||
just as if that view had been hooked up using a URLconf.
|
||||
"""
|
||||
def __init__(self, *, json_encoder=DjangoJSONEncoder, **defaults):
|
||||
self.json_encoder = json_encoder
|
||||
self.defaults = defaults
|
||||
self.cookies = SimpleCookie()
|
||||
self.errors = BytesIO()
|
||||
|
||||
def _base_environ(self, **request):
|
||||
"""
|
||||
The base environment for a request.
|
||||
"""
|
||||
# This is a minimal valid WSGI environ dictionary, plus:
|
||||
# - HTTP_COOKIE: for cookie support,
|
||||
# - REMOTE_ADDR: often useful, see #8551.
|
||||
# See https://www.python.org/dev/peps/pep-3333/#environ-variables
|
||||
return {
|
||||
'HTTP_COOKIE': '; '.join(sorted(
|
||||
'%s=%s' % (morsel.key, morsel.coded_value)
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
'PATH_INFO': '/',
|
||||
'REMOTE_ADDR': '127.0.0.1',
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'SCRIPT_NAME': '',
|
||||
'SERVER_NAME': 'testserver',
|
||||
'SERVER_PORT': '80',
|
||||
'SERVER_PROTOCOL': 'HTTP/1.1',
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.url_scheme': 'http',
|
||||
'wsgi.input': FakePayload(b''),
|
||||
'wsgi.errors': self.errors,
|
||||
'wsgi.multiprocess': True,
|
||||
'wsgi.multithread': False,
|
||||
'wsgi.run_once': False,
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
|
||||
def request(self, **request):
|
||||
"Construct a generic request object."
|
||||
return WSGIRequest(self._base_environ(**request))
|
||||
|
||||
def _encode_data(self, data, content_type):
|
||||
if content_type is MULTIPART_CONTENT:
|
||||
return encode_multipart(BOUNDARY, data)
|
||||
else:
|
||||
# Encode the content so that the byte representation is correct.
|
||||
match = CONTENT_TYPE_RE.match(content_type)
|
||||
if match:
|
||||
charset = match[1]
|
||||
else:
|
||||
charset = settings.DEFAULT_CHARSET
|
||||
return force_bytes(data, encoding=charset)
|
||||
|
||||
def _encode_json(self, data, content_type):
|
||||
"""
|
||||
Return encoded JSON if data is a dict, list, or tuple and content_type
|
||||
is application/json.
|
||||
"""
|
||||
should_encode = JSON_CONTENT_TYPE_RE.match(content_type) and isinstance(data, (dict, list, tuple))
|
||||
return json.dumps(data, cls=self.json_encoder) if should_encode else data
|
||||
|
||||
def _get_path(self, parsed):
|
||||
path = parsed.path
|
||||
# If there are parameters, add them
|
||||
if parsed.params:
|
||||
path += ";" + parsed.params
|
||||
path = unquote_to_bytes(path)
|
||||
# Replace the behavior where non-ASCII values in the WSGI environ are
|
||||
# arbitrarily decoded with ISO-8859-1.
|
||||
# Refs comment in `get_bytes_from_wsgi()`.
|
||||
return path.decode('iso-8859-1')
|
||||
|
||||
def get(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a GET request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic('GET', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
secure=False, **extra):
|
||||
"""Construct a POST request."""
|
||||
data = self._encode_json({} if data is None else data, content_type)
|
||||
post_data = self._encode_data(data, content_type)
|
||||
|
||||
return self.generic('POST', path, post_data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def head(self, path, data=None, secure=False, **extra):
|
||||
"""Construct a HEAD request."""
|
||||
data = {} if data is None else data
|
||||
return self.generic('HEAD', path, secure=secure, **{
|
||||
'QUERY_STRING': urlencode(data, doseq=True),
|
||||
**extra,
|
||||
})
|
||||
|
||||
def trace(self, path, secure=False, **extra):
|
||||
"""Construct a TRACE request."""
|
||||
return self.generic('TRACE', path, secure=secure, **extra)
|
||||
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"Construct an OPTIONS request."
|
||||
return self.generic('OPTIONS', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PUT request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('PUT', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a PATCH request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('PATCH', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra):
|
||||
"""Construct a DELETE request."""
|
||||
data = self._encode_json(data, content_type)
|
||||
return self.generic('DELETE', path, data, content_type,
|
||||
secure=secure, **extra)
|
||||
|
||||
def generic(self, method, path, data='',
|
||||
content_type='application/octet-stream', secure=False,
|
||||
**extra):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
r = {
|
||||
'PATH_INFO': self._get_path(parsed),
|
||||
'REQUEST_METHOD': method,
|
||||
'SERVER_PORT': '443' if secure else '80',
|
||||
'wsgi.url_scheme': 'https' if secure else 'http',
|
||||
}
|
||||
if data:
|
||||
r.update({
|
||||
'CONTENT_LENGTH': str(len(data)),
|
||||
'CONTENT_TYPE': content_type,
|
||||
'wsgi.input': FakePayload(data),
|
||||
})
|
||||
r.update(extra)
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the URL.
|
||||
if not r.get('QUERY_STRING'):
|
||||
# WSGI requires latin-1 encoded strings. See get_path_info().
|
||||
query_string = parsed[4].encode().decode('iso-8859-1')
|
||||
r['QUERY_STRING'] = query_string
|
||||
return self.request(**r)
|
||||
|
||||
|
||||
class AsyncRequestFactory(RequestFactory):
|
||||
"""
|
||||
Class that lets you create mock ASGI-like Request objects for use in
|
||||
testing. Usage:
|
||||
|
||||
rf = AsyncRequestFactory()
|
||||
get_request = await rf.get('/hello/')
|
||||
post_request = await rf.post('/submit/', {'foo': 'bar'})
|
||||
|
||||
Once you have a request object you can pass it to any view function,
|
||||
including synchronous ones. The reason we have a separate class here is:
|
||||
a) this makes ASGIRequest subclasses, and
|
||||
b) AsyncTestClient can subclass it.
|
||||
"""
|
||||
def _base_scope(self, **request):
|
||||
"""The base scope for a request."""
|
||||
# This is a minimal valid ASGI scope, plus:
|
||||
# - headers['cookie'] for cookie support,
|
||||
# - 'client' often useful, see #8551.
|
||||
scope = {
|
||||
'asgi': {'version': '3.0'},
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
'client': ['127.0.0.1', 0],
|
||||
'server': ('testserver', '80'),
|
||||
'scheme': 'http',
|
||||
'method': 'GET',
|
||||
'headers': [],
|
||||
**self.defaults,
|
||||
**request,
|
||||
}
|
||||
scope['headers'].append((
|
||||
b'cookie',
|
||||
b'; '.join(sorted(
|
||||
('%s=%s' % (morsel.key, morsel.coded_value)).encode('ascii')
|
||||
for morsel in self.cookies.values()
|
||||
)),
|
||||
))
|
||||
return scope
|
||||
|
||||
def request(self, **request):
|
||||
"""Construct a generic request object."""
|
||||
# This is synchronous, which means all methods on this class are.
|
||||
# AsyncClient, however, has an async request function, which makes all
|
||||
# its methods async.
|
||||
if '_body_file' in request:
|
||||
body_file = request.pop('_body_file')
|
||||
else:
|
||||
body_file = FakePayload('')
|
||||
return ASGIRequest(self._base_scope(**request), body_file)
|
||||
|
||||
def generic(
|
||||
self, method, path, data='', content_type='application/octet-stream',
|
||||
secure=False, **extra,
|
||||
):
|
||||
"""Construct an arbitrary HTTP request."""
|
||||
parsed = urlparse(str(path)) # path can be lazy.
|
||||
data = force_bytes(data, settings.DEFAULT_CHARSET)
|
||||
s = {
|
||||
'method': method,
|
||||
'path': self._get_path(parsed),
|
||||
'server': ('127.0.0.1', '443' if secure else '80'),
|
||||
'scheme': 'https' if secure else 'http',
|
||||
'headers': [(b'host', b'testserver')],
|
||||
}
|
||||
if data:
|
||||
s['headers'].extend([
|
||||
(b'content-length', bytes(len(data))),
|
||||
(b'content-type', content_type.encode('ascii')),
|
||||
])
|
||||
s['_body_file'] = FakePayload(data)
|
||||
s.update(extra)
|
||||
# If QUERY_STRING is absent or empty, we want to extract it from the
|
||||
# URL.
|
||||
if not s.get('query_string'):
|
||||
s['query_string'] = parsed[4]
|
||||
return self.request(**s)
|
||||
|
||||
|
||||
class ClientMixin:
|
||||
"""
|
||||
Mixin with common methods between Client and AsyncClient.
|
||||
"""
|
||||
def store_exc_info(self, **kwargs):
|
||||
"""Store exceptions when they are generated by a view."""
|
||||
self.exc_info = sys.exc_info()
|
||||
|
||||
def check_exception(self, response):
|
||||
"""
|
||||
Look for a signaled exception, clear the current context exception
|
||||
data, re-raise the signaled exception, and clear the signaled exception
|
||||
from the local cache.
|
||||
"""
|
||||
response.exc_info = self.exc_info
|
||||
if self.exc_info:
|
||||
_, exc_value, _ = self.exc_info
|
||||
self.exc_info = None
|
||||
if self.raise_request_exception:
|
||||
raise exc_value
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Return the current session variables."""
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
|
||||
if cookie:
|
||||
return engine.SessionStore(cookie.value)
|
||||
session = engine.SessionStore()
|
||||
session.save()
|
||||
self.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
return session
|
||||
|
||||
def login(self, **credentials):
|
||||
"""
|
||||
Set the Factory to appear as if it has successfully logged into a site.
|
||||
|
||||
Return True if login is possible or False if the provided credentials
|
||||
are incorrect.
|
||||
"""
|
||||
from django.contrib.auth import authenticate
|
||||
user = authenticate(**credentials)
|
||||
if user:
|
||||
self._login(user)
|
||||
return True
|
||||
return False
|
||||
|
||||
def force_login(self, user, backend=None):
|
||||
def get_backend():
|
||||
from django.contrib.auth import load_backend
|
||||
for backend_path in settings.AUTHENTICATION_BACKENDS:
|
||||
backend = load_backend(backend_path)
|
||||
if hasattr(backend, 'get_user'):
|
||||
return backend_path
|
||||
|
||||
if backend is None:
|
||||
backend = get_backend()
|
||||
user.backend = backend
|
||||
self._login(user, backend)
|
||||
|
||||
def _login(self, user, backend=None):
|
||||
from django.contrib.auth import login
|
||||
|
||||
# Create a fake request to store login details.
|
||||
request = HttpRequest()
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
else:
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
request.session = engine.SessionStore()
|
||||
login(request, user, backend)
|
||||
# Save the session values.
|
||||
request.session.save()
|
||||
# Set the cookie to represent the session.
|
||||
session_cookie = settings.SESSION_COOKIE_NAME
|
||||
self.cookies[session_cookie] = request.session.session_key
|
||||
cookie_data = {
|
||||
'max-age': None,
|
||||
'path': '/',
|
||||
'domain': settings.SESSION_COOKIE_DOMAIN,
|
||||
'secure': settings.SESSION_COOKIE_SECURE or None,
|
||||
'expires': None,
|
||||
}
|
||||
self.cookies[session_cookie].update(cookie_data)
|
||||
|
||||
def logout(self):
|
||||
"""Log out the user by removing the cookies and session object."""
|
||||
from django.contrib.auth import get_user, logout
|
||||
request = HttpRequest()
|
||||
if self.session:
|
||||
request.session = self.session
|
||||
request.user = get_user(request)
|
||||
else:
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
request.session = engine.SessionStore()
|
||||
logout(request)
|
||||
self.cookies = SimpleCookie()
|
||||
|
||||
def _parse_json(self, response, **extra):
|
||||
if not hasattr(response, '_json'):
|
||||
if not JSON_CONTENT_TYPE_RE.match(response.get('Content-Type')):
|
||||
raise ValueError(
|
||||
'Content-Type header is "%s", not "application/json"'
|
||||
% response.get('Content-Type')
|
||||
)
|
||||
response._json = json.loads(response.content.decode(response.charset), **extra)
|
||||
return response._json
|
||||
|
||||
|
||||
class Client(ClientMixin, RequestFactory):
|
||||
"""
|
||||
A class that can act as a client for testing purposes.
|
||||
|
||||
It allows the user to compose GET and POST requests, and
|
||||
obtain the response that the server gave to those requests.
|
||||
The server Response objects are annotated with the details
|
||||
of the contexts and templates that were rendered during the
|
||||
process of serving the request.
|
||||
|
||||
Client objects are stateful - they will retain cookie (and
|
||||
thus session) details for the lifetime of the Client instance.
|
||||
|
||||
This is not intended as a replacement for Twill/Selenium or
|
||||
the like - it is here to allow testing against the
|
||||
contexts and templates produced by a view, rather than the
|
||||
HTML rendered to the end-user.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = ClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
self.extra = None
|
||||
|
||||
def request(self, **request):
|
||||
"""
|
||||
The master request method. Compose the environment dictionary and pass
|
||||
to the handler, return the result of the handler. Assume defaults for
|
||||
the query environment, which can be overridden using the arguments to
|
||||
the request.
|
||||
"""
|
||||
environ = self._base_environ(**request)
|
||||
|
||||
# Curry a data dictionary into an instance of the template renderer
|
||||
# callback function.
|
||||
data = {}
|
||||
on_template_render = partial(store_rendered_templates, data)
|
||||
signal_uid = "template-render-%s" % id(request)
|
||||
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
||||
# Capture exceptions created by the handler.
|
||||
exception_uid = "request-exception-%s" % id(request)
|
||||
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
||||
try:
|
||||
response = self.handler(environ)
|
||||
finally:
|
||||
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
||||
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
||||
# Check for signaled exceptions.
|
||||
self.check_exception(response)
|
||||
# Save the client and request that stimulated the response.
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO']))
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
if response.context and len(response.context) == 1:
|
||||
response.context = response.context[0]
|
||||
# Update persistent cookie data.
|
||||
if response.cookies:
|
||||
self.cookies.update(response.cookies)
|
||||
return response
|
||||
|
||||
def get(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using GET."""
|
||||
self.extra = extra
|
||||
response = super().get(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def post(self, path, data=None, content_type=MULTIPART_CONTENT,
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using POST."""
|
||||
self.extra = extra
|
||||
response = super().post(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def head(self, path, data=None, follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using HEAD."""
|
||||
self.extra = extra
|
||||
response = super().head(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def options(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Request a response from the server using OPTIONS."""
|
||||
self.extra = extra
|
||||
response = super().options(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def put(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PUT."""
|
||||
self.extra = extra
|
||||
response = super().put(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def patch(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a resource to the server using PATCH."""
|
||||
self.extra = extra
|
||||
response = super().patch(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def delete(self, path, data='', content_type='application/octet-stream',
|
||||
follow=False, secure=False, **extra):
|
||||
"""Send a DELETE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().delete(path, data=data, content_type=content_type, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, content_type=content_type, **extra)
|
||||
return response
|
||||
|
||||
def trace(self, path, data='', follow=False, secure=False, **extra):
|
||||
"""Send a TRACE request to the server."""
|
||||
self.extra = extra
|
||||
response = super().trace(path, data=data, secure=secure, **extra)
|
||||
if follow:
|
||||
response = self._handle_redirects(response, data=data, **extra)
|
||||
return response
|
||||
|
||||
def _handle_redirects(self, response, data='', content_type='', **extra):
|
||||
"""
|
||||
Follow any redirects by requesting responses from the server using GET.
|
||||
"""
|
||||
response.redirect_chain = []
|
||||
redirect_status_codes = (
|
||||
HTTPStatus.MOVED_PERMANENTLY,
|
||||
HTTPStatus.FOUND,
|
||||
HTTPStatus.SEE_OTHER,
|
||||
HTTPStatus.TEMPORARY_REDIRECT,
|
||||
HTTPStatus.PERMANENT_REDIRECT,
|
||||
)
|
||||
while response.status_code in redirect_status_codes:
|
||||
response_url = response.url
|
||||
redirect_chain = response.redirect_chain
|
||||
redirect_chain.append((response_url, response.status_code))
|
||||
|
||||
url = urlsplit(response_url)
|
||||
if url.scheme:
|
||||
extra['wsgi.url_scheme'] = url.scheme
|
||||
if url.hostname:
|
||||
extra['SERVER_NAME'] = url.hostname
|
||||
if url.port:
|
||||
extra['SERVER_PORT'] = str(url.port)
|
||||
|
||||
# Prepend the request path to handle relative path redirects
|
||||
path = url.path
|
||||
if not path.startswith('/'):
|
||||
path = urljoin(response.request['PATH_INFO'], path)
|
||||
|
||||
if response.status_code in (HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT):
|
||||
# Preserve request method post-redirect for 307/308 responses.
|
||||
request_method = getattr(self, response.request['REQUEST_METHOD'].lower())
|
||||
else:
|
||||
request_method = self.get
|
||||
data = QueryDict(url.query)
|
||||
content_type = None
|
||||
|
||||
response = request_method(path, data=data, content_type=content_type, follow=False, **extra)
|
||||
response.redirect_chain = redirect_chain
|
||||
|
||||
if redirect_chain[-1] in redirect_chain[:-1]:
|
||||
# Check that we're not redirecting to somewhere we've already
|
||||
# been to, to prevent loops.
|
||||
raise RedirectCycleError("Redirect loop detected.", last_response=response)
|
||||
if len(redirect_chain) > 20:
|
||||
# Such a lengthy chain likely also means a loop, but one with
|
||||
# a growing path, changing view, or changing query argument;
|
||||
# 20 is the value of "network.http.redirection-limit" from Firefox.
|
||||
raise RedirectCycleError("Too many redirects.", last_response=response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AsyncClient(ClientMixin, AsyncRequestFactory):
|
||||
"""
|
||||
An async version of Client that creates ASGIRequests and calls through an
|
||||
async request path.
|
||||
|
||||
Does not currently support "follow" on its methods.
|
||||
"""
|
||||
def __init__(self, enforce_csrf_checks=False, raise_request_exception=True, **defaults):
|
||||
super().__init__(**defaults)
|
||||
self.handler = AsyncClientHandler(enforce_csrf_checks)
|
||||
self.raise_request_exception = raise_request_exception
|
||||
self.exc_info = None
|
||||
self.extra = None
|
||||
|
||||
async def request(self, **request):
|
||||
"""
|
||||
The master request method. Compose the scope dictionary and pass to the
|
||||
handler, return the result of the handler. Assume defaults for the
|
||||
query environment, which can be overridden using the arguments to the
|
||||
request.
|
||||
"""
|
||||
if 'follow' in request:
|
||||
raise NotImplementedError(
|
||||
'AsyncClient request methods do not accept the follow '
|
||||
'parameter.'
|
||||
)
|
||||
scope = self._base_scope(**request)
|
||||
# Curry a data dictionary into an instance of the template renderer
|
||||
# callback function.
|
||||
data = {}
|
||||
on_template_render = partial(store_rendered_templates, data)
|
||||
signal_uid = 'template-render-%s' % id(request)
|
||||
signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
|
||||
# Capture exceptions created by the handler.
|
||||
exception_uid = 'request-exception-%s' % id(request)
|
||||
got_request_exception.connect(self.store_exc_info, dispatch_uid=exception_uid)
|
||||
try:
|
||||
response = await self.handler(scope)
|
||||
finally:
|
||||
signals.template_rendered.disconnect(dispatch_uid=signal_uid)
|
||||
got_request_exception.disconnect(dispatch_uid=exception_uid)
|
||||
# Check for signaled exceptions.
|
||||
self.check_exception(response)
|
||||
# Save the client and request that stimulated the response.
|
||||
response.client = self
|
||||
response.request = request
|
||||
# Add any rendered template detail to the response.
|
||||
response.templates = data.get('templates', [])
|
||||
response.context = data.get('context')
|
||||
response.json = partial(self._parse_json, response)
|
||||
# Attach the ResolverMatch instance to the response.
|
||||
response.resolver_match = SimpleLazyObject(lambda: resolve(request['path']))
|
||||
# Flatten a single context. Not really necessary anymore thanks to the
|
||||
# __getattr__ flattening in ContextList, but has some edge case
|
||||
# backwards compatibility implications.
|
||||
if response.context and len(response.context) == 1:
|
||||
response.context = response.context[0]
|
||||
# Update persistent cookie data.
|
||||
if response.cookies:
|
||||
self.cookies.update(response.cookies)
|
||||
return response
|
||||
229
Lib/site-packages/django/test/html.py
Normal file
229
Lib/site-packages/django/test/html.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Compare two HTML documents."""
|
||||
|
||||
from html.parser import HTMLParser
|
||||
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
# ASCII whitespace is U+0009 TAB, U+000A LF, U+000C FF, U+000D CR, or U+0020
|
||||
# SPACE.
|
||||
# https://infra.spec.whatwg.org/#ascii-whitespace
|
||||
ASCII_WHITESPACE = _lazy_re_compile(r'[\t\n\f\r ]+')
|
||||
|
||||
|
||||
def normalize_whitespace(string):
|
||||
return ASCII_WHITESPACE.sub(' ', string)
|
||||
|
||||
|
||||
class Element:
|
||||
def __init__(self, name, attributes):
|
||||
self.name = name
|
||||
self.attributes = sorted(attributes)
|
||||
self.children = []
|
||||
|
||||
def append(self, element):
|
||||
if isinstance(element, str):
|
||||
element = normalize_whitespace(element)
|
||||
if self.children:
|
||||
if isinstance(self.children[-1], str):
|
||||
self.children[-1] += element
|
||||
self.children[-1] = normalize_whitespace(self.children[-1])
|
||||
return
|
||||
elif self.children:
|
||||
# removing last children if it is only whitespace
|
||||
# this can result in incorrect dom representations since
|
||||
# whitespace between inline tags like <span> is significant
|
||||
if isinstance(self.children[-1], str):
|
||||
if self.children[-1].isspace():
|
||||
self.children.pop()
|
||||
if element:
|
||||
self.children.append(element)
|
||||
|
||||
def finalize(self):
|
||||
def rstrip_last_element(children):
|
||||
if children:
|
||||
if isinstance(children[-1], str):
|
||||
children[-1] = children[-1].rstrip()
|
||||
if not children[-1]:
|
||||
children.pop()
|
||||
children = rstrip_last_element(children)
|
||||
return children
|
||||
|
||||
rstrip_last_element(self.children)
|
||||
for i, child in enumerate(self.children):
|
||||
if isinstance(child, str):
|
||||
self.children[i] = child.strip()
|
||||
elif hasattr(child, 'finalize'):
|
||||
child.finalize()
|
||||
|
||||
def __eq__(self, element):
|
||||
if not hasattr(element, 'name') or self.name != element.name:
|
||||
return False
|
||||
if len(self.attributes) != len(element.attributes):
|
||||
return False
|
||||
if self.attributes != element.attributes:
|
||||
# attributes without a value is same as attribute with value that
|
||||
# equals the attributes name:
|
||||
# <input checked> == <input checked="checked">
|
||||
for i in range(len(self.attributes)):
|
||||
attr, value = self.attributes[i]
|
||||
other_attr, other_value = element.attributes[i]
|
||||
if value is None:
|
||||
value = attr
|
||||
if other_value is None:
|
||||
other_value = other_attr
|
||||
if attr != other_attr or value != other_value:
|
||||
return False
|
||||
return self.children == element.children
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.name, *self.attributes))
|
||||
|
||||
def _count(self, element, count=True):
|
||||
if not isinstance(element, str):
|
||||
if self == element:
|
||||
return 1
|
||||
if isinstance(element, RootElement):
|
||||
if self.children == element.children:
|
||||
return 1
|
||||
i = 0
|
||||
for child in self.children:
|
||||
# child is text content and element is also text content, then
|
||||
# make a simple "text" in "text"
|
||||
if isinstance(child, str):
|
||||
if isinstance(element, str):
|
||||
if count:
|
||||
i += child.count(element)
|
||||
elif element in child:
|
||||
return 1
|
||||
else:
|
||||
i += child._count(element, count=count)
|
||||
if not count and i:
|
||||
return i
|
||||
return i
|
||||
|
||||
def __contains__(self, element):
|
||||
return self._count(element, count=False) > 0
|
||||
|
||||
def count(self, element):
|
||||
return self._count(element, count=True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.children[key]
|
||||
|
||||
def __str__(self):
|
||||
output = '<%s' % self.name
|
||||
for key, value in self.attributes:
|
||||
if value:
|
||||
output += ' %s="%s"' % (key, value)
|
||||
else:
|
||||
output += ' %s' % key
|
||||
if self.children:
|
||||
output += '>\n'
|
||||
output += ''.join(str(c) for c in self.children)
|
||||
output += '\n</%s>' % self.name
|
||||
else:
|
||||
output += '>'
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class RootElement(Element):
|
||||
def __init__(self):
|
||||
super().__init__(None, ())
|
||||
|
||||
def __str__(self):
|
||||
return ''.join(str(c) for c in self.children)
|
||||
|
||||
|
||||
class HTMLParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Parser(HTMLParser):
|
||||
# https://html.spec.whatwg.org/#void-elements
|
||||
SELF_CLOSING_TAGS = {
|
||||
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta',
|
||||
'param', 'source', 'track', 'wbr',
|
||||
# Deprecated tags
|
||||
'frame', 'spacer',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.root = RootElement()
|
||||
self.open_tags = []
|
||||
self.element_positions = {}
|
||||
|
||||
def error(self, msg):
|
||||
raise HTMLParseError(msg, self.getpos())
|
||||
|
||||
def format_position(self, position=None, element=None):
|
||||
if not position and element:
|
||||
position = self.element_positions[element]
|
||||
if position is None:
|
||||
position = self.getpos()
|
||||
if hasattr(position, 'lineno'):
|
||||
position = position.lineno, position.offset
|
||||
return 'Line %d, Column %d' % position
|
||||
|
||||
@property
|
||||
def current(self):
|
||||
if self.open_tags:
|
||||
return self.open_tags[-1]
|
||||
else:
|
||||
return self.root
|
||||
|
||||
def handle_startendtag(self, tag, attrs):
|
||||
self.handle_starttag(tag, attrs)
|
||||
if tag not in self.SELF_CLOSING_TAGS:
|
||||
self.handle_endtag(tag)
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
# Special case handling of 'class' attribute, so that comparisons of DOM
|
||||
# instances are not sensitive to ordering of classes.
|
||||
attrs = [
|
||||
(name, ' '.join(sorted(value for value in ASCII_WHITESPACE.split(value) if value)))
|
||||
if name == "class"
|
||||
else (name, value)
|
||||
for name, value in attrs
|
||||
]
|
||||
element = Element(tag, attrs)
|
||||
self.current.append(element)
|
||||
if tag not in self.SELF_CLOSING_TAGS:
|
||||
self.open_tags.append(element)
|
||||
self.element_positions[element] = self.getpos()
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if not self.open_tags:
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
while element.name != tag:
|
||||
if not self.open_tags:
|
||||
self.error("Unexpected end tag `%s` (%s)" % (
|
||||
tag, self.format_position()))
|
||||
element = self.open_tags.pop()
|
||||
|
||||
def handle_data(self, data):
|
||||
self.current.append(data)
|
||||
|
||||
|
||||
def parse_html(html):
|
||||
"""
|
||||
Take a string that contains *valid* HTML and turn it into a Python object
|
||||
structure that can be easily compared against other HTML on semantic
|
||||
equivalence. Syntactical differences like which quotation is used on
|
||||
arguments will be ignored.
|
||||
"""
|
||||
parser = Parser()
|
||||
parser.feed(html)
|
||||
parser.close()
|
||||
document = parser.root
|
||||
document.finalize()
|
||||
# Removing ROOT element if it's not necessary
|
||||
if len(document.children) == 1:
|
||||
if not isinstance(document.children[0], str):
|
||||
document = document.children[0]
|
||||
return document
|
||||
809
Lib/site-packages/django/test/runner.py
Normal file
809
Lib/site-packages/django/test/runner.py
Normal file
@@ -0,0 +1,809 @@
|
||||
import ctypes
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import textwrap
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.db import connections
|
||||
from django.test import SimpleTestCase, TestCase
|
||||
from django.test.utils import (
|
||||
setup_databases as _setup_databases, setup_test_environment,
|
||||
teardown_databases as _teardown_databases, teardown_test_environment,
|
||||
)
|
||||
from django.utils.datastructures import OrderedSet
|
||||
from django.utils.version import PY37
|
||||
|
||||
try:
|
||||
import ipdb as pdb
|
||||
except ImportError:
|
||||
import pdb
|
||||
|
||||
try:
|
||||
import tblib.pickling_support
|
||||
except ImportError:
|
||||
tblib = None
|
||||
|
||||
|
||||
class DebugSQLTextTestResult(unittest.TextTestResult):
|
||||
def __init__(self, stream, descriptions, verbosity):
|
||||
self.logger = logging.getLogger('django.db.backends')
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
super().__init__(stream, descriptions, verbosity)
|
||||
|
||||
def startTest(self, test):
|
||||
self.debug_sql_stream = StringIO()
|
||||
self.handler = logging.StreamHandler(self.debug_sql_stream)
|
||||
self.logger.addHandler(self.handler)
|
||||
super().startTest(test)
|
||||
|
||||
def stopTest(self, test):
|
||||
super().stopTest(test)
|
||||
self.logger.removeHandler(self.handler)
|
||||
if self.showAll:
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.stream.write(self.debug_sql_stream.read())
|
||||
self.stream.writeln(self.separator2)
|
||||
|
||||
def addError(self, test, err):
|
||||
super().addError(test, err)
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.errors[-1] = self.errors[-1] + (self.debug_sql_stream.read(),)
|
||||
|
||||
def addFailure(self, test, err):
|
||||
super().addFailure(test, err)
|
||||
self.debug_sql_stream.seek(0)
|
||||
self.failures[-1] = self.failures[-1] + (self.debug_sql_stream.read(),)
|
||||
|
||||
def addSubTest(self, test, subtest, err):
|
||||
super().addSubTest(test, subtest, err)
|
||||
if err is not None:
|
||||
self.debug_sql_stream.seek(0)
|
||||
errors = self.failures if issubclass(err[0], test.failureException) else self.errors
|
||||
errors[-1] = errors[-1] + (self.debug_sql_stream.read(),)
|
||||
|
||||
def printErrorList(self, flavour, errors):
|
||||
for test, err, sql_debug in errors:
|
||||
self.stream.writeln(self.separator1)
|
||||
self.stream.writeln("%s: %s" % (flavour, self.getDescription(test)))
|
||||
self.stream.writeln(self.separator2)
|
||||
self.stream.writeln(err)
|
||||
self.stream.writeln(self.separator2)
|
||||
self.stream.writeln(sql_debug)
|
||||
|
||||
|
||||
class PDBDebugResult(unittest.TextTestResult):
|
||||
"""
|
||||
Custom result class that triggers a PDB session when an error or failure
|
||||
occurs.
|
||||
"""
|
||||
|
||||
def addError(self, test, err):
|
||||
super().addError(test, err)
|
||||
self.debug(err)
|
||||
|
||||
def addFailure(self, test, err):
|
||||
super().addFailure(test, err)
|
||||
self.debug(err)
|
||||
|
||||
def debug(self, error):
|
||||
exc_type, exc_value, traceback = error
|
||||
print("\nOpening PDB: %r" % exc_value)
|
||||
pdb.post_mortem(traceback)
|
||||
|
||||
|
||||
class RemoteTestResult:
|
||||
"""
|
||||
Record information about which tests have succeeded and which have failed.
|
||||
|
||||
The sole purpose of this class is to record events in the child processes
|
||||
so they can be replayed in the master process. As a consequence it doesn't
|
||||
inherit unittest.TestResult and doesn't attempt to implement all its API.
|
||||
|
||||
The implementation matches the unpythonic coding style of unittest2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if tblib is not None:
|
||||
tblib.pickling_support.install()
|
||||
|
||||
self.events = []
|
||||
self.failfast = False
|
||||
self.shouldStop = False
|
||||
self.testsRun = 0
|
||||
|
||||
@property
|
||||
def test_index(self):
|
||||
return self.testsRun - 1
|
||||
|
||||
def _confirm_picklable(self, obj):
|
||||
"""
|
||||
Confirm that obj can be pickled and unpickled as multiprocessing will
|
||||
need to pickle the exception in the child process and unpickle it in
|
||||
the parent process. Let the exception rise, if not.
|
||||
"""
|
||||
pickle.loads(pickle.dumps(obj))
|
||||
|
||||
def _print_unpicklable_subtest(self, test, subtest, pickle_exc):
|
||||
print("""
|
||||
Subtest failed:
|
||||
|
||||
test: {}
|
||||
subtest: {}
|
||||
|
||||
Unfortunately, the subtest that failed cannot be pickled, so the parallel
|
||||
test runner cannot handle it cleanly. Here is the pickling error:
|
||||
|
||||
> {}
|
||||
|
||||
You should re-run this test with --parallel=1 to reproduce the failure
|
||||
with a cleaner failure message.
|
||||
""".format(test, subtest, pickle_exc))
|
||||
|
||||
def check_picklable(self, test, err):
|
||||
# Ensure that sys.exc_info() tuples are picklable. This displays a
|
||||
# clear multiprocessing.pool.RemoteTraceback generated in the child
|
||||
# process instead of a multiprocessing.pool.MaybeEncodingError, making
|
||||
# the root cause easier to figure out for users who aren't familiar
|
||||
# with the multiprocessing module. Since we're in a forked process,
|
||||
# our best chance to communicate with them is to print to stdout.
|
||||
try:
|
||||
self._confirm_picklable(err)
|
||||
except Exception as exc:
|
||||
original_exc_txt = repr(err[1])
|
||||
original_exc_txt = textwrap.fill(original_exc_txt, 75, initial_indent=' ', subsequent_indent=' ')
|
||||
pickle_exc_txt = repr(exc)
|
||||
pickle_exc_txt = textwrap.fill(pickle_exc_txt, 75, initial_indent=' ', subsequent_indent=' ')
|
||||
if tblib is None:
|
||||
print("""
|
||||
|
||||
{} failed:
|
||||
|
||||
{}
|
||||
|
||||
Unfortunately, tracebacks cannot be pickled, making it impossible for the
|
||||
parallel test runner to handle this exception cleanly.
|
||||
|
||||
In order to see the traceback, you should install tblib:
|
||||
|
||||
python -m pip install tblib
|
||||
""".format(test, original_exc_txt))
|
||||
else:
|
||||
print("""
|
||||
|
||||
{} failed:
|
||||
|
||||
{}
|
||||
|
||||
Unfortunately, the exception it raised cannot be pickled, making it impossible
|
||||
for the parallel test runner to handle it cleanly.
|
||||
|
||||
Here's the error encountered while trying to pickle the exception:
|
||||
|
||||
{}
|
||||
|
||||
You should re-run this test with the --parallel=1 option to reproduce the
|
||||
failure and get a correct traceback.
|
||||
""".format(test, original_exc_txt, pickle_exc_txt))
|
||||
raise
|
||||
|
||||
def check_subtest_picklable(self, test, subtest):
|
||||
try:
|
||||
self._confirm_picklable(subtest)
|
||||
except Exception as exc:
|
||||
self._print_unpicklable_subtest(test, subtest, exc)
|
||||
raise
|
||||
|
||||
def stop_if_failfast(self):
|
||||
if self.failfast:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
self.shouldStop = True
|
||||
|
||||
def startTestRun(self):
|
||||
self.events.append(('startTestRun',))
|
||||
|
||||
def stopTestRun(self):
|
||||
self.events.append(('stopTestRun',))
|
||||
|
||||
def startTest(self, test):
|
||||
self.testsRun += 1
|
||||
self.events.append(('startTest', self.test_index))
|
||||
|
||||
def stopTest(self, test):
|
||||
self.events.append(('stopTest', self.test_index))
|
||||
|
||||
def addError(self, test, err):
|
||||
self.check_picklable(test, err)
|
||||
self.events.append(('addError', self.test_index, err))
|
||||
self.stop_if_failfast()
|
||||
|
||||
def addFailure(self, test, err):
|
||||
self.check_picklable(test, err)
|
||||
self.events.append(('addFailure', self.test_index, err))
|
||||
self.stop_if_failfast()
|
||||
|
||||
def addSubTest(self, test, subtest, err):
|
||||
# Follow Python 3.5's implementation of unittest.TestResult.addSubTest()
|
||||
# by not doing anything when a subtest is successful.
|
||||
if err is not None:
|
||||
# Call check_picklable() before check_subtest_picklable() since
|
||||
# check_picklable() performs the tblib check.
|
||||
self.check_picklable(test, err)
|
||||
self.check_subtest_picklable(test, subtest)
|
||||
self.events.append(('addSubTest', self.test_index, subtest, err))
|
||||
self.stop_if_failfast()
|
||||
|
||||
def addSuccess(self, test):
|
||||
self.events.append(('addSuccess', self.test_index))
|
||||
|
||||
def addSkip(self, test, reason):
|
||||
self.events.append(('addSkip', self.test_index, reason))
|
||||
|
||||
def addExpectedFailure(self, test, err):
|
||||
# If tblib isn't installed, pickling the traceback will always fail.
|
||||
# However we don't want tblib to be required for running the tests
|
||||
# when they pass or fail as expected. Drop the traceback when an
|
||||
# expected failure occurs.
|
||||
if tblib is None:
|
||||
err = err[0], err[1], None
|
||||
self.check_picklable(test, err)
|
||||
self.events.append(('addExpectedFailure', self.test_index, err))
|
||||
|
||||
def addUnexpectedSuccess(self, test):
|
||||
self.events.append(('addUnexpectedSuccess', self.test_index))
|
||||
self.stop_if_failfast()
|
||||
|
||||
|
||||
class RemoteTestRunner:
|
||||
"""
|
||||
Run tests and record everything but don't display anything.
|
||||
|
||||
The implementation matches the unpythonic coding style of unittest2.
|
||||
"""
|
||||
|
||||
resultclass = RemoteTestResult
|
||||
|
||||
def __init__(self, failfast=False, resultclass=None):
|
||||
self.failfast = failfast
|
||||
if resultclass is not None:
|
||||
self.resultclass = resultclass
|
||||
|
||||
def run(self, test):
|
||||
result = self.resultclass()
|
||||
unittest.registerResult(result)
|
||||
result.failfast = self.failfast
|
||||
test(result)
|
||||
return result
|
||||
|
||||
|
||||
def default_test_processes():
|
||||
"""Default number of test processes when using the --parallel option."""
|
||||
# The current implementation of the parallel test runner requires
|
||||
# multiprocessing to start subprocesses with fork().
|
||||
if multiprocessing.get_start_method() != 'fork':
|
||||
return 1
|
||||
try:
|
||||
return int(os.environ['DJANGO_TEST_PROCESSES'])
|
||||
except KeyError:
|
||||
return multiprocessing.cpu_count()
|
||||
|
||||
|
||||
_worker_id = 0
|
||||
|
||||
|
||||
def _init_worker(counter):
|
||||
"""
|
||||
Switch to databases dedicated to this worker.
|
||||
|
||||
This helper lives at module-level because of the multiprocessing module's
|
||||
requirements.
|
||||
"""
|
||||
|
||||
global _worker_id
|
||||
|
||||
with counter.get_lock():
|
||||
counter.value += 1
|
||||
_worker_id = counter.value
|
||||
|
||||
for alias in connections:
|
||||
connection = connections[alias]
|
||||
settings_dict = connection.creation.get_test_db_clone_settings(str(_worker_id))
|
||||
# connection.settings_dict must be updated in place for changes to be
|
||||
# reflected in django.db.connections. If the following line assigned
|
||||
# connection.settings_dict = settings_dict, new threads would connect
|
||||
# to the default database instead of the appropriate clone.
|
||||
connection.settings_dict.update(settings_dict)
|
||||
connection.close()
|
||||
|
||||
|
||||
def _run_subsuite(args):
|
||||
"""
|
||||
Run a suite of tests with a RemoteTestRunner and return a RemoteTestResult.
|
||||
|
||||
This helper lives at module-level and its arguments are wrapped in a tuple
|
||||
because of the multiprocessing module's requirements.
|
||||
"""
|
||||
runner_class, subsuite_index, subsuite, failfast = args
|
||||
runner = runner_class(failfast=failfast)
|
||||
result = runner.run(subsuite)
|
||||
return subsuite_index, result.events
|
||||
|
||||
|
||||
class ParallelTestSuite(unittest.TestSuite):
|
||||
"""
|
||||
Run a series of tests in parallel in several processes.
|
||||
|
||||
While the unittest module's documentation implies that orchestrating the
|
||||
execution of tests is the responsibility of the test runner, in practice,
|
||||
it appears that TestRunner classes are more concerned with formatting and
|
||||
displaying test results.
|
||||
|
||||
Since there are fewer use cases for customizing TestSuite than TestRunner,
|
||||
implementing parallelization at the level of the TestSuite improves
|
||||
interoperability with existing custom test runners. A single instance of a
|
||||
test runner can still collect results from all tests without being aware
|
||||
that they have been run in parallel.
|
||||
"""
|
||||
|
||||
# In case someone wants to modify these in a subclass.
|
||||
init_worker = _init_worker
|
||||
run_subsuite = _run_subsuite
|
||||
runner_class = RemoteTestRunner
|
||||
|
||||
def __init__(self, suite, processes, failfast=False):
|
||||
self.subsuites = partition_suite_by_case(suite)
|
||||
self.processes = processes
|
||||
self.failfast = failfast
|
||||
super().__init__()
|
||||
|
||||
def run(self, result):
|
||||
"""
|
||||
Distribute test cases across workers.
|
||||
|
||||
Return an identifier of each test case with its result in order to use
|
||||
imap_unordered to show results as soon as they're available.
|
||||
|
||||
To minimize pickling errors when getting results from workers:
|
||||
|
||||
- pass back numeric indexes in self.subsuites instead of tests
|
||||
- make tracebacks picklable with tblib, if available
|
||||
|
||||
Even with tblib, errors may still occur for dynamically created
|
||||
exception classes which cannot be unpickled.
|
||||
"""
|
||||
counter = multiprocessing.Value(ctypes.c_int, 0)
|
||||
pool = multiprocessing.Pool(
|
||||
processes=self.processes,
|
||||
initializer=self.init_worker.__func__,
|
||||
initargs=[counter],
|
||||
)
|
||||
args = [
|
||||
(self.runner_class, index, subsuite, self.failfast)
|
||||
for index, subsuite in enumerate(self.subsuites)
|
||||
]
|
||||
test_results = pool.imap_unordered(self.run_subsuite.__func__, args)
|
||||
|
||||
while True:
|
||||
if result.shouldStop:
|
||||
pool.terminate()
|
||||
break
|
||||
|
||||
try:
|
||||
subsuite_index, events = test_results.next(timeout=0.1)
|
||||
except multiprocessing.TimeoutError:
|
||||
continue
|
||||
except StopIteration:
|
||||
pool.close()
|
||||
break
|
||||
|
||||
tests = list(self.subsuites[subsuite_index])
|
||||
for event in events:
|
||||
event_name = event[0]
|
||||
handler = getattr(result, event_name, None)
|
||||
if handler is None:
|
||||
continue
|
||||
test = tests[event[1]]
|
||||
args = event[2:]
|
||||
handler(test, *args)
|
||||
|
||||
pool.join()
|
||||
|
||||
return result
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.subsuites)
|
||||
|
||||
|
||||
class DiscoverRunner:
|
||||
"""A Django test runner that uses unittest2 test discovery."""
|
||||
|
||||
test_suite = unittest.TestSuite
|
||||
parallel_test_suite = ParallelTestSuite
|
||||
test_runner = unittest.TextTestRunner
|
||||
test_loader = unittest.defaultTestLoader
|
||||
reorder_by = (TestCase, SimpleTestCase)
|
||||
|
||||
def __init__(self, pattern=None, top_level=None, verbosity=1,
|
||||
interactive=True, failfast=False, keepdb=False,
|
||||
reverse=False, debug_mode=False, debug_sql=False, parallel=0,
|
||||
tags=None, exclude_tags=None, test_name_patterns=None,
|
||||
pdb=False, buffer=False, **kwargs):
|
||||
|
||||
self.pattern = pattern
|
||||
self.top_level = top_level
|
||||
self.verbosity = verbosity
|
||||
self.interactive = interactive
|
||||
self.failfast = failfast
|
||||
self.keepdb = keepdb
|
||||
self.reverse = reverse
|
||||
self.debug_mode = debug_mode
|
||||
self.debug_sql = debug_sql
|
||||
self.parallel = parallel
|
||||
self.tags = set(tags or [])
|
||||
self.exclude_tags = set(exclude_tags or [])
|
||||
self.pdb = pdb
|
||||
if self.pdb and self.parallel > 1:
|
||||
raise ValueError('You cannot use --pdb with parallel tests; pass --parallel=1 to use it.')
|
||||
self.buffer = buffer
|
||||
if self.buffer and self.parallel > 1:
|
||||
raise ValueError(
|
||||
'You cannot use -b/--buffer with parallel tests; pass '
|
||||
'--parallel=1 to use it.'
|
||||
)
|
||||
self.test_name_patterns = None
|
||||
if test_name_patterns:
|
||||
# unittest does not export the _convert_select_pattern function
|
||||
# that converts command-line arguments to patterns.
|
||||
self.test_name_patterns = {
|
||||
pattern if '*' in pattern else '*%s*' % pattern
|
||||
for pattern in test_name_patterns
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
parser.add_argument(
|
||||
'-t', '--top-level-directory', dest='top_level',
|
||||
help='Top level of project for unittest discovery.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-p', '--pattern', default="test*.py",
|
||||
help='The test matching pattern. Defaults to test*.py.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--keepdb', action='store_true',
|
||||
help='Preserves the test DB between runs.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-r', '--reverse', action='store_true',
|
||||
help='Reverses test cases order.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--debug-mode', action='store_true',
|
||||
help='Sets settings.DEBUG to True.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-d', '--debug-sql', action='store_true',
|
||||
help='Prints logged SQL queries on failure.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--parallel', nargs='?', default=1, type=int,
|
||||
const=default_test_processes(), metavar='N',
|
||||
help='Run tests using up to N parallel processes.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tag', action='append', dest='tags',
|
||||
help='Run only tests with the specified tag. Can be used multiple times.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--exclude-tag', action='append', dest='exclude_tags',
|
||||
help='Do not run tests with the specified tag. Can be used multiple times.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdb', action='store_true',
|
||||
help='Runs a debugger (pdb, or ipdb if installed) on error or failure.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-b', '--buffer', action='store_true',
|
||||
help='Discard output from passing tests.',
|
||||
)
|
||||
if PY37:
|
||||
parser.add_argument(
|
||||
'-k', action='append', dest='test_name_patterns',
|
||||
help=(
|
||||
'Only run test methods and classes that match the pattern '
|
||||
'or substring. Can be used multiple times. Same as '
|
||||
'unittest -k option.'
|
||||
),
|
||||
)
|
||||
|
||||
def setup_test_environment(self, **kwargs):
|
||||
setup_test_environment(debug=self.debug_mode)
|
||||
unittest.installHandler()
|
||||
|
||||
def build_suite(self, test_labels=None, extra_tests=None, **kwargs):
|
||||
suite = self.test_suite()
|
||||
test_labels = test_labels or ['.']
|
||||
extra_tests = extra_tests or []
|
||||
self.test_loader.testNamePatterns = self.test_name_patterns
|
||||
|
||||
discover_kwargs = {}
|
||||
if self.pattern is not None:
|
||||
discover_kwargs['pattern'] = self.pattern
|
||||
if self.top_level is not None:
|
||||
discover_kwargs['top_level_dir'] = self.top_level
|
||||
|
||||
for label in test_labels:
|
||||
kwargs = discover_kwargs.copy()
|
||||
tests = None
|
||||
|
||||
label_as_path = os.path.abspath(label)
|
||||
|
||||
# if a module, or "module.ClassName[.method_name]", just run those
|
||||
if not os.path.exists(label_as_path):
|
||||
tests = self.test_loader.loadTestsFromName(label)
|
||||
elif os.path.isdir(label_as_path) and not self.top_level:
|
||||
# Try to be a bit smarter than unittest about finding the
|
||||
# default top-level for a given directory path, to avoid
|
||||
# breaking relative imports. (Unittest's default is to set
|
||||
# top-level equal to the path, which means relative imports
|
||||
# will result in "Attempted relative import in non-package.").
|
||||
|
||||
# We'd be happy to skip this and require dotted module paths
|
||||
# (which don't cause this problem) instead of file paths (which
|
||||
# do), but in the case of a directory in the cwd, which would
|
||||
# be equally valid if considered as a top-level module or as a
|
||||
# directory path, unittest unfortunately prefers the latter.
|
||||
|
||||
top_level = label_as_path
|
||||
while True:
|
||||
init_py = os.path.join(top_level, '__init__.py')
|
||||
if os.path.exists(init_py):
|
||||
try_next = os.path.dirname(top_level)
|
||||
if try_next == top_level:
|
||||
# __init__.py all the way down? give up.
|
||||
break
|
||||
top_level = try_next
|
||||
continue
|
||||
break
|
||||
kwargs['top_level_dir'] = top_level
|
||||
|
||||
if not (tests and tests.countTestCases()) and is_discoverable(label):
|
||||
# Try discovery if path is a package or directory
|
||||
tests = self.test_loader.discover(start_dir=label, **kwargs)
|
||||
|
||||
# Make unittest forget the top-level dir it calculated from this
|
||||
# run, to support running tests from two different top-levels.
|
||||
self.test_loader._top_level_dir = None
|
||||
|
||||
suite.addTests(tests)
|
||||
|
||||
for test in extra_tests:
|
||||
suite.addTest(test)
|
||||
|
||||
if self.tags or self.exclude_tags:
|
||||
if self.verbosity >= 2:
|
||||
if self.tags:
|
||||
print('Including test tag(s): %s.' % ', '.join(sorted(self.tags)))
|
||||
if self.exclude_tags:
|
||||
print('Excluding test tag(s): %s.' % ', '.join(sorted(self.exclude_tags)))
|
||||
suite = filter_tests_by_tags(suite, self.tags, self.exclude_tags)
|
||||
suite = reorder_suite(suite, self.reorder_by, self.reverse)
|
||||
|
||||
if self.parallel > 1:
|
||||
parallel_suite = self.parallel_test_suite(suite, self.parallel, self.failfast)
|
||||
|
||||
# Since tests are distributed across processes on a per-TestCase
|
||||
# basis, there's no need for more processes than TestCases.
|
||||
parallel_units = len(parallel_suite.subsuites)
|
||||
self.parallel = min(self.parallel, parallel_units)
|
||||
|
||||
# If there's only one TestCase, parallelization isn't needed.
|
||||
if self.parallel > 1:
|
||||
suite = parallel_suite
|
||||
|
||||
return suite
|
||||
|
||||
def setup_databases(self, **kwargs):
|
||||
return _setup_databases(
|
||||
self.verbosity, self.interactive, self.keepdb, self.debug_sql,
|
||||
self.parallel, **kwargs
|
||||
)
|
||||
|
||||
def get_resultclass(self):
|
||||
if self.debug_sql:
|
||||
return DebugSQLTextTestResult
|
||||
elif self.pdb:
|
||||
return PDBDebugResult
|
||||
|
||||
def get_test_runner_kwargs(self):
|
||||
return {
|
||||
'failfast': self.failfast,
|
||||
'resultclass': self.get_resultclass(),
|
||||
'verbosity': self.verbosity,
|
||||
'buffer': self.buffer,
|
||||
}
|
||||
|
||||
def run_checks(self, databases):
|
||||
# Checks are run after database creation since some checks require
|
||||
# database access.
|
||||
call_command('check', verbosity=self.verbosity, databases=databases)
|
||||
|
||||
def run_suite(self, suite, **kwargs):
|
||||
kwargs = self.get_test_runner_kwargs()
|
||||
runner = self.test_runner(**kwargs)
|
||||
return runner.run(suite)
|
||||
|
||||
def teardown_databases(self, old_config, **kwargs):
|
||||
"""Destroy all the non-mirror databases."""
|
||||
_teardown_databases(
|
||||
old_config,
|
||||
verbosity=self.verbosity,
|
||||
parallel=self.parallel,
|
||||
keepdb=self.keepdb,
|
||||
)
|
||||
|
||||
def teardown_test_environment(self, **kwargs):
|
||||
unittest.removeHandler()
|
||||
teardown_test_environment()
|
||||
|
||||
def suite_result(self, suite, result, **kwargs):
|
||||
return len(result.failures) + len(result.errors)
|
||||
|
||||
def _get_databases(self, suite):
|
||||
databases = set()
|
||||
for test in suite:
|
||||
if isinstance(test, unittest.TestCase):
|
||||
test_databases = getattr(test, 'databases', None)
|
||||
if test_databases == '__all__':
|
||||
return set(connections)
|
||||
if test_databases:
|
||||
databases.update(test_databases)
|
||||
else:
|
||||
databases.update(self._get_databases(test))
|
||||
return databases
|
||||
|
||||
def get_databases(self, suite):
|
||||
databases = self._get_databases(suite)
|
||||
if self.verbosity >= 2:
|
||||
unused_databases = [alias for alias in connections if alias not in databases]
|
||||
if unused_databases:
|
||||
print('Skipping setup of unused database(s): %s.' % ', '.join(sorted(unused_databases)))
|
||||
return databases
|
||||
|
||||
def run_tests(self, test_labels, extra_tests=None, **kwargs):
|
||||
"""
|
||||
Run the unit tests for all the test labels in the provided list.
|
||||
|
||||
Test labels should be dotted Python paths to test modules, test
|
||||
classes, or test methods.
|
||||
|
||||
A list of 'extra' tests may also be provided; these tests
|
||||
will be added to the test suite.
|
||||
|
||||
Return the number of tests that failed.
|
||||
"""
|
||||
self.setup_test_environment()
|
||||
suite = self.build_suite(test_labels, extra_tests)
|
||||
databases = self.get_databases(suite)
|
||||
old_config = self.setup_databases(aliases=databases)
|
||||
run_failed = False
|
||||
try:
|
||||
self.run_checks(databases)
|
||||
result = self.run_suite(suite)
|
||||
except Exception:
|
||||
run_failed = True
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
self.teardown_databases(old_config)
|
||||
self.teardown_test_environment()
|
||||
except Exception:
|
||||
# Silence teardown exceptions if an exception was raised during
|
||||
# runs to avoid shadowing it.
|
||||
if not run_failed:
|
||||
raise
|
||||
return self.suite_result(suite, result)
|
||||
|
||||
|
||||
def is_discoverable(label):
|
||||
"""
|
||||
Check if a test label points to a Python package or file directory.
|
||||
|
||||
Relative labels like "." and ".." are seen as directories.
|
||||
"""
|
||||
try:
|
||||
mod = import_module(label)
|
||||
except (ImportError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return hasattr(mod, '__path__')
|
||||
|
||||
return os.path.isdir(os.path.abspath(label))
|
||||
|
||||
|
||||
def reorder_suite(suite, classes, reverse=False):
|
||||
"""
|
||||
Reorder a test suite by test type.
|
||||
|
||||
`classes` is a sequence of types
|
||||
|
||||
All tests of type classes[0] are placed first, then tests of type
|
||||
classes[1], etc. Tests with no match in classes are placed last.
|
||||
|
||||
If `reverse` is True, sort tests within classes in opposite order but
|
||||
don't reverse test classes.
|
||||
"""
|
||||
class_count = len(classes)
|
||||
suite_class = type(suite)
|
||||
bins = [OrderedSet() for i in range(class_count + 1)]
|
||||
partition_suite_by_type(suite, classes, bins, reverse=reverse)
|
||||
reordered_suite = suite_class()
|
||||
for i in range(class_count + 1):
|
||||
reordered_suite.addTests(bins[i])
|
||||
return reordered_suite
|
||||
|
||||
|
||||
def partition_suite_by_type(suite, classes, bins, reverse=False):
|
||||
"""
|
||||
Partition a test suite by test type. Also prevent duplicated tests.
|
||||
|
||||
classes is a sequence of types
|
||||
bins is a sequence of TestSuites, one more than classes
|
||||
reverse changes the ordering of tests within bins
|
||||
|
||||
Tests of type classes[i] are added to bins[i],
|
||||
tests with no match found in classes are place in bins[-1]
|
||||
"""
|
||||
suite_class = type(suite)
|
||||
if reverse:
|
||||
suite = reversed(tuple(suite))
|
||||
for test in suite:
|
||||
if isinstance(test, suite_class):
|
||||
partition_suite_by_type(test, classes, bins, reverse=reverse)
|
||||
else:
|
||||
for i in range(len(classes)):
|
||||
if isinstance(test, classes[i]):
|
||||
bins[i].add(test)
|
||||
break
|
||||
else:
|
||||
bins[-1].add(test)
|
||||
|
||||
|
||||
def partition_suite_by_case(suite):
|
||||
"""Partition a test suite by test case, preserving the order of tests."""
|
||||
groups = []
|
||||
suite_class = type(suite)
|
||||
for test_type, test_group in itertools.groupby(suite, type):
|
||||
if issubclass(test_type, unittest.TestCase):
|
||||
groups.append(suite_class(test_group))
|
||||
else:
|
||||
for item in test_group:
|
||||
groups.extend(partition_suite_by_case(item))
|
||||
return groups
|
||||
|
||||
|
||||
def filter_tests_by_tags(suite, tags, exclude_tags):
|
||||
suite_class = type(suite)
|
||||
filtered_suite = suite_class()
|
||||
|
||||
for test in suite:
|
||||
if isinstance(test, suite_class):
|
||||
filtered_suite.addTests(filter_tests_by_tags(test, tags, exclude_tags))
|
||||
else:
|
||||
test_tags = set(getattr(test, 'tags', set()))
|
||||
test_fn_name = getattr(test, '_testMethodName', str(test))
|
||||
test_fn = getattr(test, test_fn_name, test)
|
||||
test_fn_tags = set(getattr(test_fn, 'tags', set()))
|
||||
all_tags = test_tags.union(test_fn_tags)
|
||||
matched_tags = all_tags.intersection(tags)
|
||||
if (matched_tags or not tags) and not all_tags.intersection(exclude_tags):
|
||||
filtered_suite.addTest(test)
|
||||
|
||||
return filtered_suite
|
||||
132
Lib/site-packages/django/test/selenium.py
Normal file
132
Lib/site-packages/django/test/selenium.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.test import LiveServerTestCase, tag
|
||||
from django.utils.functional import classproperty
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.text import capfirst
|
||||
|
||||
|
||||
class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||||
# List of browsers to dynamically create test classes for.
|
||||
browsers = []
|
||||
# A selenium hub URL to test against.
|
||||
selenium_hub = None
|
||||
# The external host Selenium Hub can reach.
|
||||
external_host = None
|
||||
# Sentinel value to differentiate browser-specific instances.
|
||||
browser = None
|
||||
# Run browsers in headless mode.
|
||||
headless = False
|
||||
|
||||
def __new__(cls, name, bases, attrs):
|
||||
"""
|
||||
Dynamically create new classes and add them to the test module when
|
||||
multiple browsers specs are provided (e.g. --selenium=firefox,chrome).
|
||||
"""
|
||||
test_class = super().__new__(cls, name, bases, attrs)
|
||||
# If the test class is either browser-specific or a test base, return it.
|
||||
if test_class.browser or not any(name.startswith('test') and callable(value) for name, value in attrs.items()):
|
||||
return test_class
|
||||
elif test_class.browsers:
|
||||
# Reuse the created test class to make it browser-specific.
|
||||
# We can't rename it to include the browser name or create a
|
||||
# subclass like we do with the remaining browsers as it would
|
||||
# either duplicate tests or prevent pickling of its instances.
|
||||
first_browser = test_class.browsers[0]
|
||||
test_class.browser = first_browser
|
||||
# Listen on an external interface if using a selenium hub.
|
||||
host = test_class.host if not test_class.selenium_hub else '0.0.0.0'
|
||||
test_class.host = host
|
||||
test_class.external_host = cls.external_host
|
||||
# Create subclasses for each of the remaining browsers and expose
|
||||
# them through the test's module namespace.
|
||||
module = sys.modules[test_class.__module__]
|
||||
for browser in test_class.browsers[1:]:
|
||||
browser_test_class = cls.__new__(
|
||||
cls,
|
||||
"%s%s" % (capfirst(browser), name),
|
||||
(test_class,),
|
||||
{
|
||||
'browser': browser,
|
||||
'host': host,
|
||||
'external_host': cls.external_host,
|
||||
'__module__': test_class.__module__,
|
||||
}
|
||||
)
|
||||
setattr(module, browser_test_class.__name__, browser_test_class)
|
||||
return test_class
|
||||
# If no browsers were specified, skip this class (it'll still be discovered).
|
||||
return unittest.skip('No browsers specified.')(test_class)
|
||||
|
||||
@classmethod
|
||||
def import_webdriver(cls, browser):
|
||||
return import_string("selenium.webdriver.%s.webdriver.WebDriver" % browser)
|
||||
|
||||
@classmethod
|
||||
def import_options(cls, browser):
|
||||
return import_string('selenium.webdriver.%s.options.Options' % browser)
|
||||
|
||||
@classmethod
|
||||
def get_capability(cls, browser):
|
||||
from selenium.webdriver.common.desired_capabilities import (
|
||||
DesiredCapabilities,
|
||||
)
|
||||
return getattr(DesiredCapabilities, browser.upper())
|
||||
|
||||
def create_options(self):
|
||||
options = self.import_options(self.browser)()
|
||||
if self.headless:
|
||||
try:
|
||||
options.headless = True
|
||||
except AttributeError:
|
||||
pass # Only Chrome and Firefox support the headless mode.
|
||||
return options
|
||||
|
||||
def create_webdriver(self):
|
||||
if self.selenium_hub:
|
||||
from selenium import webdriver
|
||||
return webdriver.Remote(
|
||||
command_executor=self.selenium_hub,
|
||||
desired_capabilities=self.get_capability(self.browser),
|
||||
)
|
||||
return self.import_webdriver(self.browser)(options=self.create_options())
|
||||
|
||||
|
||||
@tag('selenium')
|
||||
class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
|
||||
implicit_wait = 10
|
||||
external_host = None
|
||||
|
||||
@classproperty
|
||||
def live_server_url(cls):
|
||||
return 'http://%s:%s' % (cls.external_host or cls.host, cls.server_thread.port)
|
||||
|
||||
@classproperty
|
||||
def allowed_host(cls):
|
||||
return cls.external_host or cls.host
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.selenium = cls.create_webdriver()
|
||||
cls.selenium.implicitly_wait(cls.implicit_wait)
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def _tearDownClassInternal(cls):
|
||||
# quit() the WebDriver before attempting to terminate and join the
|
||||
# single-threaded LiveServerThread to avoid a dead lock if the browser
|
||||
# kept a connection alive.
|
||||
if hasattr(cls, 'selenium'):
|
||||
cls.selenium.quit()
|
||||
super()._tearDownClassInternal()
|
||||
|
||||
@contextmanager
|
||||
def disable_implicit_wait(self):
|
||||
"""Disable the default implicit wait."""
|
||||
self.selenium.implicitly_wait(0)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.selenium.implicitly_wait(self.implicit_wait)
|
||||
208
Lib/site-packages/django/test/signals.py
Normal file
208
Lib/site-packages/django/test/signals.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from asgiref.local import Local
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.signals import setting_changed
|
||||
from django.db import connections, router
|
||||
from django.db.utils import ConnectionRouter
|
||||
from django.dispatch import Signal, receiver
|
||||
from django.utils import timezone
|
||||
from django.utils.formats import FORMAT_SETTINGS, reset_format_cache
|
||||
from django.utils.functional import empty
|
||||
|
||||
template_rendered = Signal()
|
||||
|
||||
# Most setting_changed receivers are supposed to be added below,
|
||||
# except for cases where the receiver is related to a contrib app.
|
||||
|
||||
# Settings that may not work well when using 'override_settings' (#19031)
|
||||
COMPLEX_OVERRIDE_SETTINGS = {'DATABASES'}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_cache_handlers(**kwargs):
|
||||
if kwargs['setting'] == 'CACHES':
|
||||
from django.core.cache import caches, close_caches
|
||||
close_caches()
|
||||
caches._caches = Local()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_installed_apps(**kwargs):
|
||||
if kwargs['setting'] == 'INSTALLED_APPS':
|
||||
# Rebuild any AppDirectoriesFinder instance.
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
get_finder.cache_clear()
|
||||
# Rebuild management commands cache
|
||||
from django.core.management import get_commands
|
||||
get_commands.cache_clear()
|
||||
# Rebuild get_app_template_dirs cache.
|
||||
from django.template.utils import get_app_template_dirs
|
||||
get_app_template_dirs.cache_clear()
|
||||
# Rebuild translations cache.
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._translations = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def update_connections_time_zone(**kwargs):
|
||||
if kwargs['setting'] == 'TIME_ZONE':
|
||||
# Reset process time zone
|
||||
if hasattr(time, 'tzset'):
|
||||
if kwargs['value']:
|
||||
os.environ['TZ'] = kwargs['value']
|
||||
else:
|
||||
os.environ.pop('TZ', None)
|
||||
time.tzset()
|
||||
|
||||
# Reset local time zone cache
|
||||
timezone.get_default_timezone.cache_clear()
|
||||
|
||||
# Reset the database connections' time zone
|
||||
if kwargs['setting'] in {'TIME_ZONE', 'USE_TZ'}:
|
||||
for conn in connections.all():
|
||||
try:
|
||||
del conn.timezone
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
del conn.timezone_name
|
||||
except AttributeError:
|
||||
pass
|
||||
conn.ensure_timezone()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_routers_cache(**kwargs):
|
||||
if kwargs['setting'] == 'DATABASE_ROUTERS':
|
||||
router.routers = ConnectionRouter().routers
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def reset_template_engines(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'TEMPLATES',
|
||||
'DEBUG',
|
||||
'INSTALLED_APPS',
|
||||
}:
|
||||
from django.template import engines
|
||||
try:
|
||||
del engines.templates
|
||||
except AttributeError:
|
||||
pass
|
||||
engines._templates = None
|
||||
engines._engines = {}
|
||||
from django.template.engine import Engine
|
||||
Engine.get_default.cache_clear()
|
||||
from django.forms.renderers import get_default_renderer
|
||||
get_default_renderer.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def clear_serializers_cache(**kwargs):
|
||||
if kwargs['setting'] == 'SERIALIZATION_MODULES':
|
||||
from django.core import serializers
|
||||
serializers._serializers = {}
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def language_changed(**kwargs):
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LANGUAGE_CODE', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._default = None
|
||||
trans_real._active = Local()
|
||||
if kwargs['setting'] in {'LANGUAGES', 'LOCALE_PATHS'}:
|
||||
from django.utils.translation import trans_real
|
||||
trans_real._translations = {}
|
||||
trans_real.check_for_language.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def localize_settings_changed(**kwargs):
|
||||
if kwargs['setting'] in FORMAT_SETTINGS or kwargs['setting'] == 'USE_THOUSAND_SEPARATOR':
|
||||
reset_format_cache()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def file_storage_changed(**kwargs):
|
||||
if kwargs['setting'] == 'DEFAULT_FILE_STORAGE':
|
||||
from django.core.files.storage import default_storage
|
||||
default_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def complex_setting_changed(**kwargs):
|
||||
if kwargs['enter'] and kwargs['setting'] in COMPLEX_OVERRIDE_SETTINGS:
|
||||
# Considering the current implementation of the signals framework,
|
||||
# this stacklevel shows the line containing the override_settings call.
|
||||
warnings.warn("Overriding setting %s can lead to unexpected behavior."
|
||||
% kwargs['setting'], stacklevel=6)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def root_urlconf_changed(**kwargs):
|
||||
if kwargs['setting'] == 'ROOT_URLCONF':
|
||||
from django.urls import clear_url_caches, set_urlconf
|
||||
clear_url_caches()
|
||||
set_urlconf(None)
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_storage_changed(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_STORAGE',
|
||||
'STATIC_ROOT',
|
||||
'STATIC_URL',
|
||||
}:
|
||||
from django.contrib.staticfiles.storage import staticfiles_storage
|
||||
staticfiles_storage._wrapped = empty
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def static_finders_changed(**kwargs):
|
||||
if kwargs['setting'] in {
|
||||
'STATICFILES_DIRS',
|
||||
'STATIC_ROOT',
|
||||
}:
|
||||
from django.contrib.staticfiles.finders import get_finder
|
||||
get_finder.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def auth_password_validators_changed(**kwargs):
|
||||
if kwargs['setting'] == 'AUTH_PASSWORD_VALIDATORS':
|
||||
from django.contrib.auth.password_validation import (
|
||||
get_default_password_validators,
|
||||
)
|
||||
get_default_password_validators.cache_clear()
|
||||
|
||||
|
||||
@receiver(setting_changed)
|
||||
def user_model_swapped(**kwargs):
|
||||
if kwargs['setting'] == 'AUTH_USER_MODEL':
|
||||
apps.clear_cache()
|
||||
try:
|
||||
from django.contrib.auth import get_user_model
|
||||
UserModel = get_user_model()
|
||||
except ImproperlyConfigured:
|
||||
# Some tests set an invalid AUTH_USER_MODEL.
|
||||
pass
|
||||
else:
|
||||
from django.contrib.auth import backends
|
||||
backends.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import forms
|
||||
forms.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.handlers import modwsgi
|
||||
modwsgi.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth.management.commands import changepassword
|
||||
changepassword.UserModel = UserModel
|
||||
|
||||
from django.contrib.auth import views
|
||||
views.UserModel = UserModel
|
||||
1502
Lib/site-packages/django/test/testcases.py
Normal file
1502
Lib/site-packages/django/test/testcases.py
Normal file
File diff suppressed because it is too large
Load Diff
867
Lib/site-packages/django/test/utils.py
Normal file
867
Lib/site-packages/django/test/utils.py
Normal file
@@ -0,0 +1,867 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from io import StringIO
|
||||
from itertools import chain
|
||||
from types import SimpleNamespace
|
||||
from unittest import TestCase, skipIf, skipUnless
|
||||
from xml.dom.minidom import Node, parseString
|
||||
|
||||
from django.apps import apps
|
||||
from django.apps.registry import Apps
|
||||
from django.conf import UserSettingsHolder, settings
|
||||
from django.core import mail
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.signals import request_started
|
||||
from django.db import DEFAULT_DB_ALIAS, connections, reset_queries
|
||||
from django.db.models.options import Options
|
||||
from django.template import Template
|
||||
from django.test.signals import setting_changed, template_rendered
|
||||
from django.urls import get_script_prefix, set_script_prefix
|
||||
from django.utils.translation import deactivate
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError:
|
||||
jinja2 = None
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Approximate', 'ContextList', 'isolate_lru_cache', 'get_runner',
|
||||
'modify_settings', 'override_settings',
|
||||
'requires_tz_support',
|
||||
'setup_test_environment', 'teardown_test_environment',
|
||||
)
|
||||
|
||||
TZ_SUPPORT = hasattr(time, 'tzset')
|
||||
|
||||
|
||||
class Approximate:
|
||||
def __init__(self, val, places=7):
|
||||
self.val = val
|
||||
self.places = places
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.val)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.val == other or round(abs(self.val - other), self.places) == 0
|
||||
|
||||
|
||||
class ContextList(list):
|
||||
"""
|
||||
A wrapper that provides direct key access to context items contained
|
||||
in a list of context objects.
|
||||
"""
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, str):
|
||||
for subcontext in self:
|
||||
if key in subcontext:
|
||||
return subcontext[key]
|
||||
raise KeyError(key)
|
||||
else:
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self.__getitem__(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
self[key]
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def keys(self):
|
||||
"""
|
||||
Flattened keys of subcontexts.
|
||||
"""
|
||||
return set(chain.from_iterable(d for subcontext in self for d in subcontext))
|
||||
|
||||
|
||||
def instrumented_test_render(self, context):
|
||||
"""
|
||||
An instrumented Template render method, providing a signal that can be
|
||||
intercepted by the test Client.
|
||||
"""
|
||||
template_rendered.send(sender=self, template=self, context=context)
|
||||
return self.nodelist.render(context)
|
||||
|
||||
|
||||
class _TestState:
|
||||
pass
|
||||
|
||||
|
||||
def setup_test_environment(debug=None):
|
||||
"""
|
||||
Perform global pre-test setup, such as installing the instrumented template
|
||||
renderer and setting the email backend to the locmem email backend.
|
||||
"""
|
||||
if hasattr(_TestState, 'saved_data'):
|
||||
# Executing this function twice would overwrite the saved values.
|
||||
raise RuntimeError(
|
||||
"setup_test_environment() was already called and can't be called "
|
||||
"again without first calling teardown_test_environment()."
|
||||
)
|
||||
|
||||
if debug is None:
|
||||
debug = settings.DEBUG
|
||||
|
||||
saved_data = SimpleNamespace()
|
||||
_TestState.saved_data = saved_data
|
||||
|
||||
saved_data.allowed_hosts = settings.ALLOWED_HOSTS
|
||||
# Add the default host of the test client.
|
||||
settings.ALLOWED_HOSTS = [*settings.ALLOWED_HOSTS, 'testserver']
|
||||
|
||||
saved_data.debug = settings.DEBUG
|
||||
settings.DEBUG = debug
|
||||
|
||||
saved_data.email_backend = settings.EMAIL_BACKEND
|
||||
settings.EMAIL_BACKEND = 'django.core.mail.backends.locmem.EmailBackend'
|
||||
|
||||
saved_data.template_render = Template._render
|
||||
Template._render = instrumented_test_render
|
||||
|
||||
mail.outbox = []
|
||||
|
||||
deactivate()
|
||||
|
||||
|
||||
def teardown_test_environment():
|
||||
"""
|
||||
Perform any global post-test teardown, such as restoring the original
|
||||
template renderer and restoring the email sending functions.
|
||||
"""
|
||||
saved_data = _TestState.saved_data
|
||||
|
||||
settings.ALLOWED_HOSTS = saved_data.allowed_hosts
|
||||
settings.DEBUG = saved_data.debug
|
||||
settings.EMAIL_BACKEND = saved_data.email_backend
|
||||
Template._render = saved_data.template_render
|
||||
|
||||
del _TestState.saved_data
|
||||
del mail.outbox
|
||||
|
||||
|
||||
def setup_databases(verbosity, interactive, keepdb=False, debug_sql=False, parallel=0, aliases=None, **kwargs):
|
||||
"""Create the test databases."""
|
||||
test_databases, mirrored_aliases = get_unique_databases_and_mirrors(aliases)
|
||||
|
||||
old_names = []
|
||||
|
||||
for db_name, aliases in test_databases.values():
|
||||
first_alias = None
|
||||
for alias in aliases:
|
||||
connection = connections[alias]
|
||||
old_names.append((connection, db_name, first_alias is None))
|
||||
|
||||
# Actually create the database for the first connection
|
||||
if first_alias is None:
|
||||
first_alias = alias
|
||||
connection.creation.create_test_db(
|
||||
verbosity=verbosity,
|
||||
autoclobber=not interactive,
|
||||
keepdb=keepdb,
|
||||
serialize=connection.settings_dict['TEST'].get('SERIALIZE', True),
|
||||
)
|
||||
if parallel > 1:
|
||||
for index in range(parallel):
|
||||
connection.creation.clone_test_db(
|
||||
suffix=str(index + 1),
|
||||
verbosity=verbosity,
|
||||
keepdb=keepdb,
|
||||
)
|
||||
# Configure all other connections as mirrors of the first one
|
||||
else:
|
||||
connections[alias].creation.set_as_test_mirror(connections[first_alias].settings_dict)
|
||||
|
||||
# Configure the test mirrors.
|
||||
for alias, mirror_alias in mirrored_aliases.items():
|
||||
connections[alias].creation.set_as_test_mirror(
|
||||
connections[mirror_alias].settings_dict)
|
||||
|
||||
if debug_sql:
|
||||
for alias in connections:
|
||||
connections[alias].force_debug_cursor = True
|
||||
|
||||
return old_names
|
||||
|
||||
|
||||
def dependency_ordered(test_databases, dependencies):
|
||||
"""
|
||||
Reorder test_databases into an order that honors the dependencies
|
||||
described in TEST[DEPENDENCIES].
|
||||
"""
|
||||
ordered_test_databases = []
|
||||
resolved_databases = set()
|
||||
|
||||
# Maps db signature to dependencies of all its aliases
|
||||
dependencies_map = {}
|
||||
|
||||
# Check that no database depends on its own alias
|
||||
for sig, (_, aliases) in test_databases:
|
||||
all_deps = set()
|
||||
for alias in aliases:
|
||||
all_deps.update(dependencies.get(alias, []))
|
||||
if not all_deps.isdisjoint(aliases):
|
||||
raise ImproperlyConfigured(
|
||||
"Circular dependency: databases %r depend on each other, "
|
||||
"but are aliases." % aliases
|
||||
)
|
||||
dependencies_map[sig] = all_deps
|
||||
|
||||
while test_databases:
|
||||
changed = False
|
||||
deferred = []
|
||||
|
||||
# Try to find a DB that has all its dependencies met
|
||||
for signature, (db_name, aliases) in test_databases:
|
||||
if dependencies_map[signature].issubset(resolved_databases):
|
||||
resolved_databases.update(aliases)
|
||||
ordered_test_databases.append((signature, (db_name, aliases)))
|
||||
changed = True
|
||||
else:
|
||||
deferred.append((signature, (db_name, aliases)))
|
||||
|
||||
if not changed:
|
||||
raise ImproperlyConfigured("Circular dependency in TEST[DEPENDENCIES]")
|
||||
test_databases = deferred
|
||||
return ordered_test_databases
|
||||
|
||||
|
||||
def get_unique_databases_and_mirrors(aliases=None):
|
||||
"""
|
||||
Figure out which databases actually need to be created.
|
||||
|
||||
Deduplicate entries in DATABASES that correspond the same database or are
|
||||
configured as test mirrors.
|
||||
|
||||
Return two values:
|
||||
- test_databases: ordered mapping of signatures to (name, list of aliases)
|
||||
where all aliases share the same underlying database.
|
||||
- mirrored_aliases: mapping of mirror aliases to original aliases.
|
||||
"""
|
||||
if aliases is None:
|
||||
aliases = connections
|
||||
mirrored_aliases = {}
|
||||
test_databases = {}
|
||||
dependencies = {}
|
||||
default_sig = connections[DEFAULT_DB_ALIAS].creation.test_db_signature()
|
||||
|
||||
for alias in connections:
|
||||
connection = connections[alias]
|
||||
test_settings = connection.settings_dict['TEST']
|
||||
|
||||
if test_settings['MIRROR']:
|
||||
# If the database is marked as a test mirror, save the alias.
|
||||
mirrored_aliases[alias] = test_settings['MIRROR']
|
||||
elif alias in aliases:
|
||||
# Store a tuple with DB parameters that uniquely identify it.
|
||||
# If we have two aliases with the same values for that tuple,
|
||||
# we only need to create the test database once.
|
||||
item = test_databases.setdefault(
|
||||
connection.creation.test_db_signature(),
|
||||
(connection.settings_dict['NAME'], set())
|
||||
)
|
||||
item[1].add(alias)
|
||||
|
||||
if 'DEPENDENCIES' in test_settings:
|
||||
dependencies[alias] = test_settings['DEPENDENCIES']
|
||||
else:
|
||||
if alias != DEFAULT_DB_ALIAS and connection.creation.test_db_signature() != default_sig:
|
||||
dependencies[alias] = test_settings.get('DEPENDENCIES', [DEFAULT_DB_ALIAS])
|
||||
|
||||
test_databases = dict(dependency_ordered(test_databases.items(), dependencies))
|
||||
return test_databases, mirrored_aliases
|
||||
|
||||
|
||||
def teardown_databases(old_config, verbosity, parallel=0, keepdb=False):
|
||||
"""Destroy all the non-mirror databases."""
|
||||
for connection, old_name, destroy in old_config:
|
||||
if destroy:
|
||||
if parallel > 1:
|
||||
for index in range(parallel):
|
||||
connection.creation.destroy_test_db(
|
||||
suffix=str(index + 1),
|
||||
verbosity=verbosity,
|
||||
keepdb=keepdb,
|
||||
)
|
||||
connection.creation.destroy_test_db(old_name, verbosity, keepdb)
|
||||
|
||||
|
||||
def get_runner(settings, test_runner_class=None):
|
||||
test_runner_class = test_runner_class or settings.TEST_RUNNER
|
||||
test_path = test_runner_class.split('.')
|
||||
# Allow for relative paths
|
||||
if len(test_path) > 1:
|
||||
test_module_name = '.'.join(test_path[:-1])
|
||||
else:
|
||||
test_module_name = '.'
|
||||
test_module = __import__(test_module_name, {}, {}, test_path[-1])
|
||||
return getattr(test_module, test_path[-1])
|
||||
|
||||
|
||||
class TestContextDecorator:
|
||||
"""
|
||||
A base class that can either be used as a context manager during tests
|
||||
or as a test function or unittest.TestCase subclass decorator to perform
|
||||
temporary alterations.
|
||||
|
||||
`attr_name`: attribute assigned the return value of enable() if used as
|
||||
a class decorator.
|
||||
|
||||
`kwarg_name`: keyword argument passing the return value of enable() if
|
||||
used as a function decorator.
|
||||
"""
|
||||
def __init__(self, attr_name=None, kwarg_name=None):
|
||||
self.attr_name = attr_name
|
||||
self.kwarg_name = kwarg_name
|
||||
|
||||
def enable(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def disable(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
return self.enable()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.disable()
|
||||
|
||||
def decorate_class(self, cls):
|
||||
if issubclass(cls, TestCase):
|
||||
decorated_setUp = cls.setUp
|
||||
decorated_tearDown = cls.tearDown
|
||||
|
||||
def setUp(inner_self):
|
||||
context = self.enable()
|
||||
if self.attr_name:
|
||||
setattr(inner_self, self.attr_name, context)
|
||||
try:
|
||||
decorated_setUp(inner_self)
|
||||
except Exception:
|
||||
self.disable()
|
||||
raise
|
||||
|
||||
def tearDown(inner_self):
|
||||
decorated_tearDown(inner_self)
|
||||
self.disable()
|
||||
|
||||
cls.setUp = setUp
|
||||
cls.tearDown = tearDown
|
||||
return cls
|
||||
raise TypeError('Can only decorate subclasses of unittest.TestCase')
|
||||
|
||||
def decorate_callable(self, func):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# If the inner function is an async function, we must execute async
|
||||
# as well so that the `with` statement executes at the right time.
|
||||
@wraps(func)
|
||||
async def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
with self as context:
|
||||
if self.kwarg_name:
|
||||
kwargs[self.kwarg_name] = context
|
||||
return func(*args, **kwargs)
|
||||
return inner
|
||||
|
||||
def __call__(self, decorated):
|
||||
if isinstance(decorated, type):
|
||||
return self.decorate_class(decorated)
|
||||
elif callable(decorated):
|
||||
return self.decorate_callable(decorated)
|
||||
raise TypeError('Cannot decorate object of type %s' % type(decorated))
|
||||
|
||||
|
||||
class override_settings(TestContextDecorator):
|
||||
"""
|
||||
Act as either a decorator or a context manager. If it's a decorator, take a
|
||||
function and return a wrapped function. If it's a contextmanager, use it
|
||||
with the ``with`` statement. In either event, entering/exiting are called
|
||||
before and after, respectively, the function/block is executed.
|
||||
"""
|
||||
enable_exception = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.options = kwargs
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
# Keep this code at the beginning to leave the settings unchanged
|
||||
# in case it raises an exception because INSTALLED_APPS is invalid.
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
try:
|
||||
apps.set_installed_apps(self.options['INSTALLED_APPS'])
|
||||
except Exception:
|
||||
apps.unset_installed_apps()
|
||||
raise
|
||||
override = UserSettingsHolder(settings._wrapped)
|
||||
for key, new_value in self.options.items():
|
||||
setattr(override, key, new_value)
|
||||
self.wrapped = settings._wrapped
|
||||
settings._wrapped = override
|
||||
for key, new_value in self.options.items():
|
||||
try:
|
||||
setting_changed.send(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key, value=new_value, enter=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.enable_exception = exc
|
||||
self.disable()
|
||||
|
||||
def disable(self):
|
||||
if 'INSTALLED_APPS' in self.options:
|
||||
apps.unset_installed_apps()
|
||||
settings._wrapped = self.wrapped
|
||||
del self.wrapped
|
||||
responses = []
|
||||
for key in self.options:
|
||||
new_value = getattr(settings, key, None)
|
||||
responses_for_setting = setting_changed.send_robust(
|
||||
sender=settings._wrapped.__class__,
|
||||
setting=key, value=new_value, enter=False,
|
||||
)
|
||||
responses.extend(responses_for_setting)
|
||||
if self.enable_exception is not None:
|
||||
exc = self.enable_exception
|
||||
self.enable_exception = None
|
||||
raise exc
|
||||
for _, response in responses:
|
||||
if isinstance(response, Exception):
|
||||
raise response
|
||||
|
||||
def save_options(self, test_func):
|
||||
if test_func._overridden_settings is None:
|
||||
test_func._overridden_settings = self.options
|
||||
else:
|
||||
# Duplicate dict to prevent subclasses from altering their parent.
|
||||
test_func._overridden_settings = {
|
||||
**test_func._overridden_settings,
|
||||
**self.options,
|
||||
}
|
||||
|
||||
def decorate_class(self, cls):
|
||||
from django.test import SimpleTestCase
|
||||
if not issubclass(cls, SimpleTestCase):
|
||||
raise ValueError(
|
||||
"Only subclasses of Django SimpleTestCase can be decorated "
|
||||
"with override_settings")
|
||||
self.save_options(cls)
|
||||
return cls
|
||||
|
||||
|
||||
class modify_settings(override_settings):
|
||||
"""
|
||||
Like override_settings, but makes it possible to append, prepend, or remove
|
||||
items instead of redefining the entire list.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if args:
|
||||
# Hack used when instantiating from SimpleTestCase.setUpClass.
|
||||
assert not kwargs
|
||||
self.operations = args[0]
|
||||
else:
|
||||
assert not args
|
||||
self.operations = list(kwargs.items())
|
||||
super(override_settings, self).__init__()
|
||||
|
||||
def save_options(self, test_func):
|
||||
if test_func._modified_settings is None:
|
||||
test_func._modified_settings = self.operations
|
||||
else:
|
||||
# Duplicate list to prevent subclasses from altering their parent.
|
||||
test_func._modified_settings = list(
|
||||
test_func._modified_settings) + self.operations
|
||||
|
||||
def enable(self):
|
||||
self.options = {}
|
||||
for name, operations in self.operations:
|
||||
try:
|
||||
# When called from SimpleTestCase.setUpClass, values may be
|
||||
# overridden several times; cumulate changes.
|
||||
value = self.options[name]
|
||||
except KeyError:
|
||||
value = list(getattr(settings, name, []))
|
||||
for action, items in operations.items():
|
||||
# items my be a single value or an iterable.
|
||||
if isinstance(items, str):
|
||||
items = [items]
|
||||
if action == 'append':
|
||||
value = value + [item for item in items if item not in value]
|
||||
elif action == 'prepend':
|
||||
value = [item for item in items if item not in value] + value
|
||||
elif action == 'remove':
|
||||
value = [item for item in value if item not in items]
|
||||
else:
|
||||
raise ValueError("Unsupported action: %s" % action)
|
||||
self.options[name] = value
|
||||
super().enable()
|
||||
|
||||
|
||||
class override_system_checks(TestContextDecorator):
|
||||
"""
|
||||
Act as a decorator. Override list of registered system checks.
|
||||
Useful when you override `INSTALLED_APPS`, e.g. if you exclude `auth` app,
|
||||
you also need to exclude its system checks.
|
||||
"""
|
||||
def __init__(self, new_checks, deployment_checks=None):
|
||||
from django.core.checks.registry import registry
|
||||
self.registry = registry
|
||||
self.new_checks = new_checks
|
||||
self.deployment_checks = deployment_checks
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.old_checks = self.registry.registered_checks
|
||||
self.registry.registered_checks = set()
|
||||
for check in self.new_checks:
|
||||
self.registry.register(check, *getattr(check, 'tags', ()))
|
||||
self.old_deployment_checks = self.registry.deployment_checks
|
||||
if self.deployment_checks is not None:
|
||||
self.registry.deployment_checks = set()
|
||||
for check in self.deployment_checks:
|
||||
self.registry.register(check, *getattr(check, 'tags', ()), deploy=True)
|
||||
|
||||
def disable(self):
|
||||
self.registry.registered_checks = self.old_checks
|
||||
self.registry.deployment_checks = self.old_deployment_checks
|
||||
|
||||
|
||||
def compare_xml(want, got):
|
||||
"""
|
||||
Try to do a 'xml-comparison' of want and got. Plain string comparison
|
||||
doesn't always work because, for example, attribute ordering should not be
|
||||
important. Ignore comment nodes, processing instructions, document type
|
||||
node, and leading and trailing whitespaces.
|
||||
|
||||
Based on https://github.com/lxml/lxml/blob/master/src/lxml/doctestcompare.py
|
||||
"""
|
||||
_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
|
||||
|
||||
def norm_whitespace(v):
|
||||
return _norm_whitespace_re.sub(' ', v)
|
||||
|
||||
def child_text(element):
|
||||
return ''.join(c.data for c in element.childNodes
|
||||
if c.nodeType == Node.TEXT_NODE)
|
||||
|
||||
def children(element):
|
||||
return [c for c in element.childNodes
|
||||
if c.nodeType == Node.ELEMENT_NODE]
|
||||
|
||||
def norm_child_text(element):
|
||||
return norm_whitespace(child_text(element))
|
||||
|
||||
def attrs_dict(element):
|
||||
return dict(element.attributes.items())
|
||||
|
||||
def check_element(want_element, got_element):
|
||||
if want_element.tagName != got_element.tagName:
|
||||
return False
|
||||
if norm_child_text(want_element) != norm_child_text(got_element):
|
||||
return False
|
||||
if attrs_dict(want_element) != attrs_dict(got_element):
|
||||
return False
|
||||
want_children = children(want_element)
|
||||
got_children = children(got_element)
|
||||
if len(want_children) != len(got_children):
|
||||
return False
|
||||
return all(check_element(want, got) for want, got in zip(want_children, got_children))
|
||||
|
||||
def first_node(document):
|
||||
for node in document.childNodes:
|
||||
if node.nodeType not in (
|
||||
Node.COMMENT_NODE,
|
||||
Node.DOCUMENT_TYPE_NODE,
|
||||
Node.PROCESSING_INSTRUCTION_NODE,
|
||||
):
|
||||
return node
|
||||
|
||||
want = want.strip().replace('\\n', '\n')
|
||||
got = got.strip().replace('\\n', '\n')
|
||||
|
||||
# If the string is not a complete xml document, we may need to add a
|
||||
# root element. This allow us to compare fragments, like "<foo/><bar/>"
|
||||
if not want.startswith('<?xml'):
|
||||
wrapper = '<root>%s</root>'
|
||||
want = wrapper % want
|
||||
got = wrapper % got
|
||||
|
||||
# Parse the want and got strings, and compare the parsings.
|
||||
want_root = first_node(parseString(want))
|
||||
got_root = first_node(parseString(got))
|
||||
|
||||
return check_element(want_root, got_root)
|
||||
|
||||
|
||||
class CaptureQueriesContext:
|
||||
"""
|
||||
Context manager that captures queries executed by the specified connection.
|
||||
"""
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.captured_queries)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.captured_queries[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captured_queries)
|
||||
|
||||
@property
|
||||
def captured_queries(self):
|
||||
return self.connection.queries[self.initial_queries:self.final_queries]
|
||||
|
||||
def __enter__(self):
|
||||
self.force_debug_cursor = self.connection.force_debug_cursor
|
||||
self.connection.force_debug_cursor = True
|
||||
# Run any initialization queries if needed so that they won't be
|
||||
# included as part of the count.
|
||||
self.connection.ensure_connection()
|
||||
self.initial_queries = len(self.connection.queries_log)
|
||||
self.final_queries = None
|
||||
request_started.disconnect(reset_queries)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.force_debug_cursor = self.force_debug_cursor
|
||||
request_started.connect(reset_queries)
|
||||
if exc_type is not None:
|
||||
return
|
||||
self.final_queries = len(self.connection.queries_log)
|
||||
|
||||
|
||||
class ignore_warnings(TestContextDecorator):
|
||||
def __init__(self, **kwargs):
|
||||
self.ignore_kwargs = kwargs
|
||||
if 'message' in self.ignore_kwargs or 'module' in self.ignore_kwargs:
|
||||
self.filter_func = warnings.filterwarnings
|
||||
else:
|
||||
self.filter_func = warnings.simplefilter
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.catch_warnings = warnings.catch_warnings()
|
||||
self.catch_warnings.__enter__()
|
||||
self.filter_func('ignore', **self.ignore_kwargs)
|
||||
|
||||
def disable(self):
|
||||
self.catch_warnings.__exit__(*sys.exc_info())
|
||||
|
||||
|
||||
# On OSes that don't provide tzset (Windows), we can't set the timezone
|
||||
# in which the program runs. As a consequence, we must skip tests that
|
||||
# don't enforce a specific timezone (with timezone.override or equivalent),
|
||||
# or attempt to interpret naive datetimes in the default timezone.
|
||||
|
||||
requires_tz_support = skipUnless(
|
||||
TZ_SUPPORT,
|
||||
"This test relies on the ability to run a program in an arbitrary "
|
||||
"time zone, but your operating system isn't able to do that."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def extend_sys_path(*paths):
|
||||
"""Context manager to temporarily add paths to sys.path."""
|
||||
_orig_sys_path = sys.path[:]
|
||||
sys.path.extend(paths)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.path = _orig_sys_path
|
||||
|
||||
|
||||
@contextmanager
|
||||
def isolate_lru_cache(lru_cache_object):
|
||||
"""Clear the cache of an LRU cache object on entering and exiting."""
|
||||
lru_cache_object.cache_clear()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lru_cache_object.cache_clear()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def captured_output(stream_name):
|
||||
"""Return a context manager used by captured_stdout/stdin/stderr
|
||||
that temporarily replaces the sys stream *stream_name* with a StringIO.
|
||||
|
||||
Note: This function and the following ``captured_std*`` are copied
|
||||
from CPython's ``test.support`` module."""
|
||||
orig_stdout = getattr(sys, stream_name)
|
||||
setattr(sys, stream_name, StringIO())
|
||||
try:
|
||||
yield getattr(sys, stream_name)
|
||||
finally:
|
||||
setattr(sys, stream_name, orig_stdout)
|
||||
|
||||
|
||||
def captured_stdout():
|
||||
"""Capture the output of sys.stdout:
|
||||
|
||||
with captured_stdout() as stdout:
|
||||
print("hello")
|
||||
self.assertEqual(stdout.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stdout")
|
||||
|
||||
|
||||
def captured_stderr():
|
||||
"""Capture the output of sys.stderr:
|
||||
|
||||
with captured_stderr() as stderr:
|
||||
print("hello", file=sys.stderr)
|
||||
self.assertEqual(stderr.getvalue(), "hello\n")
|
||||
"""
|
||||
return captured_output("stderr")
|
||||
|
||||
|
||||
def captured_stdin():
|
||||
"""Capture the input to sys.stdin:
|
||||
|
||||
with captured_stdin() as stdin:
|
||||
stdin.write('hello\n')
|
||||
stdin.seek(0)
|
||||
# call test code that consumes from sys.stdin
|
||||
captured = input()
|
||||
self.assertEqual(captured, "hello")
|
||||
"""
|
||||
return captured_output("stdin")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_time(t):
|
||||
"""
|
||||
Context manager to temporarily freeze time.time(). This temporarily
|
||||
modifies the time function of the time module. Modules which import the
|
||||
time function directly (e.g. `from time import time`) won't be affected
|
||||
This isn't meant as a public API, but helps reduce some repetitive code in
|
||||
Django's test suite.
|
||||
"""
|
||||
_real_time = time.time
|
||||
time.time = lambda: t
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
time.time = _real_time
|
||||
|
||||
|
||||
def require_jinja2(test_func):
|
||||
"""
|
||||
Decorator to enable a Jinja2 template engine in addition to the regular
|
||||
Django template engine for a test or skip it if Jinja2 isn't available.
|
||||
"""
|
||||
test_func = skipIf(jinja2 is None, "this test requires jinja2")(test_func)
|
||||
return override_settings(TEMPLATES=[{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'APP_DIRS': True,
|
||||
}, {
|
||||
'BACKEND': 'django.template.backends.jinja2.Jinja2',
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {'keep_trailing_newline': True},
|
||||
}])(test_func)
|
||||
|
||||
|
||||
class override_script_prefix(TestContextDecorator):
|
||||
"""Decorator or context manager to temporary override the script prefix."""
|
||||
def __init__(self, prefix):
|
||||
self.prefix = prefix
|
||||
super().__init__()
|
||||
|
||||
def enable(self):
|
||||
self.old_prefix = get_script_prefix()
|
||||
set_script_prefix(self.prefix)
|
||||
|
||||
def disable(self):
|
||||
set_script_prefix(self.old_prefix)
|
||||
|
||||
|
||||
class LoggingCaptureMixin:
|
||||
"""
|
||||
Capture the output from the 'django' logger and store it on the class's
|
||||
logger_output attribute.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.logger = logging.getLogger('django')
|
||||
self.old_stream = self.logger.handlers[0].stream
|
||||
self.logger_output = StringIO()
|
||||
self.logger.handlers[0].stream = self.logger_output
|
||||
|
||||
def tearDown(self):
|
||||
self.logger.handlers[0].stream = self.old_stream
|
||||
|
||||
|
||||
class isolate_apps(TestContextDecorator):
|
||||
"""
|
||||
Act as either a decorator or a context manager to register models defined
|
||||
in its wrapped context to an isolated registry.
|
||||
|
||||
The list of installed apps the isolated registry should contain must be
|
||||
passed as arguments.
|
||||
|
||||
Two optional keyword arguments can be specified:
|
||||
|
||||
`attr_name`: attribute assigned the isolated registry if used as a class
|
||||
decorator.
|
||||
|
||||
`kwarg_name`: keyword argument passing the isolated registry if used as a
|
||||
function decorator.
|
||||
"""
|
||||
def __init__(self, *installed_apps, **kwargs):
|
||||
self.installed_apps = installed_apps
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def enable(self):
|
||||
self.old_apps = Options.default_apps
|
||||
apps = Apps(self.installed_apps)
|
||||
setattr(Options, 'default_apps', apps)
|
||||
return apps
|
||||
|
||||
def disable(self):
|
||||
setattr(Options, 'default_apps', self.old_apps)
|
||||
|
||||
|
||||
def tag(*tags):
|
||||
"""Decorator to add tags to a test class or method."""
|
||||
def decorator(obj):
|
||||
if hasattr(obj, 'tags'):
|
||||
obj.tags = obj.tags.union(tags)
|
||||
else:
|
||||
setattr(obj, 'tags', set(tags))
|
||||
return obj
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def register_lookup(field, *lookups, lookup_name=None):
|
||||
"""
|
||||
Context manager to temporarily register lookups on a model field using
|
||||
lookup_name (or the lookup's lookup_name if not provided).
|
||||
"""
|
||||
try:
|
||||
for lookup in lookups:
|
||||
field.register_lookup(lookup, lookup_name)
|
||||
yield
|
||||
finally:
|
||||
for lookup in lookups:
|
||||
field._unregister_lookup(lookup, lookup_name)
|
||||
Reference in New Issue
Block a user