forked from hans/Nominatim
enable all API tests for sqlite and port missing features
This commit is contained in:
@@ -188,6 +188,7 @@ def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw
|
||||
return "json_each(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
|
||||
class Greatest(sa.sql.functions.GenericFunction[Any]):
|
||||
""" Function to compute maximum of all its input parameters.
|
||||
"""
|
||||
@@ -198,3 +199,23 @@ class Greatest(sa.sql.functions.GenericFunction[Any]):
|
||||
@compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "max(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
|
||||
class RegexpWord(sa.sql.functions.GenericFunction[Any]):
|
||||
""" Check if a full word is in a given string.
|
||||
"""
|
||||
name = 'RegexpWord'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(RegexpWord, 'postgresql') # type: ignore[no-untyped-call, misc]
|
||||
def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "%s ~* ('\\m(' || %s || ')\\M')::text" % (compiler.process(arg2, **kw), compiler.process(arg1, **kw))
|
||||
|
||||
|
||||
@compiles(RegexpWord, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "regexp('\\b(' || %s || ')\\b', %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
@@ -57,22 +57,16 @@ class IntArray(sa.types.TypeDecorator[Any]):
|
||||
""" Concate the array with the given array. If one of the
|
||||
operants is null, the value of the other will be returned.
|
||||
"""
|
||||
return sa.func.array_cat(self, other, type_=IntArray)
|
||||
return ArrayCat(self.expr, other)
|
||||
|
||||
|
||||
def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
|
||||
""" Return true if the array contains all the value of the argument
|
||||
array.
|
||||
"""
|
||||
return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other))
|
||||
return ArrayContains(self.expr, other)
|
||||
|
||||
|
||||
def overlaps(self, other: SaColumn) -> 'sa.Operators':
|
||||
""" Return true if at least one value of the argument is contained
|
||||
in the array.
|
||||
"""
|
||||
return self.op('&&', is_comparison=True)(other)
|
||||
|
||||
|
||||
class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
|
||||
""" Aggregate function to collect elements in an array.
|
||||
@@ -82,6 +76,48 @@ class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
|
||||
name = 'array_agg'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
|
||||
class ArrayContains(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Function to check if an array is fully contained in another.
|
||||
"""
|
||||
name = 'ArrayContains'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
|
||||
def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "(%s @> %s)" % (compiler.process(arg1, **kw),
|
||||
compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "array_contains(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
|
||||
class ArrayCat(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Function to check if an array is fully contained in another.
|
||||
"""
|
||||
type = IntArray()
|
||||
identifier = 'ArrayCat'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
|
||||
def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "array_cat(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ A custom type that implements a simple key-value store of strings.
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import HSTORE
|
||||
from sqlalchemy.dialects.sqlite import JSON as sqlite_json
|
||||
|
||||
@@ -37,11 +38,25 @@ class KeyValueStore(sa.types.TypeDecorator[Any]):
|
||||
one, overwriting values where necessary. When the argument
|
||||
is null, nothing happens.
|
||||
"""
|
||||
return self.op('||')(sa.func.coalesce(other,
|
||||
sa.type_coerce('', KeyValueStore)))
|
||||
return KeyValueConcat(self.expr, other)
|
||||
|
||||
|
||||
class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Return the merged key-value store from the input parameters.
|
||||
"""
|
||||
type = KeyValueStore()
|
||||
name = 'JsonConcat'
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
|
||||
def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
|
||||
def has_key(self, key: SaColumn) -> 'sa.Operators':
|
||||
""" Return true if the key is cotained in the store.
|
||||
"""
|
||||
return self.op('?', is_comparison=True)(key)
|
||||
|
||||
122
nominatim/db/sqlite_functions.py
Normal file
122
nominatim/db/sqlite_functions.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2023 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Custom functions for SQLite.
|
||||
"""
|
||||
from typing import cast, Optional, Set, Any
|
||||
import json
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
|
||||
""" Custom weight function for search results.
|
||||
"""
|
||||
if search_vector is not None:
|
||||
svec = [int(x) for x in search_vector.split(',')]
|
||||
for rank in json.loads(rankings):
|
||||
if all(r in svec for r in rank[1]):
|
||||
return cast(float, rank[0])
|
||||
|
||||
return default
|
||||
|
||||
|
||||
class ArrayIntersectFuzzy:
|
||||
""" Compute the array of common elements of all input integer arrays.
|
||||
Very large input paramenters may be ignored to speed up
|
||||
computation. Therefore, the result is a superset of common elements.
|
||||
|
||||
Input and output arrays are given as comma-separated lists.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.first = ''
|
||||
self.values: Optional[Set[int]] = None
|
||||
|
||||
def step(self, value: Optional[str]) -> None:
|
||||
""" Add the next array to the intersection.
|
||||
"""
|
||||
if value is not None:
|
||||
if not self.first:
|
||||
self.first = value
|
||||
elif len(value) < 10000000:
|
||||
if self.values is None:
|
||||
self.values = {int(x) for x in self.first.split(',')}
|
||||
self.values.intersection_update((int(x) for x in value.split(',')))
|
||||
|
||||
def finalize(self) -> str:
|
||||
""" Return the final result.
|
||||
"""
|
||||
if self.values is not None:
|
||||
return ','.join(map(str, self.values))
|
||||
|
||||
return self.first
|
||||
|
||||
|
||||
class ArrayUnion:
|
||||
""" Compute the set of all elements of the input integer arrays.
|
||||
|
||||
Input and output arrays are given as strings of comma-separated lists.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.values: Optional[Set[str]] = None
|
||||
|
||||
def step(self, value: Optional[str]) -> None:
|
||||
""" Add the next array to the union.
|
||||
"""
|
||||
if value is not None:
|
||||
if self.values is None:
|
||||
self.values = set(value.split(','))
|
||||
else:
|
||||
self.values.update(value.split(','))
|
||||
|
||||
def finalize(self) -> str:
|
||||
""" Return the final result.
|
||||
"""
|
||||
return '' if self.values is None else ','.join(self.values)
|
||||
|
||||
|
||||
def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]:
|
||||
""" Is the array 'containee' completely contained in array 'container'.
|
||||
"""
|
||||
if container is None or containee is None:
|
||||
return None
|
||||
|
||||
vset = container.split(',')
|
||||
return all(v in vset for v in containee.split(','))
|
||||
|
||||
|
||||
def array_pair_contains(container1: Optional[str], container2: Optional[str],
|
||||
containee: Optional[str]) -> Optional[bool]:
|
||||
""" Is the array 'containee' completely contained in the union of
|
||||
array 'container1' and array 'container2'.
|
||||
"""
|
||||
if container1 is None or container2 is None or containee is None:
|
||||
return None
|
||||
|
||||
vset = container1.split(',') + container2.split(',')
|
||||
return all(v in vset for v in containee.split(','))
|
||||
|
||||
|
||||
def install_custom_functions(conn: Any) -> None:
|
||||
""" Install helper functions for Nominatim into the given SQLite
|
||||
database connection.
|
||||
"""
|
||||
conn.create_function('weigh_search', 3, weigh_search, deterministic=True)
|
||||
conn.create_function('array_contains', 2, array_contains, deterministic=True)
|
||||
conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True)
|
||||
_create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy)
|
||||
_create_aggregate(conn, 'array_union', 1, ArrayUnion)
|
||||
|
||||
|
||||
async def _make_aggregate(aioconn: Any, *args: Any) -> None:
|
||||
await aioconn._execute(aioconn._conn.create_aggregate, *args)
|
||||
|
||||
|
||||
def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
|
||||
try:
|
||||
conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
|
||||
except Exception as error: # pylint: disable=broad-exception-caught
|
||||
conn._handle_exception(error)
|
||||
Reference in New Issue
Block a user