split code into submodules

This commit is contained in:
Sarah Hoffmann
2024-05-16 11:55:17 +02:00
parent 0fb4fe8e4d
commit 6e89310a92
137 changed files with 757 additions and 716 deletions

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Module for forward search.
"""
# pylint: disable=useless-import-alias
from .geocoder import (ForwardGeocoder as ForwardGeocoder)
from .query import (Phrase as Phrase,
PhraseType as PhraseType)
from .query_analyzer_factory import (make_query_analyzer as make_query_analyzer)

View File

@@ -0,0 +1,459 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Conversion from token assignment to an abstract DB search.
"""
from typing import Optional, List, Tuple, Iterator, Dict
import heapq
from ..types import SearchDetails, DataLayer
from .query import QueryStruct, Token, TokenType, TokenRange, BreakType
from .token_assignment import TokenAssignment
from . import db_search_fields as dbf
from . import db_searches as dbs
from . import db_search_lookups as lookups
def wrap_near_search(categories: List[Tuple[str, str]],
search: dbs.AbstractSearch) -> dbs.NearSearch:
""" Create a new search that wraps the given search in a search
for near places of the given category.
"""
return dbs.NearSearch(penalty=search.penalty,
categories=dbf.WeightedCategories(categories,
[0.0] * len(categories)),
search=search)
def build_poi_search(category: List[Tuple[str, str]],
countries: Optional[List[str]]) -> dbs.PoiSearch:
""" Create a new search for places by the given category, possibly
constraint to the given countries.
"""
if countries:
ccs = dbf.WeightedStrings(countries, [0.0] * len(countries))
else:
ccs = dbf.WeightedStrings([], [])
class _PoiData(dbf.SearchData):
penalty = 0.0
qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
countries=ccs
return dbs.PoiSearch(_PoiData())
class SearchBuilder:
""" Build the abstract search queries from token assignments.
"""
def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
self.query = query
self.details = details
@property
def configured_for_country(self) -> bool:
""" Return true if the search details are configured to
allow countries in the result.
"""
return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
and self.details.layer_enabled(DataLayer.ADDRESS)
@property
def configured_for_postcode(self) -> bool:
""" Return true if the search details are configured to
allow postcodes in the result.
"""
return self.details.min_rank <= 5 and self.details.max_rank >= 11\
and self.details.layer_enabled(DataLayer.ADDRESS)
@property
def configured_for_housenumbers(self) -> bool:
""" Return true if the search details are configured to
allow addresses in the result.
"""
return self.details.max_rank >= 30 \
and self.details.layer_enabled(DataLayer.ADDRESS)
def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
""" Yield all possible abstract searches for the given token assignment.
"""
sdata = self.get_search_data(assignment)
if sdata is None:
return
near_items = self.get_near_items(assignment)
if near_items is not None and not near_items:
return # impossible compbination of near items and category parameter
if assignment.name is None:
if near_items and not sdata.postcodes:
sdata.qualifiers = near_items
near_items = None
builder = self.build_poi_search(sdata)
elif assignment.housenumber:
hnr_tokens = self.query.get_tokens(assignment.housenumber,
TokenType.HOUSENUMBER)
builder = self.build_housenumber_search(sdata, hnr_tokens, assignment.address)
else:
builder = self.build_special_search(sdata, assignment.address,
bool(near_items))
else:
builder = self.build_name_search(sdata, assignment.name, assignment.address,
bool(near_items))
if near_items:
penalty = min(near_items.penalties)
near_items.penalties = [p - penalty for p in near_items.penalties]
for search in builder:
search_penalty = search.penalty
search.penalty = 0.0
yield dbs.NearSearch(penalty + assignment.penalty + search_penalty,
near_items, search)
else:
for search in builder:
search.penalty += assignment.penalty
yield search
def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search query for a simple category search.
This kind of search requires an additional geographic constraint.
"""
if not sdata.housenumbers \
and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
yield dbs.PoiSearch(sdata)
def build_special_search(self, sdata: dbf.SearchData,
address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search queries for searches that do not involve
a named place.
"""
if sdata.qualifiers:
# No special searches over qualifiers supported.
return
if sdata.countries and not address and not sdata.postcodes \
and self.configured_for_country:
yield dbs.CountrySearch(sdata)
if sdata.postcodes and (is_category or self.configured_for_postcode):
penalty = 0.0 if sdata.countries else 0.1
if address:
sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
[t.token for r in address
for t in self.query.get_partials_list(r)],
lookups.Restrict)]
penalty += 0.2
yield dbs.PostcodeSearch(penalty, sdata)
def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
""" Build a simple address search for special entries where the
housenumber is the main name token.
"""
sdata.lookups = [dbf.FieldLookup('name_vector', [t.token for t in hnrs], lookups.LookupAny)]
expected_count = sum(t.count for t in hnrs)
partials = {t.token: t.addr_count for trange in address
for t in self.query.get_partials_list(trange)}
if expected_count < 8000:
sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
list(partials), lookups.Restrict))
elif len(partials) != 1 or list(partials.values())[0] < 10000:
sdata.lookups.append(dbf.FieldLookup('nameaddress_vector',
list(partials), lookups.LookupAll))
else:
addr_fulls = [t.token for t
in self.query.get_tokens(address[0], TokenType.WORD)]
if len(addr_fulls) > 5:
return
sdata.lookups.append(
dbf.FieldLookup('nameaddress_vector', addr_fulls, lookups.LookupAny))
sdata.housenumbers = dbf.WeightedStrings([], [])
yield dbs.PlaceSearch(0.05, sdata, expected_count)
def build_name_search(self, sdata: dbf.SearchData,
name: TokenRange, address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search queries for simple name or address searches.
"""
if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
ranking = self.get_name_ranking(name)
name_penalty = ranking.normalize_penalty()
if ranking.rankings:
sdata.rankings.append(ranking)
for penalty, count, lookup in self.yield_lookups(name, address):
sdata.lookups = lookup
yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
-> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
""" Yield all variants how the given name and address should best
be searched for. This takes into account how frequent the terms
are and tries to find a lookup that optimizes index use.
"""
penalty = 0.0 # extra penalty
name_partials = {t.token: t for t in self.query.get_partials_list(name)}
addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
addr_tokens = list({t.token for t in addr_partials})
partials_indexed = all(t.is_indexed for t in name_partials.values()) \
and all(t.is_indexed for t in addr_partials)
exp_count = min(t.count for t in name_partials.values()) / (2**(len(name_partials) - 1))
if (len(name_partials) > 3 or exp_count < 8000) and partials_indexed:
yield penalty, exp_count, dbf.lookup_by_names(list(name_partials.keys()), addr_tokens)
return
addr_count = min(t.addr_count for t in addr_partials) if addr_partials else 30000
# Partial term to frequent. Try looking up by rare full names first.
name_fulls = self.query.get_tokens(name, TokenType.WORD)
if name_fulls:
fulls_count = sum(t.count for t in name_fulls)
if partials_indexed:
penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
if fulls_count < 50000 or addr_count < 30000:
yield penalty,fulls_count / (2**len(addr_tokens)), \
self.get_full_name_ranking(name_fulls, addr_partials,
fulls_count > 30000 / max(1, len(addr_tokens)))
# To catch remaining results, lookup by name and address
# We only do this if there is a reasonable number of results expected.
exp_count = exp_count / (2**len(addr_tokens)) if addr_tokens else exp_count
if exp_count < 10000 and addr_count < 20000\
and all(t.is_indexed for t in name_partials.values()):
penalty += 0.35 * max(1 if name_fulls else 0.1,
5 - len(name_partials) - len(addr_tokens))
yield penalty, exp_count,\
self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
def get_name_address_ranking(self, name_tokens: List[int],
addr_partials: List[Token]) -> List[dbf.FieldLookup]:
""" Create a ranking expression looking up by name and address.
"""
lookup = [dbf.FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
addr_restrict_tokens = []
addr_lookup_tokens = []
for t in addr_partials:
if t.is_indexed:
if t.addr_count > 20000:
addr_restrict_tokens.append(t.token)
else:
addr_lookup_tokens.append(t.token)
if addr_restrict_tokens:
lookup.append(dbf.FieldLookup('nameaddress_vector',
addr_restrict_tokens, lookups.Restrict))
if addr_lookup_tokens:
lookup.append(dbf.FieldLookup('nameaddress_vector',
addr_lookup_tokens, lookups.LookupAll))
return lookup
def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
use_lookup: bool) -> List[dbf.FieldLookup]:
""" Create a ranking expression with full name terms and
additional address lookup. When 'use_lookup' is true, then
address lookups will use the index, when the occurences are not
too many.
"""
# At this point drop unindexed partials from the address.
# This might yield wrong results, nothing we can do about that.
if use_lookup:
addr_restrict_tokens = []
addr_lookup_tokens = []
for t in addr_partials:
if t.is_indexed:
if t.addr_count > 20000:
addr_restrict_tokens.append(t.token)
else:
addr_lookup_tokens.append(t.token)
else:
addr_restrict_tokens = [t.token for t in addr_partials if t.is_indexed]
addr_lookup_tokens = []
return dbf.lookup_by_any_name([t.token for t in name_fulls],
addr_restrict_tokens, addr_lookup_tokens)
def get_name_ranking(self, trange: TokenRange,
db_field: str = 'name_vector') -> dbf.FieldRanking:
""" Create a ranking expression for a name term in the given range.
"""
name_fulls = self.query.get_tokens(trange, TokenType.WORD)
ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
ranks.sort(key=lambda r: r.penalty)
# Fallback, sum of penalty for partials
name_partials = self.query.get_partials_list(trange)
default = sum(t.penalty for t in name_partials) + 0.2
return dbf.FieldRanking(db_field, default, ranks)
def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
""" Create a list of ranking expressions for an address term
for the given ranges.
"""
todo: List[Tuple[int, int, dbf.RankedTokens]] = []
heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
ranks: List[dbf.RankedTokens] = []
while todo: # pylint: disable=too-many-nested-blocks
neglen, pos, rank = heapq.heappop(todo)
for tlist in self.query.nodes[pos].starting:
if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
if tlist.end < trange.end:
chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
if tlist.ttype == TokenType.PARTIAL:
penalty = rank.penalty + chgpenalty \
+ max(t.penalty for t in tlist.tokens)
heapq.heappush(todo, (neglen - 1, tlist.end,
dbf.RankedTokens(penalty, rank.tokens)))
else:
for t in tlist.tokens:
heapq.heappush(todo, (neglen - 1, tlist.end,
rank.with_token(t, chgpenalty)))
elif tlist.end == trange.end:
if tlist.ttype == TokenType.PARTIAL:
ranks.append(dbf.RankedTokens(rank.penalty
+ max(t.penalty for t in tlist.tokens),
rank.tokens))
else:
ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
if len(ranks) >= 10:
# Too many variants, bail out and only add
# Worst-case Fallback: sum of penalty of partials
name_partials = self.query.get_partials_list(trange)
default = sum(t.penalty for t in name_partials) + 0.2
ranks.append(dbf.RankedTokens(rank.penalty + default, []))
# Bail out of outer loop
todo.clear()
break
ranks.sort(key=lambda r: len(r.tokens))
default = ranks[0].penalty + 0.3
del ranks[0]
ranks.sort(key=lambda r: r.penalty)
return dbf.FieldRanking('nameaddress_vector', default, ranks)
def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
""" Collect the tokens for the non-name search fields in the
assignment.
"""
sdata = dbf.SearchData()
sdata.penalty = assignment.penalty
if assignment.country:
tokens = self.get_country_tokens(assignment.country)
if not tokens:
return None
sdata.set_strings('countries', tokens)
elif self.details.countries:
sdata.countries = dbf.WeightedStrings(self.details.countries,
[0.0] * len(self.details.countries))
if assignment.housenumber:
sdata.set_strings('housenumbers',
self.query.get_tokens(assignment.housenumber,
TokenType.HOUSENUMBER))
if assignment.postcode:
sdata.set_strings('postcodes',
self.query.get_tokens(assignment.postcode,
TokenType.POSTCODE))
if assignment.qualifier:
tokens = self.get_qualifier_tokens(assignment.qualifier)
if not tokens:
return None
sdata.set_qualifiers(tokens)
elif self.details.categories:
sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
[0.0] * len(self.details.categories))
if assignment.address:
if not assignment.name and assignment.housenumber:
# housenumber search: the first item needs to be handled like
# a name in ranking or penalties are not comparable with
# normal searches.
sdata.set_ranking([self.get_name_ranking(assignment.address[0],
db_field='nameaddress_vector')]
+ [self.get_addr_ranking(r) for r in assignment.address[1:]])
else:
sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
else:
sdata.rankings = []
return sdata
def get_country_tokens(self, trange: TokenRange) -> List[Token]:
""" Return the list of country tokens for the given range,
optionally filtered by the country list from the details
parameters.
"""
tokens = self.query.get_tokens(trange, TokenType.COUNTRY)
if self.details.countries:
tokens = [t for t in tokens if t.lookup_word in self.details.countries]
return tokens
def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
""" Return the list of qualifier tokens for the given range,
optionally filtered by the qualifier list from the details
parameters.
"""
tokens = self.query.get_tokens(trange, TokenType.QUALIFIER)
if self.details.categories:
tokens = [t for t in tokens if t.get_category() in self.details.categories]
return tokens
def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
""" Collect tokens for near items search or use the categories
requested per parameter.
Returns None if no category search is requested.
"""
if assignment.near_item:
tokens: Dict[Tuple[str, str], float] = {}
for t in self.query.get_tokens(assignment.near_item, TokenType.NEAR_ITEM):
cat = t.get_category()
# The category of a near search will be that of near_item.
# Thus, if search is restricted to a category parameter,
# the two sets must intersect.
if (not self.details.categories or cat in self.details.categories)\
and t.penalty < tokens.get(cat, 1000.0):
tokens[cat] = t.penalty
return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
return None
PENALTY_WORDCHANGE = {
BreakType.START: 0.0,
BreakType.END: 0.0,
BreakType.PHRASE: 0.0,
BreakType.WORD: 0.1,
BreakType.PART: 0.2,
BreakType.TOKEN: 0.4
}

View File

@@ -0,0 +1,254 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Data structures for more complex fields in abstract search descriptions.
"""
from typing import List, Tuple, Iterator, Dict, Type
import dataclasses
import sqlalchemy as sa
from nominatim_core.typing import SaFromClause, SaColumn, SaExpression
from .query import Token
from . import db_search_lookups as lookups
from nominatim_core.utils.json_writer import JsonWriter
@dataclasses.dataclass
class WeightedStrings:
""" A list of strings together with a penalty.
"""
values: List[str]
penalties: List[float]
def __bool__(self) -> bool:
return bool(self.values)
def __iter__(self) -> Iterator[Tuple[str, float]]:
return iter(zip(self.values, self.penalties))
def get_penalty(self, value: str, default: float = 1000.0) -> float:
""" Get the penalty for the given value. Returns the given default
if the value does not exist.
"""
try:
return self.penalties[self.values.index(value)]
except ValueError:
pass
return default
@dataclasses.dataclass
class WeightedCategories:
""" A list of class/type tuples together with a penalty.
"""
values: List[Tuple[str, str]]
penalties: List[float]
def __bool__(self) -> bool:
return bool(self.values)
def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
return iter(zip(self.values, self.penalties))
def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
""" Get the penalty for the given value. Returns the given default
if the value does not exist.
"""
try:
return self.penalties[self.values.index(value)]
except ValueError:
pass
return default
def sql_restrict(self, table: SaFromClause) -> SaExpression:
""" Return an SQLAlcheny expression that restricts the
class and type columns of the given table to the values
in the list.
Must not be used with an empty list.
"""
assert self.values
if len(self.values) == 1:
return sa.and_(table.c.class_ == self.values[0][0],
table.c.type == self.values[0][1])
return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t)
for c, t in self.values))
@dataclasses.dataclass(order=True)
class RankedTokens:
""" List of tokens together with the penalty of using it.
"""
penalty: float
tokens: List[int]
def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
""" Create a new RankedTokens list with the given token appended.
The tokens penalty as well as the given transition penalty
are added to the overall penalty.
"""
return RankedTokens(self.penalty + t.penalty + transition_penalty,
self.tokens + [t.token])
@dataclasses.dataclass
class FieldRanking:
""" A list of rankings to be applied sequentially until one matches.
The matched ranking determines the penalty. If none matches a
default penalty is applied.
"""
column: str
default: float
rankings: List[RankedTokens]
def normalize_penalty(self) -> float:
""" Reduce the default and ranking penalties, such that the minimum
penalty is 0. Return the penalty that was subtracted.
"""
if self.rankings:
min_penalty = min(self.default, min(r.penalty for r in self.rankings))
else:
min_penalty = self.default
if min_penalty > 0.0:
self.default -= min_penalty
for ranking in self.rankings:
ranking.penalty -= min_penalty
return min_penalty
def sql_penalty(self, table: SaFromClause) -> SaColumn:
""" Create an SQL expression for the rankings.
"""
assert self.rankings
rout = JsonWriter().start_array()
for rank in self.rankings:
rout.start_array().value(rank.penalty).next()
rout.start_array()
for token in rank.tokens:
rout.value(token).next()
rout.end_array()
rout.end_array().next()
rout.end_array()
return sa.func.weigh_search(table.c[self.column], rout(), self.default)
@dataclasses.dataclass
class FieldLookup:
""" A list of tokens to be searched for. The column names the database
column to search in and the lookup_type the operator that is applied.
'lookup_all' requires all tokens to match. 'lookup_any' requires
one of the tokens to match. 'restrict' requires to match all tokens
but avoids the use of indexes.
"""
column: str
tokens: List[int]
lookup_type: Type[lookups.LookupType]
def sql_condition(self, table: SaFromClause) -> SaColumn:
""" Create an SQL expression for the given match condition.
"""
return self.lookup_type(table, self.column, self.tokens)
class SearchData:
""" Search fields derived from query and token assignment
to be used with the SQL queries.
"""
penalty: float
lookups: List[FieldLookup] = []
rankings: List[FieldRanking]
housenumbers: WeightedStrings = WeightedStrings([], [])
postcodes: WeightedStrings = WeightedStrings([], [])
countries: WeightedStrings = WeightedStrings([], [])
qualifiers: WeightedCategories = WeightedCategories([], [])
def set_strings(self, field: str, tokens: List[Token]) -> None:
""" Set on of the WeightedStrings properties from the given
token list. Adapt the global penalty, so that the
minimum penalty is 0.
"""
if tokens:
min_penalty = min(t.penalty for t in tokens)
self.penalty += min_penalty
wstrs = WeightedStrings([t.lookup_word for t in tokens],
[t.penalty - min_penalty for t in tokens])
setattr(self, field, wstrs)
def set_qualifiers(self, tokens: List[Token]) -> None:
""" Set the qulaifier field from the given tokens.
"""
if tokens:
categories: Dict[Tuple[str, str], float] = {}
min_penalty = 1000.0
for t in tokens:
min_penalty = min(min_penalty, t.penalty)
cat = t.get_category()
if t.penalty < categories.get(cat, 1000.0):
categories[cat] = t.penalty
self.penalty += min_penalty
self.qualifiers = WeightedCategories(list(categories.keys()),
list(categories.values()))
def set_ranking(self, rankings: List[FieldRanking]) -> None:
""" Set the list of rankings and normalize the ranking.
"""
self.rankings = []
for ranking in rankings:
if ranking.rankings:
self.penalty += ranking.normalize_penalty()
self.rankings.append(ranking)
else:
self.penalty += ranking.default
def lookup_by_names(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
""" Create a lookup list where name tokens are looked up via index
and potential address tokens are used to restrict the search further.
"""
lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAll)]
if addr_tokens:
lookup.append(FieldLookup('nameaddress_vector', addr_tokens, lookups.Restrict))
return lookup
def lookup_by_any_name(name_tokens: List[int], addr_restrict_tokens: List[int],
addr_lookup_tokens: List[int]) -> List[FieldLookup]:
""" Create a lookup list where name tokens are looked up via index
and only one of the name tokens must be present.
Potential address tokens are used to restrict the search further.
"""
lookup = [FieldLookup('name_vector', name_tokens, lookups.LookupAny)]
if addr_restrict_tokens:
lookup.append(FieldLookup('nameaddress_vector', addr_restrict_tokens, lookups.Restrict))
if addr_lookup_tokens:
lookup.append(FieldLookup('nameaddress_vector', addr_lookup_tokens, lookups.LookupAll))
return lookup
def lookup_by_addr(name_tokens: List[int], addr_tokens: List[int]) -> List[FieldLookup]:
""" Create a lookup list where address tokens are looked up via index
and the name tokens are only used to restrict the search further.
"""
return [FieldLookup('name_vector', name_tokens, lookups.Restrict),
FieldLookup('nameaddress_vector', addr_tokens, lookups.LookupAll)]

