239 lines
8.1 KiB
Python
239 lines
8.1 KiB
Python
"""
|
|
Useful auxiliary data structures for query construction. Not useful outside
|
|
the SQL domain.
|
|
"""
|
|
|
|
import warnings
|
|
|
|
from django.core.exceptions import FullResultSet
|
|
from django.db.models.sql.constants import INNER, LOUTER
|
|
from django.utils.deprecation import RemovedInDjango60Warning
|
|
|
|
|
|
class MultiJoin(Exception):
|
|
"""
|
|
Used by join construction code to indicate the point at which a
|
|
multi-valued join was attempted (if the caller wants to treat that
|
|
exceptionally).
|
|
"""
|
|
|
|
def __init__(self, names_pos, path_with_names):
|
|
self.level = names_pos
|
|
# The path travelled, this includes the path to the multijoin.
|
|
self.names_with_path = path_with_names
|
|
|
|
|
|
class Empty:
|
|
pass
|
|
|
|
|
|
class Join:
|
|
"""
|
|
Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
|
|
FROM entry. For example, the SQL generated could be
|
|
LEFT OUTER JOIN "sometable" T1
|
|
ON ("othertable"."sometable_id" = "sometable"."id")
|
|
|
|
This class is primarily used in Query.alias_map. All entries in alias_map
|
|
must be Join compatible by providing the following attributes and methods:
|
|
- table_name (string)
|
|
- table_alias (possible alias for the table, can be None)
|
|
- join_type (can be None for those entries that aren't joined from
|
|
anything)
|
|
- parent_alias (which table is this join's parent, can be None similarly
|
|
to join_type)
|
|
- as_sql()
|
|
- relabeled_clone()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
table_name,
|
|
parent_alias,
|
|
table_alias,
|
|
join_type,
|
|
join_field,
|
|
nullable,
|
|
filtered_relation=None,
|
|
):
|
|
# Join table
|
|
self.table_name = table_name
|
|
self.parent_alias = parent_alias
|
|
# Note: table_alias is not necessarily known at instantiation time.
|
|
self.table_alias = table_alias
|
|
# LOUTER or INNER
|
|
self.join_type = join_type
|
|
# A list of 2-tuples to use in the ON clause of the JOIN.
|
|
# Each 2-tuple will create one join condition in the ON clause.
|
|
if hasattr(join_field, "get_joining_fields"):
|
|
self.join_fields = join_field.get_joining_fields()
|
|
self.join_cols = tuple(
|
|
(lhs_field.column, rhs_field.column)
|
|
for lhs_field, rhs_field in self.join_fields
|
|
)
|
|
else:
|
|
warnings.warn(
|
|
"The usage of get_joining_columns() in Join is deprecated. Implement "
|
|
"get_joining_fields() instead.",
|
|
RemovedInDjango60Warning,
|
|
)
|
|
self.join_fields = None
|
|
self.join_cols = join_field.get_joining_columns()
|
|
# Along which field (or ForeignObjectRel in the reverse join case)
|
|
self.join_field = join_field
|
|
# Is this join nullabled?
|
|
self.nullable = nullable
|
|
self.filtered_relation = filtered_relation
|
|
|
|
def as_sql(self, compiler, connection):
|
|
"""
|
|
Generate the full
|
|
LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
|
|
clause for this join.
|
|
"""
|
|
join_conditions = []
|
|
params = []
|
|
qn = compiler.quote_name_unless_alias
|
|
qn2 = connection.ops.quote_name
|
|
# Add a join condition for each pair of joining columns.
|
|
# RemovedInDjango60Warning: when the depraction ends, replace with:
|
|
# for lhs, rhs in self.join_field:
|
|
join_fields = self.join_fields or self.join_cols
|
|
for lhs, rhs in join_fields:
|
|
if isinstance(lhs, str):
|
|
# RemovedInDjango60Warning: when the depraction ends, remove
|
|
# the branch for strings.
|
|
lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
|
|
rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
|
|
else:
|
|
lhs, rhs = connection.ops.prepare_join_on_clause(
|
|
self.parent_alias, lhs, self.table_alias, rhs
|
|
)
|
|
lhs_sql, lhs_params = compiler.compile(lhs)
|
|
lhs_full_name = lhs_sql % lhs_params
|
|
rhs_sql, rhs_params = compiler.compile(rhs)
|
|
rhs_full_name = rhs_sql % rhs_params
|
|
join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")
|
|
|
|
# Add a single condition inside parentheses for whatever
|
|
# get_extra_restriction() returns.
|
|
extra_cond = self.join_field.get_extra_restriction(
|
|
self.table_alias, self.parent_alias
|
|
)
|
|
if extra_cond:
|
|
extra_sql, extra_params = compiler.compile(extra_cond)
|
|
join_conditions.append("(%s)" % extra_sql)
|
|
params.extend(extra_params)
|
|
if self.filtered_relation:
|
|
try:
|
|
extra_sql, extra_params = compiler.compile(self.filtered_relation)
|
|
except FullResultSet:
|
|
pass
|
|
else:
|
|
join_conditions.append("(%s)" % extra_sql)
|
|
params.extend(extra_params)
|
|
if not join_conditions:
|
|
# This might be a rel on the other end of an actual declared field.
|
|
declared_field = getattr(self.join_field, "field", self.join_field)
|
|
raise ValueError(
|
|
"Join generated an empty ON clause. %s did not yield either "
|
|
"joining columns or extra restrictions." % declared_field.__class__
|
|
)
|
|
on_clause_sql = " AND ".join(join_conditions)
|
|
alias_str = (
|
|
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
|
)
|
|
sql = "%s %s%s ON (%s)" % (
|
|
self.join_type,
|
|
qn(self.table_name),
|
|
alias_str,
|
|
on_clause_sql,
|
|
)
|
|
return sql, params
|
|
|
|
def relabeled_clone(self, change_map):
|
|
new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
|
|
new_table_alias = change_map.get(self.table_alias, self.table_alias)
|
|
if self.filtered_relation is not None:
|
|
filtered_relation = self.filtered_relation.relabeled_clone(change_map)
|
|
else:
|
|
filtered_relation = None
|
|
return self.__class__(
|
|
self.table_name,
|
|
new_parent_alias,
|
|
new_table_alias,
|
|
self.join_type,
|
|
self.join_field,
|
|
self.nullable,
|
|
filtered_relation=filtered_relation,
|
|
)
|
|
|
|
@property
|
|
def identity(self):
|
|
return (
|
|
self.__class__,
|
|
self.table_name,
|
|
self.parent_alias,
|
|
self.join_field,
|
|
self.filtered_relation,
|
|
)
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, Join):
|
|
return NotImplemented
|
|
return self.identity == other.identity
|
|
|
|
def __hash__(self):
|
|
return hash(self.identity)
|
|
|
|
def demote(self):
|
|
new = self.relabeled_clone({})
|
|
new.join_type = INNER
|
|
return new
|
|
|
|
def promote(self):
|
|
new = self.relabeled_clone({})
|
|
new.join_type = LOUTER
|
|
return new
|
|
|
|
|
|
class BaseTable:
|
|
"""
|
|
The BaseTable class is used for base table references in FROM clause. For
|
|
example, the SQL "foo" in
|
|
SELECT * FROM "foo" WHERE somecond
|
|
could be generated by this class.
|
|
"""
|
|
|
|
join_type = None
|
|
parent_alias = None
|
|
filtered_relation = None
|
|
|
|
def __init__(self, table_name, alias):
|
|
self.table_name = table_name
|
|
self.table_alias = alias
|
|
|
|
def as_sql(self, compiler, connection):
|
|
alias_str = (
|
|
"" if self.table_alias == self.table_name else (" %s" % self.table_alias)
|
|
)
|
|
base_sql = compiler.quote_name_unless_alias(self.table_name)
|
|
return base_sql + alias_str, []
|
|
|
|
def relabeled_clone(self, change_map):
|
|
return self.__class__(
|
|
self.table_name, change_map.get(self.table_alias, self.table_alias)
|
|
)
|
|
|
|
@property
|
|
def identity(self):
|
|
return self.__class__, self.table_name, self.table_alias
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, BaseTable):
|
|
return NotImplemented
|
|
return self.identity == other.identity
|
|
|
|
def __hash__(self):
|
|
return hash(self.identity)
|