improve typing for @compiles constructs

The first parameter is in fact the self parameter referring to
the function class.
This commit is contained in:
Sarah Hoffmann
2023-12-06 11:13:12 +01:00
parent df6eddebcd
commit b5c61e0b5b
2 changed files with 25 additions and 25 deletions

View File

@@ -29,7 +29,7 @@ class PlacexGeometryReverseLookuppolygon(sa.sql.functions.GenericFunction[Any]):
@compiles(PlacexGeometryReverseLookuppolygon) # type: ignore[no-untyped-call, misc]
def _default_intersects(element: SaColumn,
def _default_intersects(element: PlacexGeometryReverseLookuppolygon,
compiler: 'sa.Compiled', **kw: Any) -> str:
return ("(ST_GeometryType(placex.geometry) in ('ST_Polygon', 'ST_MultiPolygon')"
" AND placex.rank_address between 4 and 25"
@@ -40,7 +40,7 @@ def _default_intersects(element: SaColumn,
@compiles(PlacexGeometryReverseLookuppolygon, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_intersects(element: SaColumn,
def _sqlite_intersects(element: PlacexGeometryReverseLookuppolygon,
compiler: 'sa.Compiled', **kw: Any) -> str:
return ("(ST_GeometryType(placex.geometry) in ('POLYGON', 'MULTIPOLYGON')"
" AND placex.rank_address between 4 and 25"
@@ -61,7 +61,7 @@ class IntersectsReverseDistance(sa.sql.functions.GenericFunction[Any]):
@compiles(IntersectsReverseDistance) # type: ignore[no-untyped-call, misc]
def default_reverse_place_diameter(element: SaColumn,
def default_reverse_place_diameter(element: IntersectsReverseDistance,
compiler: 'sa.Compiled', **kw: Any) -> str:
table = element.tablename
return f"({table}.rank_address between 4 and 25"\
@@ -74,7 +74,7 @@ def default_reverse_place_diameter(element: SaColumn,
@compiles(IntersectsReverseDistance, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_reverse_place_diameter(element: SaColumn,
def sqlite_reverse_place_diameter(element: IntersectsReverseDistance,
compiler: 'sa.Compiled', **kw: Any) -> str:
geom1, rank, geom2 = list(element.clauses)
table = element.tablename
@@ -102,7 +102,7 @@ class IsBelowReverseDistance(sa.sql.functions.GenericFunction[Any]):
@compiles(IsBelowReverseDistance) # type: ignore[no-untyped-call, misc]
def default_is_below_reverse_distance(element: SaColumn,
def default_is_below_reverse_distance(element: IsBelowReverseDistance,
compiler: 'sa.Compiled', **kw: Any) -> str:
dist, rank = list(element.clauses)
return "%s < reverse_place_diameter(%s)" % (compiler.process(dist, **kw),
@@ -110,7 +110,7 @@ def default_is_below_reverse_distance(element: SaColumn,
@compiles(IsBelowReverseDistance, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_is_below_reverse_distance(element: SaColumn,
def sqlite_is_below_reverse_distance(element: IsBelowReverseDistance,
compiler: 'sa.Compiled', **kw: Any) -> str:
dist, rank = list(element.clauses)
return "%s < 14.0 * exp(-0.2 * %s) - 0.03" % (compiler.process(dist, **kw),
@@ -139,7 +139,7 @@ class IsAddressPoint(sa.sql.functions.GenericFunction[Any]):
@compiles(IsAddressPoint) # type: ignore[no-untyped-call, misc]
def default_is_address_point(element: SaColumn,
def default_is_address_point(element: IsAddressPoint,
compiler: 'sa.Compiled', **kw: Any) -> str:
rank, hnr, name = list(element.clauses)
return "(%s = 30 AND (%s IS NOT NULL OR %s ? 'addr:housename'))" % (
@@ -149,7 +149,7 @@ def default_is_address_point(element: SaColumn,
@compiles(IsAddressPoint, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_is_address_point(element: SaColumn,
def sqlite_is_address_point(element: IsAddressPoint,
compiler: 'sa.Compiled', **kw: Any) -> str:
rank, hnr, name = list(element.clauses)
return "(%s = 30 AND coalesce(%s, json_extract(%s, '$.addr:housename')) IS NOT NULL)" % (
@@ -166,7 +166,7 @@ class CrosscheckNames(sa.sql.functions.GenericFunction[Any]):
inherit_cache = True
@compiles(CrosscheckNames) # type: ignore[no-untyped-call, misc]
def compile_crosscheck_names(element: SaColumn,
def compile_crosscheck_names(element: CrosscheckNames,
compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "coalesce(avals(%s) && ARRAY(SELECT * FROM json_array_elements_text(%s)), false)" % (
@@ -174,7 +174,7 @@ def compile_crosscheck_names(element: SaColumn,
@compiles(CrosscheckNames, 'sqlite') # type: ignore[no-untyped-call, misc]
def compile_sqlite_crosscheck_names(element: SaColumn,
def compile_sqlite_crosscheck_names(element: CrosscheckNames,
compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "EXISTS(SELECT *"\
@@ -191,12 +191,12 @@ class JsonArrayEach(sa.sql.functions.GenericFunction[Any]):
@compiles(JsonArrayEach) # type: ignore[no-untyped-call, misc]
def default_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
def default_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw: Any) -> str:
return "json_array_elements(%s)" % compiler.process(element.clauses, **kw)
@compiles(JsonArrayEach, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_json_array_each(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw: Any) -> str:
return "json_each(%s)" % compiler.process(element.clauses, **kw)
@@ -208,5 +208,5 @@ class Greatest(sa.sql.functions.GenericFunction[Any]):
@compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_greatest(element: SaColumn, compiler: 'sa.Compiled', **kw: Any) -> str:
def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str:
return "max(%s)" % compiler.process(element.clauses, **kw)