gonna figure it out
This commit is contained in:
@@ -0,0 +1,200 @@
|
||||
"""This module defines specific functions for MySQL dialect."""
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.sqltypes import NullType
|
||||
|
||||
from geoalchemy2 import functions
|
||||
from geoalchemy2.admin.dialects.common import _check_spatial_type
|
||||
from geoalchemy2.admin.dialects.common import _spatial_idx_name
|
||||
from geoalchemy2.admin.dialects.common import setup_create_drop
|
||||
from geoalchemy2.types import Geography
|
||||
from geoalchemy2.types import Geometry
|
||||
|
||||
_POSSIBLE_TYPES = [
|
||||
"geometry",
|
||||
"point",
|
||||
"linestring",
|
||||
"polygon",
|
||||
"multipoint",
|
||||
"multilinestring",
|
||||
"multipolygon",
|
||||
"geometrycollection",
|
||||
]
|
||||
|
||||
|
||||
def reflect_geometry_column(inspector, table, column_info):
|
||||
"""Reflect a column of type Geometry with Postgresql dialect."""
|
||||
if not isinstance(column_info.get("type"), (Geometry, NullType)):
|
||||
return
|
||||
|
||||
column_name = column_info.get("name")
|
||||
schema = table.schema or inspector.default_schema_name
|
||||
|
||||
# Check geometry type, SRID and if the column is nullable
|
||||
geometry_type_query = """SELECT DATA_TYPE, SRS_ID, IS_NULLABLE
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
|
||||
table.name, column_name
|
||||
)
|
||||
if schema is not None:
|
||||
geometry_type_query += """ and table_schema = '{}'""".format(schema)
|
||||
geometry_type, srid, nullable_str = inspector.bind.execute(text(geometry_type_query)).one()
|
||||
is_nullable = str(nullable_str).lower() == "yes"
|
||||
|
||||
if geometry_type not in _POSSIBLE_TYPES:
|
||||
return
|
||||
|
||||
# Check if the column has spatial index
|
||||
has_index_query = """SELECT DISTINCT
|
||||
INDEX_TYPE
|
||||
FROM INFORMATION_SCHEMA.STATISTICS
|
||||
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
|
||||
table.name, column_name
|
||||
)
|
||||
if schema is not None:
|
||||
has_index_query += """ and TABLE_SCHEMA = '{}'""".format(schema)
|
||||
spatial_index_res = inspector.bind.execute(text(has_index_query)).scalar()
|
||||
spatial_index = str(spatial_index_res).lower() == "spatial"
|
||||
|
||||
# Set attributes
|
||||
column_info["type"] = Geometry(
|
||||
geometry_type=geometry_type.upper(),
|
||||
srid=srid,
|
||||
spatial_index=spatial_index,
|
||||
nullable=is_nullable,
|
||||
_spatial_index_reflected=True,
|
||||
)
|
||||
|
||||
|
||||
def before_create(table, bind, **kw):
|
||||
"""Handle spatial indexes during the before_create event."""
|
||||
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
|
||||
|
||||
# Remove the spatial indexes from the table metadata because they should not be
|
||||
# created during the table.create() step since the associated columns do not exist
|
||||
# at this time.
|
||||
table.info["_after_create_indexes"] = []
|
||||
current_indexes = set(table.indexes)
|
||||
for idx in current_indexes:
|
||||
for col in table.info["_saved_columns"]:
|
||||
if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values():
|
||||
table.indexes.remove(idx)
|
||||
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
|
||||
col.type, "spatial_index", False
|
||||
):
|
||||
table.info["_after_create_indexes"].append(idx)
|
||||
|
||||
table.columns = table.info.pop("_saved_columns")
|
||||
|
||||
|
||||
def after_create(table, bind, **kw):
|
||||
"""Handle spatial indexes during the after_create event."""
|
||||
# Restore original column list including managed Geometry columns
|
||||
dialect = bind.dialect
|
||||
|
||||
# table.columns = table.info.pop("_saved_columns")
|
||||
|
||||
for col in table.columns:
|
||||
# Add spatial indices for the Geometry and Geography columns
|
||||
if (
|
||||
_check_spatial_type(col.type, (Geometry, Geography), dialect)
|
||||
and col.type.spatial_index is True
|
||||
):
|
||||
# If the index does not exist, define it and create it
|
||||
if not [i for i in table.indexes if col in i.columns.values()]:
|
||||
sql = "ALTER TABLE {} ADD SPATIAL INDEX({});".format(table.name, col.name)
|
||||
q = text(sql)
|
||||
bind.execute(q)
|
||||
|
||||
for idx in table.info.pop("_after_create_indexes"):
|
||||
table.indexes.add(idx)
|
||||
|
||||
|
||||
def before_drop(table, bind, **kw):
|
||||
return
|
||||
|
||||
|
||||
def after_drop(table, bind, **kw):
|
||||
return
|
||||
|
||||
|
||||
_MYSQL_FUNCTIONS = {
|
||||
"ST_AsEWKB": "ST_AsBinary",
|
||||
}
|
||||
|
||||
|
||||
def _compiles_mysql(cls, fn):
|
||||
def _compile_mysql(element, compiler, **kw):
|
||||
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
|
||||
|
||||
compiles(getattr(functions, cls), "mysql")(_compile_mysql)
|
||||
compiles(getattr(functions, cls), "mariadb")(_compile_mysql)
|
||||
|
||||
|
||||
def register_mysql_mapping(mapping):
|
||||
"""Register compilation mappings for the given functions.
|
||||
|
||||
Args:
|
||||
mapping: Should have the following form::
|
||||
|
||||
{
|
||||
"function_name_1": "mysql_function_name_1",
|
||||
"function_name_2": "mysql_function_name_2",
|
||||
...
|
||||
}
|
||||
"""
|
||||
for cls, fn in mapping.items():
|
||||
_compiles_mysql(cls, fn)
|
||||
|
||||
|
||||
register_mysql_mapping(_MYSQL_FUNCTIONS)
|
||||
|
||||
|
||||
def _compile_GeomFromText_MySql(element, compiler, **kw):
|
||||
element.identifier = "ST_GeomFromText"
|
||||
compiled = compiler.process(element.clauses, **kw)
|
||||
srid = element.type.srid
|
||||
|
||||
if srid > 0:
|
||||
return "{}({}, {})".format(element.identifier, compiled, srid)
|
||||
else:
|
||||
return "{}({})".format(element.identifier, compiled)
|
||||
|
||||
|
||||
def _compile_GeomFromWKB_MySql(element, compiler, **kw):
|
||||
element.identifier = "ST_GeomFromWKB"
|
||||
wkb_data = list(element.clauses)[0].value
|
||||
if isinstance(wkb_data, memoryview):
|
||||
list(element.clauses)[0].value = wkb_data.tobytes()
|
||||
compiled = compiler.process(element.clauses, **kw)
|
||||
srid = element.type.srid
|
||||
|
||||
if srid > 0:
|
||||
return "{}({}, {})".format(element.identifier, compiled, srid)
|
||||
else:
|
||||
return "{}({})".format(element.identifier, compiled)
|
||||
|
||||
|
||||
@compiles(functions.ST_GeomFromText, "mysql") # type: ignore
|
||||
@compiles(functions.ST_GeomFromText, "mariadb") # type: ignore
|
||||
def _MySQL_ST_GeomFromText(element, compiler, **kw):
|
||||
return _compile_GeomFromText_MySql(element, compiler, **kw)
|
||||
|
||||
|
||||
@compiles(functions.ST_GeomFromEWKT, "mysql") # type: ignore
|
||||
@compiles(functions.ST_GeomFromEWKT, "mariadb") # type: ignore
|
||||
def _MySQL_ST_GeomFromEWKT(element, compiler, **kw):
|
||||
return _compile_GeomFromText_MySql(element, compiler, **kw)
|
||||
|
||||
|
||||
@compiles(functions.ST_GeomFromWKB, "mysql") # type: ignore
|
||||
@compiles(functions.ST_GeomFromWKB, "mariadb") # type: ignore
|
||||
def _MySQL_ST_GeomFromWKB(element, compiler, **kw):
|
||||
return _compile_GeomFromWKB_MySql(element, compiler, **kw)
|
||||
|
||||
|
||||
@compiles(functions.ST_GeomFromEWKB, "mysql") # type: ignore
|
||||
@compiles(functions.ST_GeomFromEWKB, "mariadb") # type: ignore
|
||||
def _MySQL_ST_GeomFromEWKB(element, compiler, **kw):
|
||||
return _compile_GeomFromWKB_MySql(element, compiler, **kw)
|
||||
Reference in New Issue
Block a user