View File

@@ -0,0 +1,114 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of lookup functions for the search_name table.
"""
from typing import List, Any
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from nominatim_core.typing import SaFromClause
from nominatim_core.db.sqlalchemy_types import IntArray
# pylint: disable=consider-using-f-string
LookupType = sa.sql.expression.FunctionElement[Any]
class LookupAll(LookupType):
""" Find all entries in search_name table that contain all of
a given list of tokens using an index for the search.
"""
inherit_cache = True
def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
super().__init__(table.c.place_id, getattr(table.c, column), column,
sa.type_coerce(tokens, IntArray))
@compiles(LookupAll) # type: ignore[no-untyped-call, misc]
def _default_lookup_all(element: LookupAll,
compiler: 'sa.Compiled', **kw: Any) -> str:
_, col, _, tokens = list(element.clauses)
return "(%s @> %s)" % (compiler.process(col, **kw),
compiler.process(tokens, **kw))
@compiles(LookupAll, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_lookup_all(element: LookupAll,
compiler: 'sa.Compiled', **kw: Any) -> str:
place, col, colname, tokens = list(element.clauses)
return "(%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_intersect_fuzzy(places) as p FROM"\
" (SELECT places FROM reverse_search_name"\
" WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
" AND column = %s"\
" ORDER BY length(places)) as x) as u,"\
" json_each('[' || u.p || ']'))"\
" AND array_contains(%s, %s))"\
% (compiler.process(place, **kw),
compiler.process(tokens, **kw),
compiler.process(colname, **kw),
compiler.process(col, **kw),
compiler.process(tokens, **kw)
)
class LookupAny(LookupType):
""" Find all entries that contain at least one of the given tokens.
Use an index for the search.
"""
inherit_cache = True
def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
super().__init__(table.c.place_id, getattr(table.c, column), column,
sa.type_coerce(tokens, IntArray))
@compiles(LookupAny) # type: ignore[no-untyped-call, misc]
def _default_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str:
_, col, _, tokens = list(element.clauses)
return "(%s && %s)" % (compiler.process(col, **kw),
compiler.process(tokens, **kw))
@compiles(LookupAny, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str:
place, _, colname, tokens = list(element.clauses)
return "%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_union(places) as p FROM reverse_search_name"\
" WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
" AND column = %s) as u,"\
" json_each('[' || u.p || ']'))" % (compiler.process(place, **kw),
compiler.process(tokens, **kw),
compiler.process(colname, **kw))
class Restrict(LookupType):
""" Find all entries that contain all of the given tokens.
Do not use an index for the search.
"""
inherit_cache = True
def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
super().__init__(getattr(table.c, column),
sa.type_coerce(tokens, IntArray))
@compiles(Restrict) # type: ignore[no-untyped-call, misc]
def _default_restrict(element: Restrict,
compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw))
@compiles(Restrict, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_restrict(element: Restrict,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_contains(%s)" % compiler.process(element.clauses, **kw)

View File

@@ -0,0 +1,874 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of the actual database accesses for forward search.
"""
from typing import List, Tuple, AsyncIterator, Dict, Any, Callable, cast
import abc
import sqlalchemy as sa
from nominatim_core.typing import SaFromClause, SaScalarSelect, SaColumn, \
SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
from nominatim_core.db.sqlalchemy_types import Geometry, IntArray
from ..connection import SearchConnection
from ..types import SearchDetails, DataLayer, GeometryFormat, Bbox
from .. import results as nres
from .db_search_fields import SearchData, WeightedCategories
#pylint: disable=singleton-comparison,not-callable
#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
def no_index(expr: SaColumn) -> SaColumn:
""" Wrap the given expression, so that the query planner will
refrain from using the expression for index lookup.
"""
return sa.func.coalesce(sa.null(), expr) # pylint: disable=not-callable
def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
""" Create a dictionary from search parameters that can be used
as bind parameter for SQL execute.
"""
return {'limit': details.max_results,
'min_rank': details.min_rank,
'max_rank': details.max_rank,
'viewbox': details.viewbox,
'viewbox2': details.viewbox_x2,
'near': details.near,
'near_radius': details.near_radius,
'excluded': details.excluded,
'countries': details.countries}
LIMIT_PARAM: SaBind = sa.bindparam('limit')
MIN_RANK_PARAM: SaBind = sa.bindparam('min_rank')
MAX_RANK_PARAM: SaBind = sa.bindparam('max_rank')
VIEWBOX_PARAM: SaBind = sa.bindparam('viewbox', type_=Geometry)
VIEWBOX2_PARAM: SaBind = sa.bindparam('viewbox2', type_=Geometry)
NEAR_PARAM: SaBind = sa.bindparam('near', type_=Geometry)
NEAR_RADIUS_PARAM: SaBind = sa.bindparam('near_radius')
COUNTRIES_PARAM: SaBind = sa.bindparam('countries')
def filter_by_area(sql: SaSelect, t: SaFromClause,
details: SearchDetails, avoid_index: bool = False) -> SaSelect:
""" Apply SQL statements for filtering by viewbox and near point,
if applicable.
"""
if details.near is not None and details.near_radius is not None:
if details.near_radius < 0.1 and not avoid_index:
sql = sql.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM))
else:
sql = sql.where(t.c.geometry.ST_Distance(NEAR_PARAM) <= NEAR_RADIUS_PARAM)
if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
use_index=not avoid_index and
details.viewbox.area < 0.2))
return sql
def _exclude_places(t: SaFromClause) -> Callable[[], SaExpression]:
return lambda: t.c.place_id.not_in(sa.bindparam('excluded'))
def _select_placex(t: SaFromClause) -> SaSelect:
return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
t.c.class_, t.c.type,
t.c.address, t.c.extratags,
t.c.housenumber, t.c.postcode, t.c.country_code,
t.c.wikipedia,
t.c.parent_place_id, t.c.rank_address, t.c.rank_search,
t.c.linked_place_id, t.c.admin_level,
t.c.centroid,
t.c.geometry.ST_Expand(0).label('bbox'))
def _add_geometry_columns(sql: SaLambdaSelect, col: SaColumn, details: SearchDetails) -> SaSelect:
out = []
if details.geometry_simplification > 0.0:
col = sa.func.ST_SimplifyPreserveTopology(col, details.geometry_simplification)
if details.geometry_output & GeometryFormat.GEOJSON:
out.append(sa.func.ST_AsGeoJSON(col, 7).label('geometry_geojson'))
if details.geometry_output & GeometryFormat.TEXT:
out.append(sa.func.ST_AsText(col).label('geometry_text'))
if details.geometry_output & GeometryFormat.KML:
out.append(sa.func.ST_AsKML(col, 7).label('geometry_kml'))
if details.geometry_output & GeometryFormat.SVG:
out.append(sa.func.ST_AsSVG(col, 0, 7).label('geometry_svg'))
return sql.add_columns(*out)
def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
numerals: List[int], details: SearchDetails) -> SaScalarSelect:
all_ids = sa.func.ArrayAgg(table.c.place_id)
sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id)
if len(numerals) == 1:
sql = sql.where(sa.between(numerals[0], table.c.startnumber, table.c.endnumber))\
.where((numerals[0] - table.c.startnumber) % table.c.step == 0)
else:
sql = sql.where(sa.or_(
*(sa.and_(sa.between(n, table.c.startnumber, table.c.endnumber),
(n - table.c.startnumber) % table.c.step == 0)
for n in numerals)))
if details.excluded:
sql = sql.where(_exclude_places(table))
return sql.scalar_subquery()
def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn:
orexpr: List[SaExpression] = []
if layers & DataLayer.ADDRESS and layers & DataLayer.POI:
orexpr.append(no_index(table.c.rank_address).between(1, 30))
elif layers & DataLayer.ADDRESS:
orexpr.append(no_index(table.c.rank_address).between(1, 29))
orexpr.append(sa.func.IsAddressPoint(table))
elif layers & DataLayer.POI:
orexpr.append(sa.and_(no_index(table.c.rank_address) == 30,
table.c.class_.not_in(('place', 'building'))))
if layers & DataLayer.MANMADE:
exclude = []
if not layers & DataLayer.RAILWAY:
exclude.append('railway')
if not layers & DataLayer.NATURAL:
exclude.extend(('natural', 'water', 'waterway'))
orexpr.append(sa.and_(table.c.class_.not_in(tuple(exclude)),
no_index(table.c.rank_address) == 0))
else:
include = []
if layers & DataLayer.RAILWAY:
include.append('railway')
if layers & DataLayer.NATURAL:
include.extend(('natural', 'water', 'waterway'))
orexpr.append(sa.and_(table.c.class_.in_(tuple(include)),
no_index(table.c.rank_address) == 0))
if len(orexpr) == 1:
return orexpr[0]
return sa.or_(*orexpr)
def _interpolated_position(table: SaFromClause, nr: SaColumn) -> SaColumn:
pos = sa.cast(nr - table.c.startnumber, sa.Float) / (table.c.endnumber - table.c.startnumber)
return sa.case(
(table.c.endnumber == table.c.startnumber, table.c.linegeo.ST_Centroid()),
else_=table.c.linegeo.ST_LineInterpolatePoint(pos)).label('centroid')
async def _get_placex_housenumbers(conn: SearchConnection,
place_ids: List[int],
details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
t = conn.t.placex
sql = _select_placex(t).add_columns(t.c.importance)\
.where(t.c.place_id.in_(place_ids))
if details.geometry_output:
sql = _add_geometry_columns(sql, t.c.geometry, details)
for row in await conn.execute(sql):
result = nres.create_from_placex_row(row, nres.SearchResult)
assert result
result.bbox = Bbox.from_wkb(row.bbox)
yield result
def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
""" Create a subselect that returns the given list of integers
as rows in the column 'nr'.
"""
vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
.table_valued(sa.column('value', type_=sa.JSON))
return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
async def _get_osmline(conn: SearchConnection, place_ids: List[int],
numerals: List[int],
details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
t = conn.t.osmline
values = _int_list_to_subquery(numerals)
sql = sa.select(t.c.place_id, t.c.osm_id,
t.c.parent_place_id, t.c.address,
values.c.nr.label('housenumber'),
_interpolated_position(t, values.c.nr),
t.c.postcode, t.c.country_code)\
.where(t.c.place_id.in_(place_ids))\
.join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber))
if details.geometry_output:
sub = sql.subquery()
sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details)
for row in await conn.execute(sql):
result = nres.create_from_osmline_row(row, nres.SearchResult)
assert result
yield result
async def _get_tiger(conn: SearchConnection, place_ids: List[int],
numerals: List[int], osm_id: int,
details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
t = conn.t.tiger
values = _int_list_to_subquery(numerals)
sql = sa.select(t.c.place_id, t.c.parent_place_id,
sa.literal('W').label('osm_type'),
sa.literal(osm_id).label('osm_id'),
values.c.nr.label('housenumber'),
_interpolated_position(t, values.c.nr),
t.c.postcode)\
.where(t.c.place_id.in_(place_ids))\
.join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber))
if details.geometry_output:
sub = sql.subquery()
sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details)
for row in await conn.execute(sql):
result = nres.create_from_tiger_row(row, nres.SearchResult)
assert result
yield result
class AbstractSearch(abc.ABC):
""" Encapuslation of a single lookup in the database.
"""
SEARCH_PRIO: int = 2
def __init__(self, penalty: float) -> None:
self.penalty = penalty
@abc.abstractmethod
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
class NearSearch(AbstractSearch):
""" Category search of a place type near the result of another search.
"""
def __init__(self, penalty: float, categories: WeightedCategories,
search: AbstractSearch) -> None:
super().__init__(penalty)
self.search = search
self.categories = categories
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
results = nres.SearchResults()
base = await self.search.lookup(conn, details)
if not base:
return results
base.sort(key=lambda r: (r.accuracy, r.rank_search))
max_accuracy = base[0].accuracy + 0.5
if base[0].rank_address == 0:
min_rank = 0
max_rank = 0
elif base[0].rank_address < 26:
min_rank = 1
max_rank = min(25, base[0].rank_address + 4)
else:
min_rank = 26
max_rank = 30
base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
and r.accuracy <= max_accuracy
and r.bbox and r.bbox.area < 20
and r.rank_address >= min_rank
and r.rank_address <= max_rank)
if base:
baseids = [b.place_id for b in base[:5] if b.place_id]
for category, penalty in self.categories:
await self.lookup_category(results, conn, baseids, category, penalty, details)
if len(results) >= details.max_results:
break
return results
async def lookup_category(self, results: nres.SearchResults,
conn: SearchConnection, ids: List[int],
category: Tuple[str, str], penalty: float,
details: SearchDetails) -> None:
""" Find places of the given category near the list of
place ids and add the results to 'results'.
"""
table = await conn.get_class_table(*category)
tgeom = conn.t.placex.alias('pgeom')
if table is None:
# No classtype table available, do a simplified lookup in placex.
table = conn.t.placex
sql = sa.select(table.c.place_id,
sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid))
.label('dist'))\
.join(tgeom, table.c.geometry.intersects(tgeom.c.centroid.ST_Expand(0.01)))\
.where(table.c.class_ == category[0])\
.where(table.c.type == category[1])
else:
# Use classtype table. We can afford to use a larger
# radius for the lookup.
sql = sa.select(table.c.place_id,
sa.func.min(tgeom.c.centroid.ST_Distance(table.c.centroid))
.label('dist'))\
.join(tgeom,
table.c.centroid.ST_CoveredBy(
sa.case((sa.and_(tgeom.c.rank_address > 9,
tgeom.c.geometry.is_area()),
tgeom.c.geometry),
else_ = tgeom.c.centroid.ST_Expand(0.05))))
inner = sql.where(tgeom.c.place_id.in_(ids))\
.group_by(table.c.place_id).subquery()
t = conn.t.placex
sql = _select_placex(t).add_columns((-inner.c.dist).label('importance'))\
.join(inner, inner.c.place_id == t.c.place_id)\
.order_by(inner.c.dist)
sql = sql.where(no_index(t.c.rank_address).between(MIN_RANK_PARAM, MAX_RANK_PARAM))
if details.countries:
sql = sql.where(t.c.country_code.in_(COUNTRIES_PARAM))
if details.excluded:
sql = sql.where(_exclude_places(t))
if details.layers is not None:
sql = sql.where(_filter_by_layer(t, details.layers))
sql = sql.limit(LIMIT_PARAM)
for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_placex_row(row, nres.SearchResult)
assert result
result.accuracy = self.penalty + penalty
result.bbox = Bbox.from_wkb(row.bbox)
results.append(result)
class PoiSearch(AbstractSearch):
""" Category search in a geographic area.
"""
def __init__(self, sdata: SearchData) -> None:
super().__init__(sdata.penalty)
self.qualifiers = sdata.qualifiers
self.countries = sdata.countries
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
bind_params = _details_to_bind_params(details)
t = conn.t.placex
rows: List[SaRow] = []
if details.near and details.near_radius is not None and details.near_radius < 0.2:
# simply search in placex table
def _base_query() -> SaSelect:
return _select_placex(t) \
.add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
.label('importance'))\
.where(t.c.linked_place_id == None) \
.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
.order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
.limit(LIMIT_PARAM)
classtype = self.qualifiers.values
if len(classtype) == 1:
cclass, ctype = classtype[0]
sql: SaLambdaSelect = sa.lambda_stmt(lambda: _base_query()
.where(t.c.class_ == cclass)
.where(t.c.type == ctype))
else:
sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
for cls, typ in classtype)))
if self.countries:
sql = sql.where(t.c.country_code.in_(self.countries.values))
if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
rows.extend(await conn.execute(sql, bind_params))
else:
# use the class type tables
for category in self.qualifiers.values:
table = await conn.get_class_table(*category)
if table is not None:
sql = _select_placex(t)\
.add_columns(t.c.importance)\
.join(table, t.c.place_id == table.c.place_id)\
.where(t.c.class_ == category[0])\
.where(t.c.type == category[1])
if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(table.c.centroid.intersects(VIEWBOX_PARAM))
if details.near and details.near_radius is not None:
sql = sql.order_by(table.c.centroid.ST_Distance(NEAR_PARAM))\
.where(table.c.centroid.within_distance(NEAR_PARAM,
NEAR_RADIUS_PARAM))
if self.countries:
sql = sql.where(t.c.country_code.in_(self.countries.values))
sql = sql.limit(LIMIT_PARAM)
rows.extend(await conn.execute(sql, bind_params))
results = nres.SearchResults()
for row in rows:
result = nres.create_from_placex_row(row, nres.SearchResult)
assert result
result.accuracy = self.penalty + self.qualifiers.get_penalty((row.class_, row.type))
result.bbox = Bbox.from_wkb(row.bbox)
results.append(result)
return results
class CountrySearch(AbstractSearch):
""" Search for a country name or country code.
"""
SEARCH_PRIO = 0
def __init__(self, sdata: SearchData) -> None:
super().__init__(sdata.penalty)
self.countries = sdata.countries
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
t = conn.t.placex
ccodes = self.countries.values
sql = _select_placex(t)\
.add_columns(t.c.importance)\
.where(t.c.country_code.in_(ccodes))\
.where(t.c.rank_address == 4)
if details.geometry_output:
sql = _add_geometry_columns(sql, t.c.geometry, details)
if details.excluded:
sql = sql.where(_exclude_places(t))
sql = filter_by_area(sql, t, details)
results = nres.SearchResults()
for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_placex_row(row, nres.SearchResult)
assert result
result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
result.bbox = Bbox.from_wkb(row.bbox)
results.append(result)
if not results:
results = await self.lookup_in_country_table(conn, details)
if results:
details.min_rank = min(5, details.max_rank)
details.max_rank = min(25, details.max_rank)
return results
async def lookup_in_country_table(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Look up the country in the fallback country tables.
"""
# Avoid the fallback search when this is a more search. Country results
# usually are in the first batch of results and it is not possible
# to exclude these fallbacks.
if details.excluded:
return nres.SearchResults()
t = conn.t.country_name
tgrid = conn.t.country_grid
sql = sa.select(tgrid.c.country_code,
tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
.label('centroid'),
tgrid.c.geometry.ST_Collect().ST_Expand(0).label('bbox'))\
.where(tgrid.c.country_code.in_(self.countries.values))\
.group_by(tgrid.c.country_code)
sql = filter_by_area(sql, tgrid, details, avoid_index=True)
sub = sql.subquery('grid')
sql = sa.select(t.c.country_code,
t.c.name.merge(t.c.derived_name).label('name'),
sub.c.centroid, sub.c.bbox)\
.join(sub, t.c.country_code == sub.c.country_code)
if details.geometry_output:
sql = _add_geometry_columns(sql, sub.c.centroid, details)
results = nres.SearchResults()
for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_country_row(row, nres.SearchResult)
assert result
result.bbox = Bbox.from_wkb(row.bbox)
result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
results.append(result)
return results
class PostcodeSearch(AbstractSearch):
""" Search for a postcode.
"""
def __init__(self, extra_penalty: float, sdata: SearchData) -> None:
super().__init__(sdata.penalty + extra_penalty)
self.countries = sdata.countries
self.postcodes = sdata.postcodes
self.lookups = sdata.lookups
self.rankings = sdata.rankings
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
t = conn.t.postcode
pcs = self.postcodes.values
sql = sa.select(t.c.place_id, t.c.parent_place_id,
t.c.rank_search, t.c.rank_address,
t.c.postcode, t.c.country_code,
t.c.geometry.label('centroid'))\
.where(t.c.postcode.in_(pcs))
if details.geometry_output:
sql = _add_geometry_columns(sql, t.c.geometry, details)
penalty: SaExpression = sa.literal(self.penalty)
if details.viewbox is not None and not details.bounded_viewbox:
penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
(t.c.geometry.intersects(VIEWBOX2_PARAM), 0.5),
else_=1.0)
if details.near is not None:
sql = sql.order_by(t.c.geometry.ST_Distance(NEAR_PARAM))
sql = filter_by_area(sql, t, details)
if self.countries:
sql = sql.where(t.c.country_code.in_(self.countries.values))
if details.excluded:
sql = sql.where(_exclude_places(t))
if self.lookups:
assert len(self.lookups) == 1
tsearch = conn.t.search_name
sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
.where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
.contains(sa.type_coerce(self.lookups[0].tokens,
IntArray)))
for ranking in self.rankings:
penalty += ranking.sql_penalty(conn.t.search_name)
penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
else_=1.0)
sql = sql.add_columns(penalty.label('accuracy'))
sql = sql.order_by('accuracy').limit(LIMIT_PARAM)
results = nres.SearchResults()
for row in await conn.execute(sql, _details_to_bind_params(details)):
p = conn.t.placex
placex_sql = _select_placex(p).add_columns(p.c.importance)\
.where(sa.text("""class = 'boundary'
AND type = 'postal_code'
AND osm_type = 'R'"""))\
.where(p.c.country_code == row.country_code)\
.where(p.c.postcode == row.postcode)\
.limit(1)
if details.geometry_output:
placex_sql = _add_geometry_columns(placex_sql, p.c.geometry, details)
for prow in await conn.execute(placex_sql, _details_to_bind_params(details)):
result = nres.create_from_placex_row(prow, nres.SearchResult)
break
else:
result = nres.create_from_postcode_row(row, nres.SearchResult)
assert result
if result.place_id not in details.excluded:
result.accuracy = row.accuracy
results.append(result)
return results
class PlaceSearch(AbstractSearch):
""" Generic search for an address or named place.
"""
SEARCH_PRIO = 1
def __init__(self, extra_penalty: float, sdata: SearchData, expected_count: int) -> None:
super().__init__(sdata.penalty + extra_penalty)
self.countries = sdata.countries
self.postcodes = sdata.postcodes
self.housenumbers = sdata.housenumbers
self.qualifiers = sdata.qualifiers
self.lookups = sdata.lookups
self.rankings = sdata.rankings
self.expected_count = expected_count
def _inner_search_name_cte(self, conn: SearchConnection,
details: SearchDetails) -> 'sa.CTE':
""" Create a subquery that preselects the rows in the search_name
table.
"""
t = conn.t.search_name
penalty: SaExpression = sa.literal(self.penalty)
for ranking in self.rankings:
penalty += ranking.sql_penalty(t)
sql = sa.select(t.c.place_id, t.c.search_rank, t.c.address_rank,
t.c.country_code, t.c.centroid,
t.c.name_vector, t.c.nameaddress_vector,
sa.case((t.c.importance > 0, t.c.importance),
else_=0.40001-(sa.cast(t.c.search_rank, sa.Float())/75))
.label('importance'),
penalty.label('penalty'))
for lookup in self.lookups:
sql = sql.where(lookup.sql_condition(t))
if self.countries:
sql = sql.where(t.c.country_code.in_(self.countries.values))
if self.postcodes:
# if a postcode is given, don't search for state or country level objects
sql = sql.where(t.c.address_rank > 9)
if self.expected_count > 10000:
# Many results expected. Restrict by postcode.
tpc = conn.t.postcode
sql = sql.where(sa.select(tpc.c.postcode)
.where(tpc.c.postcode.in_(self.postcodes.values))
.where(t.c.centroid.within_distance(tpc.c.geometry, 0.4))
.exists())
if details.viewbox is not None:
if details.bounded_viewbox:
sql = sql.where(t.c.centroid
.intersects(VIEWBOX_PARAM,
use_index=details.viewbox.area < 0.2))
elif not self.postcodes and not self.housenumbers and self.expected_count >= 10000:
sql = sql.where(t.c.centroid
.intersects(VIEWBOX2_PARAM,
use_index=details.viewbox.area < 0.5))
if details.near is not None and details.near_radius is not None:
if details.near_radius < 0.1:
sql = sql.where(t.c.centroid.within_distance(NEAR_PARAM,
NEAR_RADIUS_PARAM))
else:
sql = sql.where(t.c.centroid
.ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM)
if self.housenumbers:
sql = sql.where(t.c.address_rank.between(16, 30))
else:
if details.excluded:
sql = sql.where(_exclude_places(t))
if details.min_rank > 0:
sql = sql.where(sa.or_(t.c.address_rank >= MIN_RANK_PARAM,
t.c.search_rank >= MIN_RANK_PARAM))
if details.max_rank < 30:
sql = sql.where(sa.or_(t.c.address_rank <= MAX_RANK_PARAM,
t.c.search_rank <= MAX_RANK_PARAM))
inner = sql.limit(10000).order_by(sa.desc(sa.text('importance'))).subquery()
sql = sa.select(inner.c.place_id, inner.c.search_rank, inner.c.address_rank,
inner.c.country_code, inner.c.centroid, inner.c.importance,
inner.c.penalty)
# If the query is not an address search or has a geographic preference,
# preselect most important items to restrict the number of places
# that need to be looked up in placex.
if not self.housenumbers\
and (details.viewbox is None or details.bounded_viewbox)\
and (details.near is None or details.near_radius is not None)\
and not self.qualifiers:
sql = sql.add_columns(sa.func.first_value(inner.c.penalty - inner.c.importance)
.over(order_by=inner.c.penalty - inner.c.importance)
.label('min_penalty'))
inner = sql.subquery()
sql = sa.select(inner.c.place_id, inner.c.search_rank, inner.c.address_rank,
inner.c.country_code, inner.c.centroid, inner.c.importance,
inner.c.penalty)\
.where(inner.c.penalty - inner.c.importance < inner.c.min_penalty + 0.5)
return sql.cte('searches')
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
t = conn.t.placex
tsearch = self._inner_search_name_cte(conn, details)
sql = _select_placex(t).join(tsearch, t.c.place_id == tsearch.c.place_id)
if details.geometry_output:
sql = _add_geometry_columns(sql, t.c.geometry, details)
penalty: SaExpression = tsearch.c.penalty
if self.postcodes:
tpc = conn.t.postcode
pcs = self.postcodes.values
pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(t.c.centroid)))\
.where(tpc.c.postcode.in_(pcs))\
.scalar_subquery()
penalty += sa.case((t.c.postcode.in_(pcs), 0.0),
else_=sa.func.coalesce(pc_near, cast(SaColumn, 2.0)))
if details.viewbox is not None and not details.bounded_viewbox:
penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM, use_index=False), 0.0),
(t.c.geometry.intersects(VIEWBOX2_PARAM, use_index=False), 0.5),
else_=1.0)
if details.near is not None:
sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
.label('importance'))
sql = sql.order_by(sa.desc(sa.text('importance')))
else:
sql = sql.order_by(penalty - tsearch.c.importance)
sql = sql.add_columns(tsearch.c.importance)
sql = sql.add_columns(penalty.label('accuracy'))\
.order_by(sa.text('accuracy'))
if self.housenumbers:
hnr_list = '|'.join(self.housenumbers.values)
inner = sql.where(sa.or_(tsearch.c.address_rank < 30,
sa.func.RegexpWord(hnr_list, t.c.housenumber)))\
.subquery()
# Housenumbers from placex
thnr = conn.t.placex.alias('hnr')
pid_list = sa.func.ArrayAgg(thnr.c.place_id)
place_sql = sa.select(pid_list)\
.where(thnr.c.parent_place_id == inner.c.place_id)\
.where(sa.func.RegexpWord(hnr_list, thnr.c.housenumber))\
.where(thnr.c.linked_place_id == None)\
.where(thnr.c.indexed_status == 0)
if details.excluded:
place_sql = place_sql.where(thnr.c.place_id.not_in(sa.bindparam('excluded')))
if self.qualifiers:
place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr))
numerals = [int(n) for n in self.housenumbers.values
if n.isdigit() and len(n) < 8]
interpol_sql: SaColumn
tiger_sql: SaColumn
if numerals and \
(not self.qualifiers or ('place', 'house') in self.qualifiers.values):
# Housenumbers from interpolations
interpol_sql = _make_interpolation_subquery(conn.t.osmline, inner,
numerals, details)
# Housenumbers from Tiger
tiger_sql = sa.case((inner.c.country_code == 'us',
_make_interpolation_subquery(conn.t.tiger, inner,
numerals, details)
), else_=None)
else:
interpol_sql = sa.null()
tiger_sql = sa.null()
unsort = sa.select(inner, place_sql.scalar_subquery().label('placex_hnr'),
interpol_sql.label('interpol_hnr'),
tiger_sql.label('tiger_hnr')).subquery('unsort')
sql = sa.select(unsort)\
.order_by(sa.case((unsort.c.placex_hnr != None, 1),
(unsort.c.interpol_hnr != None, 2),
(unsort.c.tiger_hnr != None, 3),
else_=4),
unsort.c.accuracy)
else:
sql = sql.where(t.c.linked_place_id == None)\
.where(t.c.indexed_status == 0)
if self.qualifiers:
sql = sql.where(self.qualifiers.sql_restrict(t))
if details.layers is not None:
sql = sql.where(_filter_by_layer(t, details.layers))
sql = sql.limit(LIMIT_PARAM)
results = nres.SearchResults()
for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_placex_row(row, nres.SearchResult)
assert result
result.bbox = Bbox.from_wkb(row.bbox)
result.accuracy = row.accuracy
if self.housenumbers and row.rank_address < 30:
if row.placex_hnr:
subs = _get_placex_housenumbers(conn, row.placex_hnr, details)
elif row.interpol_hnr:
subs = _get_osmline(conn, row.interpol_hnr, numerals, details)
elif row.tiger_hnr:
subs = _get_tiger(conn, row.tiger_hnr, numerals, row.osm_id, details)
else:
subs = None
if subs is not None:
async for sub in subs:
assert sub.housenumber
sub.accuracy = result.accuracy
if not any(nr in self.housenumbers.values
for nr in sub.housenumber.split(';')):
sub.accuracy += 0.6
results.append(sub)
# Only add the street as a result, if it meets all other
# filter conditions.
if (not details.excluded or result.place_id not in details.excluded)\
and (not self.qualifiers or result.category in self.qualifiers.values)\
and result.rank_address >= details.min_rank:
result.accuracy += 1.0 # penalty for missing housenumber
results.append(result)
else:
results.append(result)
return results

View File

@@ -0,0 +1,274 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Public interface to the search code.
"""
from typing import List, Any, Optional, Iterator, Tuple, Dict
import itertools
import re
import datetime as dt
import difflib
from ..connection import SearchConnection
from ..types import SearchDetails
from ..results import SearchResult, SearchResults, add_result_details
from ..logging import log
from .token_assignment import yield_token_assignments
from .db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
from .db_searches import AbstractSearch
from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
from .query import Phrase, QueryStruct
class ForwardGeocoder:
""" Main class responsible for place search.
"""
def __init__(self, conn: SearchConnection,
params: SearchDetails, timeout: Optional[int]) -> None:
self.conn = conn
self.params = params
self.timeout = dt.timedelta(seconds=timeout or 1000000)
self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
@property
def limit(self) -> int:
""" Return the configured maximum number of search results.
"""
return self.params.max_results
async def build_searches(self,
phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
""" Analyse the query and return the tokenized query and list of
possible searches over it.
"""
if self.query_analyzer is None:
self.query_analyzer = await make_query_analyzer(self.conn)
query = await self.query_analyzer.analyze_query(phrases)
searches: List[AbstractSearch] = []
if query.num_token_slots() > 0:
# 2. Compute all possible search interpretations
log().section('Compute abstract searches')
search_builder = SearchBuilder(query, self.params)
num_searches = 0
for assignment in yield_token_assignments(query):
searches.extend(search_builder.build(assignment))
if num_searches < len(searches):
log().table_dump('Searches for assignment',
_dump_searches(searches, query, num_searches))
num_searches = len(searches)
searches.sort(key=lambda s: (s.penalty, s.SEARCH_PRIO))
return query, searches
async def execute_searches(self, query: QueryStruct,
searches: List[AbstractSearch]) -> SearchResults:
""" Run the abstract searches against the database until a result
is found.
"""
log().section('Execute database searches')
results: Dict[Any, SearchResult] = {}
end_time = dt.datetime.now() + self.timeout
min_ranking = searches[0].penalty + 2.0
prev_penalty = 0.0
for i, search in enumerate(searches):
if search.penalty > prev_penalty and (search.penalty > min_ranking or i > 20):
break
log().table_dump(f"{i + 1}. Search", _dump_searches([search], query))
log().var_dump('Params', self.params)
lookup_results = await search.lookup(self.conn, self.params)
for result in lookup_results:
rhash = (result.source_table, result.place_id,
result.housenumber, result.country_code)
prevresult = results.get(rhash)
if prevresult:
prevresult.accuracy = min(prevresult.accuracy, result.accuracy)
else:
results[rhash] = result
min_ranking = min(min_ranking, result.accuracy * 1.2, 2.0)
log().result_dump('Results', ((r.accuracy, r) for r in lookup_results))
prev_penalty = search.penalty
if dt.datetime.now() >= end_time:
break
return SearchResults(results.values())
def pre_filter_results(self, results: SearchResults) -> SearchResults:
""" Remove results that are significantly worse than the
best match.
"""
if results:
max_ranking = min(r.ranking for r in results) + 0.5
results = SearchResults(r for r in results if r.ranking < max_ranking)
return results
def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
""" Remove badly matching results, sort by ranking and
limit to the configured number of results.
"""
if results:
results.sort(key=lambda r: r.ranking)
min_rank = results[0].rank_search
min_ranking = results[0].ranking
results = SearchResults(r for r in results
if r.ranking + 0.03 * (r.rank_search - min_rank)
< min_ranking + 0.5)
results = SearchResults(results[:self.limit])
return results
def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
""" Adjust the accuracy of the localized result according to how well
they match the original query.
"""
assert self.query_analyzer is not None
qwords = [word for phrase in query.source
for word in re.split('[, ]+', phrase.text) if word]
if not qwords:
return
for result in results:
# Negative importance indicates ordering by distance, which is
# more important than word matching.
if not result.display_name\
or (result.importance is not None and result.importance < 0):
continue
distance = 0.0
norm = self.query_analyzer.normalize_text(' '.join((result.display_name,
result.country_code or '')))
words = set((w for w in norm.split(' ') if w))
if not words:
continue
for qword in qwords:
wdist = max(difflib.SequenceMatcher(a=qword, b=w).quick_ratio() for w in words)
if wdist < 0.5:
distance += len(qword)
else:
distance += (1.0 - wdist) * len(qword)
# Compensate for the fact that country names do not get a
# match penalty yet by the tokenizer.
# Temporary hack that needs to be removed!
if result.rank_address == 4:
distance *= 2
result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
async def lookup_pois(self, categories: List[Tuple[str, str]],
phrases: List[Phrase]) -> SearchResults:
""" Look up places by category. If phrase is given, a place search
over the phrase will be executed first and places close to the
results returned.
"""
log().function('forward_lookup_pois', categories=categories, params=self.params)
if phrases:
query, searches = await self.build_searches(phrases)
if query:
searches = [wrap_near_search(categories, s) for s in searches[:50]]
results = await self.execute_searches(query, searches)
results = self.pre_filter_results(results)
await add_result_details(self.conn, results, self.params)
log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
results = self.sort_and_cut_results(results)
else:
results = SearchResults()
else:
search = build_poi_search(categories, self.params.countries)
results = await search.lookup(self.conn, self.params)
await add_result_details(self.conn, results, self.params)
log().result_dump('Final Results', ((r.accuracy, r) for r in results))
return results
async def lookup(self, phrases: List[Phrase]) -> SearchResults:
""" Look up a single free-text query.
"""
log().function('forward_lookup', phrases=phrases, params=self.params)
results = SearchResults()
if self.params.is_impossible():
return results
query, searches = await self.build_searches(phrases)
if searches:
# Execute SQL until an appropriate result is found.
results = await self.execute_searches(query, searches[:50])
results = self.pre_filter_results(results)
await add_result_details(self.conn, results, self.params)
log().result_dump('Preliminary Results', ((r.accuracy, r) for r in results))
self.rerank_by_query(query, results)
log().result_dump('Results after reranking', ((r.accuracy, r) for r in results))
results = self.sort_and_cut_results(results)
log().result_dump('Final Results', ((r.accuracy, r) for r in results))
return results
# pylint: disable=invalid-name,too-many-locals
def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
start: int = 0) -> Iterator[Optional[List[Any]]]:
yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
'Qualifier', 'Catgeory', 'Rankings']
def tk(tl: List[int]) -> str:
tstr = [f"{query.find_lookup_word_by_id(t)}({t})" for t in tl]
return f"[{','.join(tstr)}]"
def fmt_ranking(f: Any) -> str:
if not f:
return ''
ranks = ','.join((f"{tk(r.tokens)}^{r.penalty:.3g}" for r in f.rankings))
if len(ranks) > 100:
ranks = ranks[:100] + '...'
return f"{f.column}({ranks},def={f.default:.3g})"
def fmt_lookup(l: Any) -> str:
if not l:
return ''
return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
def fmt_cstr(c: Any) -> str:
if not c:
return ''
return f'{c[0]}^{c[1]}'
for search in searches[start:]:
fields = ('lookups', 'rankings', 'countries', 'housenumbers',
'postcodes', 'qualifiers')
if hasattr(search, 'search'):
iters = itertools.zip_longest([f"{search.penalty:.3g}"],
*(getattr(search.search, attr, []) for attr in fields),
getattr(search, 'categories', []),
fillvalue='')
else:
iters = itertools.zip_longest([f"{search.penalty:.3g}"],
*(getattr(search, attr, []) for attr in fields),
[],
fillvalue='')
for penalty, lookup, rank, cc, hnr, pc, qual, cat in iters:
yield [penalty, fmt_lookup(lookup), fmt_cstr(hnr),
fmt_cstr(pc), fmt_cstr(cc), fmt_cstr(qual), fmt_cstr(cat), fmt_ranking(rank)]
yield None

