Files
california-equity-git/.venv/lib/python3.12/site-packages/geoalchemy2/admin/dialects/mysql.py
2024-09-28 23:12:43 -07:00

201 lines
6.7 KiB
Python

"""This module defines specific functions for MySQL dialect."""
from sqlalchemy import text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.sqltypes import NullType
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
_POSSIBLE_TYPES = [
"geometry",
"point",
"linestring",
"polygon",
"multipoint",
"multilinestring",
"multipolygon",
"geometrycollection",
]
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a column of type Geometry with Postgresql dialect."""
if not isinstance(column_info.get("type"), (Geometry, NullType)):
return
column_name = column_info.get("name")
schema = table.schema or inspector.default_schema_name
# Check geometry type, SRID and if the column is nullable
geometry_type_query = """SELECT DATA_TYPE, SRS_ID, IS_NULLABLE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
table.name, column_name
)
if schema is not None:
geometry_type_query += """ and table_schema = '{}'""".format(schema)
geometry_type, srid, nullable_str = inspector.bind.execute(text(geometry_type_query)).one()
is_nullable = str(nullable_str).lower() == "yes"
if geometry_type not in _POSSIBLE_TYPES:
return
# Check if the column has spatial index
has_index_query = """SELECT DISTINCT
INDEX_TYPE
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
table.name, column_name
)
if schema is not None:
has_index_query += """ and TABLE_SCHEMA = '{}'""".format(schema)
spatial_index_res = inspector.bind.execute(text(has_index_query)).scalar()
spatial_index = str(spatial_index_res).lower() == "spatial"
# Set attributes
column_info["type"] = Geometry(
geometry_type=geometry_type.upper(),
srid=srid,
spatial_index=spatial_index,
nullable=is_nullable,
_spatial_index_reflected=True,
)
def before_create(table, bind, **kw):
"""Handle spatial indexes during the before_create event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
# Remove the spatial indexes from the table metadata because they should not be
# created during the table.create() step since the associated columns do not exist
# at this time.
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values():
table.indexes.remove(idx)
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
col.type, "spatial_index", False
):
table.info["_after_create_indexes"].append(idx)
table.columns = table.info.pop("_saved_columns")
def after_create(table, bind, **kw):
"""Handle spatial indexes during the after_create event."""
# Restore original column list including managed Geometry columns
dialect = bind.dialect
# table.columns = table.info.pop("_saved_columns")
for col in table.columns:
# Add spatial indices for the Geometry and Geography columns
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
# If the index does not exist, define it and create it
if not [i for i in table.indexes if col in i.columns.values()]:
sql = "ALTER TABLE {} ADD SPATIAL INDEX({});".format(table.name, col.name)
q = text(sql)
bind.execute(q)
for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
def before_drop(table, bind, **kw):
return
def after_drop(table, bind, **kw):
return
_MYSQL_FUNCTIONS = {
"ST_AsEWKB": "ST_AsBinary",
}
def _compiles_mysql(cls, fn):
def _compile_mysql(element, compiler, **kw):
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
compiles(getattr(functions, cls), "mysql")(_compile_mysql)
compiles(getattr(functions, cls), "mariadb")(_compile_mysql)
def register_mysql_mapping(mapping):
"""Register compilation mappings for the given functions.
Args:
mapping: Should have the following form::
{
"function_name_1": "mysql_function_name_1",
"function_name_2": "mysql_function_name_2",
...
}
"""
for cls, fn in mapping.items():
_compiles_mysql(cls, fn)
register_mysql_mapping(_MYSQL_FUNCTIONS)
def _compile_GeomFromText_MySql(element, compiler, **kw):
element.identifier = "ST_GeomFromText"
compiled = compiler.process(element.clauses, **kw)
srid = element.type.srid
if srid > 0:
return "{}({}, {})".format(element.identifier, compiled, srid)
else:
return "{}({})".format(element.identifier, compiled)
def _compile_GeomFromWKB_MySql(element, compiler, **kw):
element.identifier = "ST_GeomFromWKB"
wkb_data = list(element.clauses)[0].value
if isinstance(wkb_data, memoryview):
list(element.clauses)[0].value = wkb_data.tobytes()
compiled = compiler.process(element.clauses, **kw)
srid = element.type.srid
if srid > 0:
return "{}({}, {})".format(element.identifier, compiled, srid)
else:
return "{}({})".format(element.identifier, compiled)
@compiles(functions.ST_GeomFromText, "mysql") # type: ignore
@compiles(functions.ST_GeomFromText, "mariadb") # type: ignore
def _MySQL_ST_GeomFromText(element, compiler, **kw):
return _compile_GeomFromText_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKT, "mysql") # type: ignore
@compiles(functions.ST_GeomFromEWKT, "mariadb") # type: ignore
def _MySQL_ST_GeomFromEWKT(element, compiler, **kw):
return _compile_GeomFromText_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromWKB, "mysql") # type: ignore
@compiles(functions.ST_GeomFromWKB, "mariadb") # type: ignore
def _MySQL_ST_GeomFromWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MySql(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKB, "mysql") # type: ignore
@compiles(functions.ST_GeomFromEWKB, "mariadb") # type: ignore
def _MySQL_ST_GeomFromEWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MySql(element, compiler, **kw)