enable all API tests for sqlite and port missing features

This commit is contained in:
Sarah Hoffmann
2023-12-06 20:56:21 +01:00
parent 0d840c8d4e
commit 6d39563b87
15 changed files with 514 additions and 230 deletions

View File

@@ -188,6 +188,7 @@ def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw
return "json_each(%s)" % compiler.process(element.clauses, **kw)
class Greatest(sa.sql.functions.GenericFunction[Any]):
""" Function to compute maximum of all its input parameters.
"""
@@ -198,3 +199,23 @@ class Greatest(sa.sql.functions.GenericFunction[Any]):
@compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str:
return "max(%s)" % compiler.process(element.clauses, **kw)
class RegexpWord(sa.sql.functions.GenericFunction[Any]):
""" Check if a full word is in a given string.
"""
name = 'RegexpWord'
inherit_cache = True
@compiles(RegexpWord, 'postgresql') # type: ignore[no-untyped-call, misc]
def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "%s ~* ('\\m(' || %s || ')\\M')::text" % (compiler.process(arg2, **kw), compiler.process(arg1, **kw))
@compiles(RegexpWord, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "regexp('\\b(' || %s || ')\\b', %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -57,22 +57,16 @@ class IntArray(sa.types.TypeDecorator[Any]):
""" Concate the array with the given array. If one of the
operants is null, the value of the other will be returned.
"""
return sa.func.array_cat(self, other, type_=IntArray)
return ArrayCat(self.expr, other)
def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
""" Return true if the array contains all the value of the argument
array.
"""
return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other))
return ArrayContains(self.expr, other)
def overlaps(self, other: SaColumn) -> 'sa.Operators':
""" Return true if at least one value of the argument is contained
in the array.
"""
return self.op('&&', is_comparison=True)(other)
class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
""" Aggregate function to collect elements in an array.
@@ -82,6 +76,48 @@ class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
name = 'array_agg'
inherit_cache = True
@compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
class ArrayContains(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another.
"""
name = 'ArrayContains'
inherit_cache = True
@compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s @> %s)" % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw))
@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_contains(%s)" % compiler.process(element.clauses, **kw)
class ArrayCat(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another.
"""
type = IntArray()
identifier = 'ArrayCat'
inherit_cache = True
@compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_cat(%s)" % compiler.process(element.clauses, **kw)
@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -10,6 +10,7 @@ A custom type that implements a simple key-value store of strings.
from typing import Any
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.sqlite import JSON as sqlite_json
@@ -37,11 +38,25 @@ class KeyValueStore(sa.types.TypeDecorator[Any]):
one, overwriting values where necessary. When the argument
is null, nothing happens.
"""
return self.op('||')(sa.func.coalesce(other,
sa.type_coerce('', KeyValueStore)))
return KeyValueConcat(self.expr, other)
class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
""" Return the merged key-value store from the input parameters.
"""
type = KeyValueStore()
name = 'JsonConcat'
inherit_cache = True
@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
def has_key(self, key: SaColumn) -> 'sa.Operators':
""" Return true if the key is cotained in the store.
"""
return self.op('?', is_comparison=True)(key)

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Custom functions for SQLite.
"""
from typing import cast, Optional, Set, Any
import json
# pylint: disable=protected-access
def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
""" Custom weight function for search results.
"""
if search_vector is not None:
svec = [int(x) for x in search_vector.split(',')]
for rank in json.loads(rankings):
if all(r in svec for r in rank[1]):
return cast(float, rank[0])
return default
class ArrayIntersectFuzzy:
""" Compute the array of common elements of all input integer arrays.
Very large input paramenters may be ignored to speed up
computation. Therefore, the result is a superset of common elements.
Input and output arrays are given as comma-separated lists.
"""
def __init__(self) -> None:
self.first = ''
self.values: Optional[Set[int]] = None
def step(self, value: Optional[str]) -> None:
""" Add the next array to the intersection.
"""
if value is not None:
if not self.first:
self.first = value
elif len(value) < 10000000:
if self.values is None:
self.values = {int(x) for x in self.first.split(',')}
self.values.intersection_update((int(x) for x in value.split(',')))
def finalize(self) -> str:
""" Return the final result.
"""
if self.values is not None:
return ','.join(map(str, self.values))
return self.first
class ArrayUnion:
""" Compute the set of all elements of the input integer arrays.
Input and output arrays are given as strings of comma-separated lists.
"""
def __init__(self) -> None:
self.values: Optional[Set[str]] = None
def step(self, value: Optional[str]) -> None:
""" Add the next array to the union.
"""
if value is not None:
if self.values is None:
self.values = set(value.split(','))
else:
self.values.update(value.split(','))
def finalize(self) -> str:
""" Return the final result.
"""
return '' if self.values is None else ','.join(self.values)
def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]:
""" Is the array 'containee' completely contained in array 'container'.
"""
if container is None or containee is None:
return None
vset = container.split(',')
return all(v in vset for v in containee.split(','))
def array_pair_contains(container1: Optional[str], container2: Optional[str],
containee: Optional[str]) -> Optional[bool]:
""" Is the array 'containee' completely contained in the union of
array 'container1' and array 'container2'.
"""
if container1 is None or container2 is None or containee is None:
return None
vset = container1.split(',') + container2.split(',')
return all(v in vset for v in containee.split(','))
def install_custom_functions(conn: Any) -> None:
""" Install helper functions for Nominatim into the given SQLite
database connection.
"""
conn.create_function('weigh_search', 3, weigh_search, deterministic=True)
conn.create_function('array_contains', 2, array_contains, deterministic=True)
conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True)
_create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy)
_create_aggregate(conn, 'array_union', 1, ArrayUnion)
async def _make_aggregate(aioconn: Any, *args: Any) -> None:
await aioconn._execute(aioconn._conn.create_aggregate, *args)
def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
try:
conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
except Exception as error: # pylint: disable=broad-exception-caught
conn._handle_exception(error)