View File

@@ -0,0 +1,314 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of query analysis for the ICU tokenizer.
"""
from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
from collections import defaultdict
import dataclasses
import difflib
from icu import Transliterator
import sqlalchemy as sa
from nominatim_core.typing import SaRow
from nominatim_core.db.sqlalchemy_types import Json
from ..connection import SearchConnection
from ..logging import log
from ..search import query as qmod
from ..search.query_analyzer_factory import AbstractQueryAnalyzer
DB_TO_TOKEN_TYPE = {
'W': qmod.TokenType.WORD,
'w': qmod.TokenType.PARTIAL,
'H': qmod.TokenType.HOUSENUMBER,
'P': qmod.TokenType.POSTCODE,
'C': qmod.TokenType.COUNTRY
}
class QueryPart(NamedTuple):
""" Normalized and transliterated form of a single term in the query.
When the term came out of a split during the transliteration,
the normalized string is the full word before transliteration.
The word number keeps track of the word before transliteration
and can be used to identify partial transliterated terms.
"""
token: str
normalized: str
word_number: int
QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]]
def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
""" Return all combinations of words in the terms list after the
given position.
"""
total = len(terms)
for first in range(start, total):
word = terms[first].token
yield word, qmod.TokenRange(first, first + 1)
for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, terms[last].token))
yield word, qmod.TokenRange(first, last + 1)
@dataclasses.dataclass
class ICUToken(qmod.Token):
""" Specialised token for ICU tokenizer.
"""
word_token: str
info: Optional[Dict[str, Any]]
def get_category(self) -> Tuple[str, str]:
assert self.info
return self.info.get('class', ''), self.info.get('type', '')
def rematch(self, norm: str) -> None:
""" Check how well the token matches the given normalized string
and add a penalty, if necessary.
"""
if not self.lookup_word:
return
seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
distance = 0
for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
if tag in ('delete', 'insert') and (afrom == 0 or ato == len(self.lookup_word)):
distance += 1
elif tag == 'replace':
distance += max((ato-afrom), (bto-bfrom))
elif tag != 'equal':
distance += abs((ato-afrom) - (bto-bfrom))
self.penalty += (distance/len(self.lookup_word))
@staticmethod
def from_db_row(row: SaRow) -> 'ICUToken':
""" Create a ICUToken from the row of the word table.
"""
count = 1 if row.info is None else row.info.get('count', 1)
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
penalty = 0.0
if row.type == 'w':
penalty = 0.3
elif row.type == 'W':
if len(row.word_token) == 1 and row.word_token == row.word:
penalty = 0.2 if row.word.isdigit() else 0.3
elif row.type == 'H':
penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
if all(not c.isdigit() for c in row.word_token):
penalty += 0.2 * (len(row.word_token) - 1)
elif row.type == 'C':
if len(row.word_token) == 1:
penalty = 0.3
if row.info is None:
lookup_word = row.word
else:
lookup_word = row.info.get('lookup', row.word)
if lookup_word:
lookup_word = lookup_word.split('@', 1)[0]
else:
lookup_word = row.word_token
return ICUToken(penalty=penalty, token=row.word_id, count=max(1, count),
lookup_word=lookup_word, is_indexed=True,
word_token=row.word_token, info=row.info,
addr_count=max(1, addr_count))
class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" Converter for query strings into a tokenized query
using the tokens created by a ICU tokenizer.
"""
def __init__(self, conn: SearchConnection) -> None:
self.conn = conn
async def setup(self) -> None:
""" Set up static data structures needed for the analysis.
"""
async def _make_normalizer() -> Any:
rules = await self.conn.get_property('tokenizer_import_normalisation')
return Transliterator.createFromRules("normalization", rules)
self.normalizer = await self.conn.get_cached_value('ICUTOK', 'normalizer',
_make_normalizer)
async def _make_transliterator() -> Any:
rules = await self.conn.get_property('tokenizer_import_transliteration')
return Transliterator.createFromRules("transliteration", rules)
self.transliterator = await self.conn.get_cached_value('ICUTOK', 'transliterator',
_make_transliterator)
if 'word' not in self.conn.t.meta.tables:
sa.Table('word', self.conn.t.meta,
sa.Column('word_id', sa.Integer),
sa.Column('word_token', sa.Text, nullable=False),
sa.Column('type', sa.Text, nullable=False),
sa.Column('word', sa.Text),
sa.Column('info', Json))
async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
""" Analyze the given list of phrases and return the
tokenized query.
"""
log().section('Analyze query (using ICU tokenizer)')
normalized = list(filter(lambda p: p.text,
(qmod.Phrase(p.ptype, self.normalize_text(p.text))
for p in phrases)))
query = qmod.QueryStruct(normalized)
log().var_dump('Normalized query', query.source)
if not query.source:
return query
parts, words = self.split_query(query)
log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]:
token = ICUToken.from_db_row(row)
if row.type == 'S':
if row.info['op'] in ('in', 'near'):
if trange.start == 0:
query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
else:
if trange.start == 0 and trange.end == query.num_token_slots():
query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
else:
query.add_token(trange, qmod.TokenType.QUALIFIER, token)
else:
query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
self.add_extra_tokens(query, parts)
self.rerank_tokens(query, parts)
log().table_dump('Word tokens', _dump_word_tokens(query))
return query
def normalize_text(self, text: str) -> str:
""" Bring the given text into a normalized form. That is the
standardized form search will work with. All information removed
at this stage is inevitably lost.
"""
return cast(str, self.normalizer.transliterate(text))
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
""" Transliterate the phrases and split them into tokens.
Returns the list of transliterated tokens together with their
normalized form and a dictionary of words for lookup together
with their position.
"""
parts: QueryParts = []
phrase_start = 0
words = defaultdict(list)
wordnr = 0
for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype
for word in phrase.text.split(' '):
trans = self.transliterator.transliterate(word)
if trans:
for term in trans.split(' '):
if term:
parts.append(QueryPart(term, word, wordnr))
query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
query.nodes[-1].btype = qmod.BreakType.WORD
wordnr += 1
query.nodes[-1].btype = qmod.BreakType.PHRASE
for word, wrange in yield_words(parts, phrase_start):
words[word].append(wrange)
phrase_start = len(parts)
query.nodes[-1].btype = qmod.BreakType.END
return parts, words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the
given word tokens.
"""
t = self.conn.t.meta.tables['word']
return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add tokens to query that are not saved in the database.
"""
for part, node, i in zip(parts, query.nodes, range(1000)):
if len(part.token) <= 4 and part[0].isdigit()\
and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
ICUToken(0.5, 0, 1, 1, part.token, True, part.token, None))
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add penalties to tokens that depend on presence of other token.
"""
for i, node, tlist in query.iter_token_lists():
if tlist.ttype == qmod.TokenType.POSTCODE:
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
and (repl.ttype != qmod.TokenType.HOUSENUMBER
or len(tlist.tokens[0].lookup_word) > 4):
repl.add_penalty(0.39)
elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
and len(tlist.tokens[0].lookup_word) <= 3:
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
norm = parts[i].normalized
for j in range(i + 1, tlist.end):
if parts[j - 1].word_number != parts[j].word_number:
norm += ' ' + parts[j].normalized
for token in tlist.tokens:
cast(ICUToken, token).rematch(norm)
def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
out = query.nodes[0].btype.value
for node, part in zip(query.nodes[1:], parts):
out += part.token + node.btype.value
return out
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
for node in query.nodes:
for tlist in node.starting:
for token in tlist.tokens:
t = cast(ICUToken, token)
yield [tlist.ttype.name, t.token, t.word_token or '',
t.lookup_word or '', t.penalty, t.count, t.info]
async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
""" Create and set up a new query analyzer for a database based
on the ICU tokenizer.
"""
out = ICUQueryAnalyzer(conn)
await out.setup()
return out

View File

@@ -0,0 +1,272 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of query analysis for the legacy tokenizer.
"""
from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
from copy import copy
from collections import defaultdict
import dataclasses
import sqlalchemy as sa
from nominatim_core.typing import SaRow
from ..connection import SearchConnection
from ..logging import log
from . import query as qmod
from .query_analyzer_factory import AbstractQueryAnalyzer
def yield_words(terms: List[str], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
""" Return all combinations of words in the terms list after the
given position.
"""
total = len(terms)
for first in range(start, total):
word = terms[first]
yield word, qmod.TokenRange(first, first + 1)
for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, terms[last]))
yield word, qmod.TokenRange(first, last + 1)
@dataclasses.dataclass
class LegacyToken(qmod.Token):
""" Specialised token for legacy tokenizer.
"""
word_token: str
category: Optional[Tuple[str, str]]
country: Optional[str]
operator: Optional[str]
@property
def info(self) -> Dict[str, Any]:
""" Dictionary of additional properties of the token.
Should only be used for debugging purposes.
"""
return {'category': self.category,
'country': self.country,
'operator': self.operator}
def get_category(self) -> Tuple[str, str]:
assert self.category
return self.category
class LegacyQueryAnalyzer(AbstractQueryAnalyzer):
""" Converter for query strings into a tokenized query
using the tokens created by a legacy tokenizer.
"""
def __init__(self, conn: SearchConnection) -> None:
self.conn = conn
async def setup(self) -> None:
""" Set up static data structures needed for the analysis.
"""
self.max_word_freq = int(await self.conn.get_property('tokenizer_maxwordfreq'))
if 'word' not in self.conn.t.meta.tables:
sa.Table('word', self.conn.t.meta,
sa.Column('word_id', sa.Integer),
sa.Column('word_token', sa.Text, nullable=False),
sa.Column('word', sa.Text),
sa.Column('class', sa.Text),
sa.Column('type', sa.Text),
sa.Column('country_code', sa.Text),
sa.Column('search_name_count', sa.Integer),
sa.Column('operator', sa.Text))
async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
""" Analyze the given list of phrases and return the
tokenized query.
"""
log().section('Analyze query (using Legacy tokenizer)')
normalized = []
if phrases:
for row in await self.conn.execute(sa.select(*(sa.func.make_standard_name(p.text)
for p in phrases))):
normalized = [qmod.Phrase(p.ptype, r) for r, p in zip(row, phrases) if r]
break
query = qmod.QueryStruct(normalized)
log().var_dump('Normalized query', query.source)
if not query.source:
return query
parts, words = self.split_query(query)
lookup_words = list(words.keys())
log().var_dump('Split query', parts)
log().var_dump('Extracted words', lookup_words)
for row in await self.lookup_in_db(lookup_words):
for trange in words[row.word_token.strip()]:
token, ttype = self.make_token(row)
if ttype == qmod.TokenType.NEAR_ITEM:
if trange.start == 0:
query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
elif ttype == qmod.TokenType.QUALIFIER:
query.add_token(trange, qmod.TokenType.QUALIFIER, token)
if trange.start == 0 or trange.end == query.num_token_slots():
token = copy(token)
token.penalty += 0.1 * (query.num_token_slots())
query.add_token(trange, qmod.TokenType.NEAR_ITEM, token)
elif ttype != qmod.TokenType.PARTIAL or trange.start + 1 == trange.end:
query.add_token(trange, ttype, token)
self.add_extra_tokens(query, parts)
self.rerank_tokens(query)
log().table_dump('Word tokens', _dump_word_tokens(query))
return query
def normalize_text(self, text: str) -> str:
""" Bring the given text into a normalized form.
This only removes case, so some difference with the normalization
in the phrase remains.
"""
return text.lower()
def split_query(self, query: qmod.QueryStruct) -> Tuple[List[str],
Dict[str, List[qmod.TokenRange]]]:
""" Transliterate the phrases and split them into tokens.
Returns a list of transliterated tokens and a dictionary
of words for lookup together with their position.
"""
parts: List[str] = []
phrase_start = 0
words = defaultdict(list)
for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype
for trans in phrase.text.split(' '):
if trans:
for term in trans.split(' '):
if term:
parts.append(trans)
query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
query.nodes[-1].btype = qmod.BreakType.WORD
query.nodes[-1].btype = qmod.BreakType.PHRASE
for word, wrange in yield_words(parts, phrase_start):
words[word].append(wrange)
phrase_start = len(parts)
query.nodes[-1].btype = qmod.BreakType.END
return parts, words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the
given word tokens.
"""
t = self.conn.t.meta.tables['word']
sql = t.select().where(t.c.word_token.in_(words + [' ' + w for w in words]))
return await self.conn.execute(sql)
def make_token(self, row: SaRow) -> Tuple[LegacyToken, qmod.TokenType]:
""" Create a LegacyToken from the row of the word table.
Also determines the type of token.
"""
penalty = 0.0
is_indexed = True
rowclass = getattr(row, 'class')
if row.country_code is not None:
ttype = qmod.TokenType.COUNTRY
lookup_word = row.country_code
elif rowclass is not None:
if rowclass == 'place' and row.type == 'house':
ttype = qmod.TokenType.HOUSENUMBER
lookup_word = row.word_token[1:]
elif rowclass == 'place' and row.type == 'postcode':
ttype = qmod.TokenType.POSTCODE
lookup_word = row.word_token[1:]
else:
ttype = qmod.TokenType.NEAR_ITEM if row.operator in ('in', 'near')\
else qmod.TokenType.QUALIFIER
lookup_word = row.word
elif row.word_token.startswith(' '):
ttype = qmod.TokenType.WORD
lookup_word = row.word or row.word_token[1:]
else:
ttype = qmod.TokenType.PARTIAL
lookup_word = row.word_token
penalty = 0.21
if row.search_name_count > self.max_word_freq:
is_indexed = False
return LegacyToken(penalty=penalty, token=row.word_id,
count=max(1, row.search_name_count or 1),
addr_count=1, # not supported
lookup_word=lookup_word,
word_token=row.word_token.strip(),
category=(rowclass, row.type) if rowclass is not None else None,
country=row.country_code,
operator=row.operator,
is_indexed=is_indexed),\
ttype
def add_extra_tokens(self, query: qmod.QueryStruct, parts: List[str]) -> None:
""" Add tokens to query that are not saved in the database.
"""
for part, node, i in zip(parts, query.nodes, range(1000)):
if len(part) <= 4 and part.isdigit()\
and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
LegacyToken(penalty=0.5, token=0, count=1, addr_count=1,
lookup_word=part, word_token=part,
category=None, country=None,
operator=None, is_indexed=True))
def rerank_tokens(self, query: qmod.QueryStruct) -> None:
""" Add penalties to tokens that depend on presence of other token.
"""
for _, node, tlist in query.iter_token_lists():
if tlist.ttype == qmod.TokenType.POSTCODE:
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
and (repl.ttype != qmod.TokenType.HOUSENUMBER
or len(tlist.tokens[0].lookup_word) > 4):
repl.add_penalty(0.39)
elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
and len(tlist.tokens[0].lookup_word) <= 3:
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
for node in query.nodes:
for tlist in node.starting:
for token in tlist.tokens:
t = cast(LegacyToken, token)
yield [tlist.ttype.name, t.token, t.word_token or '',
t.lookup_word or '', t.penalty, t.count, t.info]
async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
""" Create and set up a new query analyzer for a database based
on the ICU tokenizer.
"""
out = LegacyQueryAnalyzer(conn)
await out.setup()
return out

View File

@@ -0,0 +1,297 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Datastructures for a tokenized query.
"""
from typing import List, Tuple, Optional, Iterator
from abc import ABC, abstractmethod
import dataclasses
import enum
class BreakType(enum.Enum):
""" Type of break between tokens.
"""
START = '<'
""" Begin of the query. """
END = '>'
""" End of the query. """
PHRASE = ','
""" Break between two phrases. """
WORD = ' '
""" Break between words. """
PART = '-'
""" Break inside a word, for example a hyphen or apostrophe. """
TOKEN = '`'
""" Break created as a result of tokenization.
This may happen in languages without spaces between words.
"""
class TokenType(enum.Enum):
""" Type of token.
"""
WORD = enum.auto()
""" Full name of a place. """
PARTIAL = enum.auto()
""" Word term without breaks, does not necessarily represent a full name. """
HOUSENUMBER = enum.auto()
""" Housenumber term. """
POSTCODE = enum.auto()
""" Postal code term. """
COUNTRY = enum.auto()
""" Country name or reference. """
QUALIFIER = enum.auto()
""" Special term used together with name (e.g. _Hotel_ Bellevue). """
NEAR_ITEM = enum.auto()
""" Special term used as searchable object(e.g. supermarket in ...). """
class PhraseType(enum.Enum):
""" Designation of a phrase.
"""
NONE = 0
""" No specific designation (i.e. source is free-form query). """
AMENITY = enum.auto()
""" Contains name or type of a POI. """
STREET = enum.auto()
""" Contains a street name optionally with a housenumber. """
CITY = enum.auto()
""" Contains the postal city. """
COUNTY = enum.auto()
""" Contains the equivalent of a county. """
STATE = enum.auto()
""" Contains a state or province. """
POSTCODE = enum.auto()
""" Contains a postal code. """
COUNTRY = enum.auto()
""" Contains the country name or code. """
def compatible_with(self, ttype: TokenType,
is_full_phrase: bool) -> bool:
""" Check if the given token type can be used with the phrase type.
"""
if self == PhraseType.NONE:
return not is_full_phrase or ttype != TokenType.QUALIFIER
if self == PhraseType.AMENITY:
return ttype in (TokenType.WORD, TokenType.PARTIAL)\
or (is_full_phrase and ttype == TokenType.NEAR_ITEM)\
or (not is_full_phrase and ttype == TokenType.QUALIFIER)
if self == PhraseType.STREET:
return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER)
if self == PhraseType.POSTCODE:
return ttype == TokenType.POSTCODE
if self == PhraseType.COUNTRY:
return ttype == TokenType.COUNTRY
return ttype in (TokenType.WORD, TokenType.PARTIAL)
@dataclasses.dataclass
class Token(ABC):
""" Base type for tokens.
Specific query analyzers must implement the concrete token class.
"""
penalty: float
token: int
count: int
addr_count: int
lookup_word: str
is_indexed: bool
@abstractmethod
def get_category(self) -> Tuple[str, str]:
""" Return the category restriction for qualifier terms and
category objects.
"""
@dataclasses.dataclass
class TokenRange:
""" Indexes of query nodes over which a token spans.
"""
start: int
end: int
def __lt__(self, other: 'TokenRange') -> bool:
return self.end <= other.start
def __le__(self, other: 'TokenRange') -> bool:
return NotImplemented
def __gt__(self, other: 'TokenRange') -> bool:
return self.start >= other.end
def __ge__(self, other: 'TokenRange') -> bool:
return NotImplemented
def replace_start(self, new_start: int) -> 'TokenRange':
""" Return a new token range with the new start.
"""
return TokenRange(new_start, self.end)
def replace_end(self, new_end: int) -> 'TokenRange':
""" Return a new token range with the new end.
"""
return TokenRange(self.start, new_end)
def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
""" Split the span into two spans at the given index.
The index must be within the span.
"""
return self.replace_end(index), self.replace_start(index)
@dataclasses.dataclass
class TokenList:
""" List of all tokens of a given type going from one breakpoint to another.
"""
end: int
ttype: TokenType
tokens: List[Token]
def add_penalty(self, penalty: float) -> None:
""" Add the given penalty to all tokens in the list.
"""
for token in self.tokens:
token.penalty += penalty
@dataclasses.dataclass
class QueryNode:
""" A node of the query representing a break between terms.
"""
btype: BreakType
ptype: PhraseType
starting: List[TokenList] = dataclasses.field(default_factory=list)
def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
""" Check if there are tokens of the given types ending at the
given node.
"""
return any(tl.end == end and tl.ttype in ttypes for tl in self.starting)
def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]:
""" Get the list of tokens of the given type starting at this node
and ending at the node 'end'. Returns 'None' if no such
tokens exist.
"""
for tlist in self.starting:
if tlist.end == end and tlist.ttype == ttype:
return tlist.tokens
return None
@dataclasses.dataclass
class Phrase:
""" A normalized query part. Phrases may be typed which means that
they then represent a specific part of the address.
"""
ptype: PhraseType
text: str
class QueryStruct:
""" A tokenized search query together with the normalized source
from which the tokens have been parsed.
The query contains a list of nodes that represent the breaks
between words. Tokens span between nodes, which don't necessarily
need to be direct neighbours. Thus the query is represented as a
directed acyclic graph.
When created, a query contains a single node: the start of the
query. Further nodes can be added by appending to 'nodes'.
"""
def __init__(self, source: List[Phrase]) -> None:
self.source = source
self.nodes: List[QueryNode] = \
[QueryNode(BreakType.START, source[0].ptype if source else PhraseType.NONE)]
def num_token_slots(self) -> int:
""" Return the length of the query in vertice steps.
"""
return len(self.nodes) - 1
def add_node(self, btype: BreakType, ptype: PhraseType) -> None:
""" Append a new break node with the given break type.
The phrase type denotes the type for any tokens starting
at the node.
"""
self.nodes.append(QueryNode(btype, ptype))
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
""" Add a token to the query. 'start' and 'end' are the indexes of the
nodes from which to which the token spans. The indexes must exist
and are expected to be in the same phrase.
'ttype' denotes the type of the token and 'token' the token to
be inserted.
If the token type is not compatible with the phrase it should
be added to, then the token is silently dropped.
"""
snode = self.nodes[trange.start]
full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
if snode.ptype.compatible_with(ttype, full_phrase):
tlist = snode.get_tokens(trange.end, ttype)
if tlist is None:
snode.starting.append(TokenList(trange.end, ttype, [token]))
else:
tlist.append(token)
def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
""" Get the list of tokens of a given type, spanning the given
nodes. The nodes must exist. If no tokens exist, an
empty list is returned.
"""
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
def get_partials_list(self, trange: TokenRange) -> List[Token]:
""" Create a list of partial tokens between the given nodes.
The list is composed of the first token of type PARTIAL
going to the subsequent node. Such PARTIAL tokens are
assumed to exist.
"""
return [next(iter(self.get_tokens(TokenRange(i, i+1), TokenType.PARTIAL)))
for i in range(trange.start, trange.end)]
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
""" Iterator over all token lists in the query.
"""
for i, node in enumerate(self.nodes):
for tlist in node.starting:
yield i, node, tlist
def find_lookup_word_by_id(self, token: int) -> str:
""" Find the first token with the given token ID and return
its lookup word. Returns 'None' if no such token exists.
The function is very slow and must only be used for
debugging.
"""
for node in self.nodes:
for tlist in node.starting:
for t in tlist.tokens:
if t.token == token:
return f"[{tlist.ttype.name[0]}]{t.lookup_word}"
return 'None'

View File

@@ -0,0 +1,54 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Factory for creating a query analyzer for the configured tokenizer.
"""
from typing import List, cast, TYPE_CHECKING
from abc import ABC, abstractmethod
from pathlib import Path
import importlib
from ..logging import log
from ..connection import SearchConnection
if TYPE_CHECKING:
from .query import Phrase, QueryStruct
class AbstractQueryAnalyzer(ABC):
""" Class for analysing incoming queries.
Query analyzers are tied to the tokenizer used on import.
"""
@abstractmethod
async def analyze_query(self, phrases: List['Phrase']) -> 'QueryStruct':
""" Analyze the given phrases and return the tokenized query.
"""
@abstractmethod
def normalize_text(self, text: str) -> str:
""" Bring the given text into a normalized form. That is the
standardized form search will work with. All information removed
at this stage is inevitably lost.
"""
async def make_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
""" Create a query analyzer for the tokenizer used by the database.
"""
name = await conn.get_property('tokenizer')
src_file = Path(__file__).parent / f'{name}_tokenizer.py'
if not src_file.is_file():
log().comment(f"No tokenizer named '{name}' available. Database not set up properly.")
raise RuntimeError('Tokenizer not found')
module = importlib.import_module(f'nominatim.api.search.{name}_tokenizer')
return cast(AbstractQueryAnalyzer, await module.create_query_analyzer(conn))

View File

@@ -0,0 +1,422 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Create query interpretations where each vertice in the query is assigned
a specific function (expressed as a token type).
"""
from typing import Optional, List, Iterator
import dataclasses
from ..logging import log
from . import query as qmod
# pylint: disable=too-many-return-statements,too-many-branches
@dataclasses.dataclass
class TypedRange:
""" A token range for a specific type of tokens.
"""
ttype: qmod.TokenType
trange: qmod.TokenRange
PENALTY_TOKENCHANGE = {
qmod.BreakType.START: 0.0,
qmod.BreakType.END: 0.0,
qmod.BreakType.PHRASE: 0.0,
qmod.BreakType.WORD: 0.1,
qmod.BreakType.PART: 0.2,
qmod.BreakType.TOKEN: 0.4
}
TypedRangeSeq = List[TypedRange]
@dataclasses.dataclass
class TokenAssignment: # pylint: disable=too-many-instance-attributes
""" Representation of a possible assignment of token types
to the tokens in a tokenized query.
"""
penalty: float = 0.0
name: Optional[qmod.TokenRange] = None
address: List[qmod.TokenRange] = dataclasses.field(default_factory=list)
housenumber: Optional[qmod.TokenRange] = None
postcode: Optional[qmod.TokenRange] = None
country: Optional[qmod.TokenRange] = None
near_item: Optional[qmod.TokenRange] = None
qualifier: Optional[qmod.TokenRange] = None
@staticmethod
def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
""" Create a new token assignment from a sequence of typed spans.
"""
out = TokenAssignment()
for token in ranges:
if token.ttype == qmod.TokenType.PARTIAL:
out.address.append(token.trange)
elif token.ttype == qmod.TokenType.HOUSENUMBER:
out.housenumber = token.trange
elif token.ttype == qmod.TokenType.POSTCODE:
out.postcode = token.trange
elif token.ttype == qmod.TokenType.COUNTRY:
out.country = token.trange
elif token.ttype == qmod.TokenType.NEAR_ITEM:
out.near_item = token.trange
elif token.ttype == qmod.TokenType.QUALIFIER:
out.qualifier = token.trange
return out
class _TokenSequence:
""" Working state used to put together the token assignments.
Represents an intermediate state while traversing the tokenized
query.
"""
def __init__(self, seq: TypedRangeSeq,
direction: int = 0, penalty: float = 0.0) -> None:
self.seq = seq
self.direction = direction
self.penalty = penalty
def __str__(self) -> str:
seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
@property
def end_pos(self) -> int:
""" Return the index of the global end of the current sequence.
"""
return self.seq[-1].trange.end if self.seq else 0
def has_types(self, *ttypes: qmod.TokenType) -> bool:
""" Check if the current sequence contains any typed ranges of
the given types.
"""
return any(s.ttype in ttypes for s in self.seq)
def is_final(self) -> bool:
""" Return true when the sequence cannot be extended by any
form of token anymore.
"""
# Country and category must be the final term for left-to-right
return len(self.seq) > 1 and \
self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
""" Check if the give token type is appendable to the existing sequence.
Returns None if the token type is not appendable, otherwise the
new direction of the sequence after adding such a type. The
token is not added.
"""
if ttype == qmod.TokenType.WORD:
return None
if not self.seq:
# Append unconditionally to the empty list
if ttype == qmod.TokenType.COUNTRY:
return -1
if ttype in (qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
return 1
return self.direction
# Name tokens are always acceptable and don't change direction
if ttype == qmod.TokenType.PARTIAL:
# qualifiers cannot appear in the middle of the query. They need
# to be near the next phrase.
if self.direction == -1 \
and any(t.ttype == qmod.TokenType.QUALIFIER for t in self.seq[:-1]):
return None
return self.direction
# Other tokens may only appear once
if self.has_types(ttype):
return None
if ttype == qmod.TokenType.HOUSENUMBER:
if self.direction == 1:
if len(self.seq) == 1 and self.seq[0].ttype == qmod.TokenType.QUALIFIER:
return None
if len(self.seq) > 2 \
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
return None # direction left-to-right: housenumber must come before anything
elif self.direction == -1 \
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
return -1 # force direction right-to-left if after other terms
return self.direction
if ttype == qmod.TokenType.POSTCODE:
if self.direction == -1:
if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
return None
return -1
if self.direction == 1:
return None if self.has_types(qmod.TokenType.COUNTRY) else 1
if self.has_types(qmod.TokenType.HOUSENUMBER, qmod.TokenType.QUALIFIER):
return 1
return self.direction
if ttype == qmod.TokenType.COUNTRY:
return None if self.direction == -1 else 1
if ttype == qmod.TokenType.NEAR_ITEM:
return self.direction
if ttype == qmod.TokenType.QUALIFIER:
if self.direction == 1:
if (len(self.seq) == 1
and self.seq[0].ttype in (qmod.TokenType.PARTIAL, qmod.TokenType.NEAR_ITEM)) \
or (len(self.seq) == 2
and self.seq[0].ttype == qmod.TokenType.NEAR_ITEM
and self.seq[1].ttype == qmod.TokenType.PARTIAL):
return 1
return None
if self.direction == -1:
return -1
tempseq = self.seq[1:] if self.seq[0].ttype == qmod.TokenType.NEAR_ITEM else self.seq
if len(tempseq) == 0:
return 1
if len(tempseq) == 1 and self.seq[0].ttype == qmod.TokenType.HOUSENUMBER:
return None
if len(tempseq) > 1 or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
return -1
return 0
return None
def advance(self, ttype: qmod.TokenType, end_pos: int,
btype: qmod.BreakType) -> Optional['_TokenSequence']:
""" Return a new token sequence state with the given token type
extended.
"""
newdir = self.appendable(ttype)
if newdir is None:
return None
if not self.seq:
newseq = [TypedRange(ttype, qmod.TokenRange(0, end_pos))]
new_penalty = 0.0
else:
last = self.seq[-1]
if btype != qmod.BreakType.PHRASE and last.ttype == ttype:
# extend the existing range
newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
new_penalty = 0.0
else:
# start a new range
newseq = list(self.seq) + [TypedRange(ttype,
qmod.TokenRange(last.trange.end, end_pos))]
new_penalty = PENALTY_TOKENCHANGE[btype]
return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
if priors >= 2:
if self.direction == 0:
self.direction = new_dir
else:
if priors == 2:
self.penalty += 0.8
else:
return False
return True
def recheck_sequence(self) -> bool:
""" Check that the sequence is a fully valid token assignment
and adapt direction and penalties further if necessary.
This function catches some impossible assignments that need
forward context and can therefore not be excluded when building
the assignment.
"""
# housenumbers may not be further than 2 words from the beginning.
# If there are two words in front, give it a penalty.
hnrpos = next((i for i, tr in enumerate(self.seq)
if tr.ttype == qmod.TokenType.HOUSENUMBER),
None)
if hnrpos is not None:
if self.direction != -1:
priors = sum(1 for t in self.seq[:hnrpos] if t.ttype == qmod.TokenType.PARTIAL)
if not self._adapt_penalty_from_priors(priors, -1):
return False
if self.direction != 1:
priors = sum(1 for t in self.seq[hnrpos+1:] if t.ttype == qmod.TokenType.PARTIAL)
if not self._adapt_penalty_from_priors(priors, 1):
return False
if any(t.ttype == qmod.TokenType.NEAR_ITEM for t in self.seq):
self.penalty += 1.0
return True
def _get_assignments_postcode(self, base: TokenAssignment,
query_len: int) -> Iterator[TokenAssignment]:
""" Yield possible assignments of Postcode searches with an
address component.
"""
assert base.postcode is not None
if (base.postcode.start == 0 and self.direction != -1)\
or (base.postcode.end == query_len and self.direction != 1):
log().comment('postcode search')
# <address>,<postcode> should give preference to address search
if base.postcode.start == 0:
penalty = self.penalty
self.direction = -1 # name searches are only possible backwards
else:
penalty = self.penalty + 0.1
self.direction = 1 # name searches are only possible forwards
yield dataclasses.replace(base, penalty=penalty)
def _get_assignments_address_forward(self, base: TokenAssignment,
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments of address searches with
left-to-right reading.
"""
first = base.address[0]
log().comment('first word = name')
yield dataclasses.replace(base, penalty=self.penalty,
name=first, address=base.address[1:])
# To paraphrase:
# * if another name term comes after the first one and before the
# housenumber
# * a qualifier comes after the name
# * the containing phrase is strictly typed
if (base.housenumber and first.end < base.housenumber.start)\
or (base.qualifier and base.qualifier > first)\
or (query.nodes[first.start].ptype != qmod.PhraseType.NONE):
return
penalty = self.penalty
# Penalty for:
# * <name>, <street>, <housenumber> , ...
# * queries that are comma-separated
if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
penalty += 0.25
for i in range(first.start + 1, first.end):
name, addr = first.split(i)
log().comment(f'split first word = name ({i - first.start})')
yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def _get_assignments_address_backward(self, base: TokenAssignment,
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments of address searches with
right-to-left reading.
"""
last = base.address[-1]
if self.direction == -1 or len(base.address) > 1:
log().comment('last word = name')
yield dataclasses.replace(base, penalty=self.penalty,
name=last, address=base.address[:-1])
# To paraphrase:
# * if another name term comes before the last one and after the
# housenumber
# * a qualifier comes before the name
# * the containing phrase is strictly typed
if (base.housenumber and last.start > base.housenumber.end)\
or (base.qualifier and base.qualifier < last)\
or (query.nodes[last.start].ptype != qmod.PhraseType.NONE):
return
penalty = self.penalty
if base.housenumber and base.housenumber < last:
penalty += 0.4
if len(query.source) > 1:
penalty += 0.25
for i in range(last.start + 1, last.end):
addr, name = last.split(i)
log().comment(f'split last word = name ({i - last.start})')
yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments for the current sequence.
This function splits up general name assignments into name
and address and yields all possible variants of that.
"""
base = TokenAssignment.from_ranges(self.seq)
num_addr_tokens = sum(t.end - t.start for t in base.address)
if num_addr_tokens > 50:
return
# Postcode search (postcode-only search is covered in next case)
if base.postcode is not None and base.address:
yield from self._get_assignments_postcode(base, query.num_token_slots())
# Postcode or country-only search
if not base.address:
if not base.housenumber and (base.postcode or base.country or base.near_item):
log().comment('postcode/country search')
yield dataclasses.replace(base, penalty=self.penalty)
else:
# <postcode>,<address> should give preference to postcode search
if base.postcode and base.postcode.start == 0:
self.penalty += 0.1
# Right-to-left reading of the address
if self.direction != -1:
yield from self._get_assignments_address_forward(base, query)
# Left-to-right reading of the address
if self.direction != 1:
yield from self._get_assignments_address_backward(base, query)
# variant for special housenumber searches
if base.housenumber and not base.qualifier:
yield dataclasses.replace(base, penalty=self.penalty)
def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Return possible word type assignments to word positions.
The assignments are computed from the concrete tokens listed
in the tokenized query.
The result includes the penalty for transitions from one word type to
another. It does not include penalties for transitions within a
type.
"""
todo = [_TokenSequence([], direction=0 if query.source[0].ptype == qmod.PhraseType.NONE else 1)]
while todo:
state = todo.pop()
node = query.nodes[state.end_pos]
for tlist in node.starting:
newstate = state.advance(tlist.ttype, tlist.end, node.btype)
if newstate is not None:
if newstate.end_pos == query.num_token_slots():
if newstate.recheck_sequence():
log().var_dump('Assignment', newstate)
yield from newstate.get_assignments(query)
elif not newstate.is_final():
todo.append(newstate)