forked from hans/Nominatim
fix style issue found by flake8
This commit is contained in:
@@ -42,7 +42,7 @@ def build_poi_search(category: List[Tuple[str, str]],
|
||||
class _PoiData(dbf.SearchData):
|
||||
penalty = 0.0
|
||||
qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
|
||||
countries=ccs
|
||||
countries = ccs
|
||||
|
||||
return dbs.PoiSearch(_PoiData())
|
||||
|
||||
@@ -55,15 +55,13 @@ class SearchBuilder:
|
||||
self.query = query
|
||||
self.details = details
|
||||
|
||||
|
||||
@property
|
||||
def configured_for_country(self) -> bool:
|
||||
""" Return true if the search details are configured to
|
||||
allow countries in the result.
|
||||
"""
|
||||
return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
|
||||
and self.details.layer_enabled(DataLayer.ADDRESS)
|
||||
|
||||
and self.details.layer_enabled(DataLayer.ADDRESS)
|
||||
|
||||
@property
|
||||
def configured_for_postcode(self) -> bool:
|
||||
@@ -71,8 +69,7 @@ class SearchBuilder:
|
||||
allow postcodes in the result.
|
||||
"""
|
||||
return self.details.min_rank <= 5 and self.details.max_rank >= 11\
|
||||
and self.details.layer_enabled(DataLayer.ADDRESS)
|
||||
|
||||
and self.details.layer_enabled(DataLayer.ADDRESS)
|
||||
|
||||
@property
|
||||
def configured_for_housenumbers(self) -> bool:
|
||||
@@ -80,8 +77,7 @@ class SearchBuilder:
|
||||
allow addresses in the result.
|
||||
"""
|
||||
return self.details.max_rank >= 30 \
|
||||
and self.details.layer_enabled(DataLayer.ADDRESS)
|
||||
|
||||
and self.details.layer_enabled(DataLayer.ADDRESS)
|
||||
|
||||
def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
|
||||
""" Yield all possible abstract searches for the given token assignment.
|
||||
@@ -92,7 +88,7 @@ class SearchBuilder:
|
||||
|
||||
near_items = self.get_near_items(assignment)
|
||||
if near_items is not None and not near_items:
|
||||
return # impossible compbination of near items and category parameter
|
||||
return # impossible combination of near items and category parameter
|
||||
|
||||
if assignment.name is None:
|
||||
if near_items and not sdata.postcodes:
|
||||
@@ -123,7 +119,6 @@ class SearchBuilder:
|
||||
search.penalty += assignment.penalty
|
||||
yield search
|
||||
|
||||
|
||||
def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
|
||||
""" Build abstract search query for a simple category search.
|
||||
This kind of search requires an additional geographic constraint.
|
||||
@@ -132,7 +127,6 @@ class SearchBuilder:
|
||||
and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
|
||||
yield dbs.PoiSearch(sdata)
|
||||
|
||||
|
||||
def build_special_search(self, sdata: dbf.SearchData,
|
||||
address: List[TokenRange],
|
||||
is_category: bool) -> Iterator[dbs.AbstractSearch]:
|
||||
@@ -157,7 +151,6 @@ class SearchBuilder:
|
||||
penalty += 0.2
|
||||
yield dbs.PostcodeSearch(penalty, sdata)
|
||||
|
||||
|
||||
def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
|
||||
address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
|
||||
""" Build a simple address search for special entries where the
|
||||
@@ -167,7 +160,7 @@ class SearchBuilder:
|
||||
expected_count = sum(t.count for t in hnrs)
|
||||
|
||||
partials = {t.token: t.addr_count for trange in address
|
||||
for t in self.query.get_partials_list(trange)}
|
||||
for t in self.query.get_partials_list(trange)}
|
||||
|
||||
if not partials:
|
||||
# can happen when none of the partials is indexed
|
||||
@@ -190,7 +183,6 @@ class SearchBuilder:
|
||||
sdata.housenumbers = dbf.WeightedStrings([], [])
|
||||
yield dbs.PlaceSearch(0.05, sdata, expected_count)
|
||||
|
||||
|
||||
def build_name_search(self, sdata: dbf.SearchData,
|
||||
name: TokenRange, address: List[TokenRange],
|
||||
is_category: bool) -> Iterator[dbs.AbstractSearch]:
|
||||
@@ -205,14 +197,13 @@ class SearchBuilder:
|
||||
sdata.lookups = lookup
|
||||
yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
|
||||
|
||||
|
||||
def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
|
||||
-> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
|
||||
def yield_lookups(self, name: TokenRange, address: List[TokenRange]
|
||||
) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
|
||||
""" Yield all variants how the given name and address should best
|
||||
be searched for. This takes into account how frequent the terms
|
||||
are and tries to find a lookup that optimizes index use.
|
||||
"""
|
||||
penalty = 0.0 # extra penalty
|
||||
penalty = 0.0 # extra penalty
|
||||
name_partials = {t.token: t for t in self.query.get_partials_list(name)}
|
||||
|
||||
addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
|
||||
@@ -231,7 +222,7 @@ class SearchBuilder:
|
||||
fulls_count = sum(t.count for t in name_fulls)
|
||||
|
||||
if fulls_count < 50000 or addr_count < 30000:
|
||||
yield penalty,fulls_count / (2**len(addr_tokens)), \
|
||||
yield penalty, fulls_count / (2**len(addr_tokens)), \
|
||||
self.get_full_name_ranking(name_fulls, addr_partials,
|
||||
fulls_count > 30000 / max(1, len(addr_tokens)))
|
||||
|
||||
@@ -241,9 +232,8 @@ class SearchBuilder:
|
||||
if exp_count < 10000 and addr_count < 20000:
|
||||
penalty += 0.35 * max(1 if name_fulls else 0.1,
|
||||
5 - len(name_partials) - len(addr_tokens))
|
||||
yield penalty, exp_count,\
|
||||
self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
|
||||
|
||||
yield penalty, exp_count, \
|
||||
self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
|
||||
|
||||
def get_name_address_ranking(self, name_tokens: List[int],
|
||||
addr_partials: List[Token]) -> List[dbf.FieldLookup]:
|
||||
@@ -268,7 +258,6 @@ class SearchBuilder:
|
||||
|
||||
return lookup
|
||||
|
||||
|
||||
def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
|
||||
use_lookup: bool) -> List[dbf.FieldLookup]:
|
||||
""" Create a ranking expression with full name terms and
|
||||
@@ -293,7 +282,6 @@ class SearchBuilder:
|
||||
return dbf.lookup_by_any_name([t.token for t in name_fulls],
|
||||
addr_restrict_tokens, addr_lookup_tokens)
|
||||
|
||||
|
||||
def get_name_ranking(self, trange: TokenRange,
|
||||
db_field: str = 'name_vector') -> dbf.FieldRanking:
|
||||
""" Create a ranking expression for a name term in the given range.
|
||||
@@ -306,7 +294,6 @@ class SearchBuilder:
|
||||
default = sum(t.penalty for t in name_partials) + 0.2
|
||||
return dbf.FieldRanking(db_field, default, ranks)
|
||||
|
||||
|
||||
def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
|
||||
""" Create a list of ranking expressions for an address term
|
||||
for the given ranges.
|
||||
@@ -315,7 +302,7 @@ class SearchBuilder:
|
||||
heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
|
||||
ranks: List[dbf.RankedTokens] = []
|
||||
|
||||
while todo: # pylint: disable=too-many-nested-blocks
|
||||
while todo:
|
||||
neglen, pos, rank = heapq.heappop(todo)
|
||||
for tlist in self.query.nodes[pos].starting:
|
||||
if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
|
||||
@@ -354,7 +341,6 @@ class SearchBuilder:
|
||||
|
||||
return dbf.FieldRanking('nameaddress_vector', default, ranks)
|
||||
|
||||
|
||||
def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
|
||||
""" Collect the tokens for the non-name search fields in the
|
||||
assignment.
|
||||
@@ -401,7 +387,6 @@ class SearchBuilder:
|
||||
|
||||
return sdata
|
||||
|
||||
|
||||
def get_country_tokens(self, trange: TokenRange) -> List[Token]:
|
||||
""" Return the list of country tokens for the given range,
|
||||
optionally filtered by the country list from the details
|
||||
@@ -413,7 +398,6 @@ class SearchBuilder:
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
|
||||
""" Return the list of qualifier tokens for the given range,
|
||||
optionally filtered by the qualifier list from the details
|
||||
@@ -425,7 +409,6 @@ class SearchBuilder:
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
|
||||
""" Collect tokens for near items search or use the categories
|
||||
requested per parameter.
|
||||
|
||||
@@ -28,11 +28,9 @@ class WeightedStrings:
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.values)
|
||||
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[str, float]]:
|
||||
return iter(zip(self.values, self.penalties))
|
||||
|
||||
|
||||
def get_penalty(self, value: str, default: float = 1000.0) -> float:
|
||||
""" Get the penalty for the given value. Returns the given default
|
||||
if the value does not exist.
|
||||
@@ -54,11 +52,9 @@ class WeightedCategories:
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.values)
|
||||
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
|
||||
return iter(zip(self.values, self.penalties))
|
||||
|
||||
|
||||
def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
|
||||
""" Get the penalty for the given value. Returns the given default
|
||||
if the value does not exist.
|
||||
@@ -69,7 +65,6 @@ class WeightedCategories:
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def sql_restrict(self, table: SaFromClause) -> SaExpression:
|
||||
""" Return an SQLAlcheny expression that restricts the
|
||||
class and type columns of the given table to the values
|
||||
@@ -125,7 +120,6 @@ class FieldRanking:
|
||||
ranking.penalty -= min_penalty
|
||||
return min_penalty
|
||||
|
||||
|
||||
def sql_penalty(self, table: SaFromClause) -> SaColumn:
|
||||
""" Create an SQL expression for the rankings.
|
||||
"""
|
||||
@@ -177,7 +171,6 @@ class SearchData:
|
||||
|
||||
qualifiers: WeightedCategories = WeightedCategories([], [])
|
||||
|
||||
|
||||
def set_strings(self, field: str, tokens: List[Token]) -> None:
|
||||
""" Set on of the WeightedStrings properties from the given
|
||||
token list. Adapt the global penalty, so that the
|
||||
@@ -191,7 +184,6 @@ class SearchData:
|
||||
|
||||
setattr(self, field, wstrs)
|
||||
|
||||
|
||||
def set_qualifiers(self, tokens: List[Token]) -> None:
|
||||
""" Set the qulaifier field from the given tokens.
|
||||
"""
|
||||
@@ -207,7 +199,6 @@ class SearchData:
|
||||
self.qualifiers = WeightedCategories(list(categories.keys()),
|
||||
list(categories.values()))
|
||||
|
||||
|
||||
def set_ranking(self, rankings: List[FieldRanking]) -> None:
|
||||
""" Set the list of rankings and normalize the ranking.
|
||||
"""
|
||||
|
||||
@@ -15,10 +15,10 @@ from sqlalchemy.ext.compiler import compiles
|
||||
from ..typing import SaFromClause
|
||||
from ..sql.sqlalchemy_types import IntArray
|
||||
|
||||
# pylint: disable=consider-using-f-string
|
||||
|
||||
LookupType = sa.sql.expression.FunctionElement[Any]
|
||||
|
||||
|
||||
class LookupAll(LookupType):
|
||||
""" Find all entries in search_name table that contain all of
|
||||
a given list of tokens using an index for the search.
|
||||
@@ -40,7 +40,7 @@ def _default_lookup_all(element: LookupAll,
|
||||
|
||||
@compiles(LookupAll, 'sqlite')
|
||||
def _sqlite_lookup_all(element: LookupAll,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
place, col, colname, tokens = list(element.clauses)
|
||||
return "(%s IN (SELECT CAST(value as bigint) FROM"\
|
||||
" (SELECT array_intersect_fuzzy(places) as p FROM"\
|
||||
@@ -50,13 +50,11 @@ def _sqlite_lookup_all(element: LookupAll,
|
||||
" ORDER BY length(places)) as x) as u,"\
|
||||
" json_each('[' || u.p || ']'))"\
|
||||
" AND array_contains(%s, %s))"\
|
||||
% (compiler.process(place, **kw),
|
||||
compiler.process(tokens, **kw),
|
||||
compiler.process(colname, **kw),
|
||||
compiler.process(col, **kw),
|
||||
compiler.process(tokens, **kw)
|
||||
)
|
||||
|
||||
% (compiler.process(place, **kw),
|
||||
compiler.process(tokens, **kw),
|
||||
compiler.process(colname, **kw),
|
||||
compiler.process(col, **kw),
|
||||
compiler.process(tokens, **kw))
|
||||
|
||||
|
||||
class LookupAny(LookupType):
|
||||
@@ -69,6 +67,7 @@ class LookupAny(LookupType):
|
||||
super().__init__(table.c.place_id, getattr(table.c, column), column,
|
||||
sa.type_coerce(tokens, IntArray))
|
||||
|
||||
|
||||
@compiles(LookupAny)
|
||||
def _default_lookup_any(element: LookupAny,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
@@ -76,9 +75,10 @@ def _default_lookup_any(element: LookupAny,
|
||||
return "(%s && %s)" % (compiler.process(col, **kw),
|
||||
compiler.process(tokens, **kw))
|
||||
|
||||
|
||||
@compiles(LookupAny, 'sqlite')
|
||||
def _sqlite_lookup_any(element: LookupAny,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
place, _, colname, tokens = list(element.clauses)
|
||||
return "%s IN (SELECT CAST(value as bigint) FROM"\
|
||||
" (SELECT array_union(places) as p FROM reverse_search_name"\
|
||||
@@ -89,7 +89,6 @@ def _sqlite_lookup_any(element: LookupAny,
|
||||
compiler.process(colname, **kw))
|
||||
|
||||
|
||||
|
||||
class Restrict(LookupType):
|
||||
""" Find all entries that contain all of the given tokens.
|
||||
Do not use an index for the search.
|
||||
@@ -103,12 +102,13 @@ class Restrict(LookupType):
|
||||
|
||||
@compiles(Restrict)
|
||||
def _default_restrict(element: Restrict,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
|
||||
compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
@compiles(Restrict, 'sqlite')
|
||||
def _sqlite_restrict(element: Restrict,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "array_contains(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
@@ -20,14 +20,12 @@ from ..types import SearchDetails, DataLayer, GeometryFormat, Bbox
|
||||
from .. import results as nres
|
||||
from .db_search_fields import SearchData, WeightedCategories
|
||||
|
||||
#pylint: disable=singleton-comparison,not-callable
|
||||
#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
|
||||
|
||||
def no_index(expr: SaColumn) -> SaColumn:
|
||||
""" Wrap the given expression, so that the query planner will
|
||||
refrain from using the expression for index lookup.
|
||||
"""
|
||||
return sa.func.coalesce(sa.null(), expr) # pylint: disable=not-callable
|
||||
return sa.func.coalesce(sa.null(), expr)
|
||||
|
||||
|
||||
def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
|
||||
@@ -68,7 +66,7 @@ def filter_by_area(sql: SaSelect, t: SaFromClause,
|
||||
if details.viewbox is not None and details.bounded_viewbox:
|
||||
sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
|
||||
use_index=not avoid_index and
|
||||
details.viewbox.area < 0.2))
|
||||
details.viewbox.area < 0.2))
|
||||
|
||||
return sql
|
||||
|
||||
@@ -190,7 +188,7 @@ def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
|
||||
as rows in the column 'nr'.
|
||||
"""
|
||||
vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
|
||||
.table_valued(sa.column('value', type_=sa.JSON))
|
||||
.table_valued(sa.column('value', type_=sa.JSON))
|
||||
return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
|
||||
|
||||
|
||||
@@ -266,7 +264,6 @@ class NearSearch(AbstractSearch):
|
||||
self.search = search
|
||||
self.categories = categories
|
||||
|
||||
|
||||
async def lookup(self, conn: SearchConnection,
|
||||
details: SearchDetails) -> nres.SearchResults:
|
||||
""" Find results for the search in the database.
|
||||
@@ -288,11 +285,12 @@ class NearSearch(AbstractSearch):
|
||||
else:
|
||||
min_rank = 26
|
||||
max_rank = 30
|
||||
base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
|
||||
and r.accuracy <= max_accuracy
|
||||
and r.bbox and r.bbox.area < 20
|
||||
and r.rank_address >= min_rank
|
||||
and r.rank_address <= max_rank)
|
||||
base = nres.SearchResults(r for r in base
|
||||
if (r.source_table == nres.SourceTable.PLACEX
|
||||
and r.accuracy <= max_accuracy
|
||||
and r.bbox and r.bbox.area < 20
|
||||
and r.rank_address >= min_rank
|
||||
and r.rank_address <= max_rank))
|
||||
|
||||
if base:
|
||||
baseids = [b.place_id for b in base[:5] if b.place_id]
|
||||
@@ -304,7 +302,6 @@ class NearSearch(AbstractSearch):
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def lookup_category(self, results: nres.SearchResults,
|
||||
conn: SearchConnection, ids: List[int],
|
||||
category: Tuple[str, str], penalty: float,
|
||||
@@ -334,9 +331,9 @@ class NearSearch(AbstractSearch):
|
||||
.join(tgeom,
|
||||
table.c.centroid.ST_CoveredBy(
|
||||
sa.case((sa.and_(tgeom.c.rank_address > 9,
|
||||
tgeom.c.geometry.is_area()),
|
||||
tgeom.c.geometry.is_area()),
|
||||
tgeom.c.geometry),
|
||||
else_ = tgeom.c.centroid.ST_Expand(0.05))))
|
||||
else_=tgeom.c.centroid.ST_Expand(0.05))))
|
||||
|
||||
inner = sql.where(tgeom.c.place_id.in_(ids))\
|
||||
.group_by(table.c.place_id).subquery()
|
||||
@@ -363,7 +360,6 @@ class NearSearch(AbstractSearch):
|
||||
results.append(result)
|
||||
|
||||
|
||||
|
||||
class PoiSearch(AbstractSearch):
|
||||
""" Category search in a geographic area.
|
||||
"""
|
||||
@@ -372,7 +368,6 @@ class PoiSearch(AbstractSearch):
|
||||
self.qualifiers = sdata.qualifiers
|
||||
self.countries = sdata.countries
|
||||
|
||||
|
||||
async def lookup(self, conn: SearchConnection,
|
||||
details: SearchDetails) -> nres.SearchResults:
|
||||
""" Find results for the search in the database.
|
||||
@@ -387,7 +382,7 @@ class PoiSearch(AbstractSearch):
|
||||
def _base_query() -> SaSelect:
|
||||
return _select_placex(t) \
|
||||
.add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
|
||||
.label('importance'))\
|
||||
.label('importance'))\
|
||||
.where(t.c.linked_place_id == None) \
|
||||
.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
|
||||
.order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
|
||||
@@ -396,9 +391,9 @@ class PoiSearch(AbstractSearch):
|
||||
classtype = self.qualifiers.values
|
||||
if len(classtype) == 1:
|
||||
cclass, ctype = classtype[0]
|
||||
sql: SaLambdaSelect = sa.lambda_stmt(lambda: _base_query()
|
||||
.where(t.c.class_ == cclass)
|
||||
.where(t.c.type == ctype))
|
||||
sql: SaLambdaSelect = sa.lambda_stmt(
|
||||
lambda: _base_query().where(t.c.class_ == cclass)
|
||||
.where(t.c.type == ctype))
|
||||
else:
|
||||
sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
|
||||
for cls, typ in classtype)))
|
||||
@@ -455,7 +450,6 @@ class CountrySearch(AbstractSearch):
|
||||
super().__init__(sdata.penalty)
|
||||
self.countries = sdata.countries
|
||||
|
||||
|
||||
async def lookup(self, conn: SearchConnection,
|
||||
details: SearchDetails) -> nres.SearchResults:
|
||||
""" Find results for the search in the database.
|
||||
@@ -464,9 +458,9 @@ class CountrySearch(AbstractSearch):
|
||||
|
||||
ccodes = self.countries.values
|
||||
sql = _select_placex(t)\
|
||||
.add_columns(t.c.importance)\
|
||||
.where(t.c.country_code.in_(ccodes))\
|
||||
.where(t.c.rank_address == 4)
|
||||
.add_columns(t.c.importance)\
|
||||
.where(t.c.country_code.in_(ccodes))\
|
||||
.where(t.c.rank_address == 4)
|
||||
|
||||
if details.geometry_output:
|
||||
sql = _add_geometry_columns(sql, t.c.geometry, details)
|
||||
@@ -493,7 +487,6 @@ class CountrySearch(AbstractSearch):
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def lookup_in_country_table(self, conn: SearchConnection,
|
||||
details: SearchDetails) -> nres.SearchResults:
|
||||
""" Look up the country in the fallback country tables.
|
||||
@@ -509,7 +502,7 @@ class CountrySearch(AbstractSearch):
|
||||
|
||||
sql = sa.select(tgrid.c.country_code,
|
||||
tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
|
||||
.label('centroid'),
|
||||
.label('centroid'),
|
||||
tgrid.c.geometry.ST_Collect().ST_Expand(0).label('bbox'))\
|
||||
.where(tgrid.c.country_code.in_(self.countries.values))\
|
||||
.group_by(tgrid.c.country_code)
|
||||
@@ -537,7 +530,6 @@ class CountrySearch(AbstractSearch):
|
||||
return results
|
||||
|
||||
|
||||
|
||||
class PostcodeSearch(AbstractSearch):
|
||||
""" Search for a postcode.
|
||||
"""
|
||||
@@ -548,7 +540,6 @@ class PostcodeSearch(AbstractSearch):
|
||||
self.lookups = sdata.lookups
|
||||
self.rankings = sdata.rankings
|
||||
|
||||
|
||||
async def lookup(self, conn: SearchConnection,
|
||||
details: SearchDetails) -> nres.SearchResults:
|
||||
""" Find results for the search in the database.
|
||||
@@ -588,14 +579,13 @@ class PostcodeSearch(AbstractSearch):
|
||||
tsearch = conn.t.search_name
|
||||
sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
|
||||
.where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
|
||||
.contains(sa.type_coerce(self.lookups[0].tokens,
|
||||
IntArray)))
|
||||
.contains(sa.type_coerce(self.lookups[0].tokens,
|
||||
IntArray)))
|
||||
|
||||
for ranking in self.rankings:
|
||||
penalty += ranking.sql_penalty(conn.t.search_name)
|
||||
penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
|
||||
else_=1.0)
|
||||
|
||||
else_=1.0)
|
||||
|
||||
sql = sql.add_columns(penalty.label('accuracy'))
|
||||
sql = sql.order_by('accuracy').limit(LIMIT_PARAM)
|
||||
@@ -603,13 +593,14 @@ class PostcodeSearch(AbstractSearch):
|
||||
results = nres.SearchResults()
|
||||
for row in await conn.execute(sql, _details_to_bind_params(details)):
|
||||
p = conn.t.placex
|
||||
placex_sql = _select_placex(p).add_columns(p.c.importance)\
|
||||
.where(sa.text("""class = 'boundary'
|
||||
AND type = 'postal_code'
|
||||
AND osm_type = 'R'"""))\
|
||||
.where(p.c.country_code == row.country_code)\
|
||||
.where(p.c.postcode == row.postcode)\
|
||||
.limit(1)
|
||||
placex_sql = _select_placex(p)\
|
||||
.add_columns(p.c.importance)\
|
||||
.where(sa.text("""class = 'boundary'
|
||||
AND type = 'postal_code'
|
||||
AND osm_type = 'R'"""))\
|
||||
.where(p.c.country_code == row.country_code)\
|
||||
.where(p.c.postcode == row.postcode)\
|
||||
.limit(1)
|
||||
|
||||
if details.geometry_output:
|
||||
placex_sql = _add_geometry_columns(placex_sql, p.c.geometry, details)
|
||||
@@ -630,7 +621,6 @@ class PostcodeSearch(AbstractSearch):
|
||||
return results
|
||||
|
||||
|
||||
|
||||
class PlaceSearch(AbstractSearch):
|
||||
""" Generic search for an address or named place.
|
||||
"""
|
||||
@@ -646,7 +636,6 @@ class PlaceSearch(AbstractSearch):
|
||||
self.rankings = sdata.rankings
|
||||
self.expected_count = expected_count
|
||||
|
||||
|
||||
def _inner_search_name_cte(self, conn: SearchConnection,
|
||||
details: SearchDetails) -> 'sa.CTE':
|
||||
""" Create a subquery that preselects the rows in the search_name
|
||||
@@ -699,7 +688,7 @@ class PlaceSearch(AbstractSearch):
|
||||
NEAR_RADIUS_PARAM))
|
||||
else:
|
||||
sql = sql.where(t.c.centroid
|
||||
.ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM)
|
||||
.ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM)
|
||||
|
||||
if self.housenumbers:
|
||||
sql = sql.where(t.c.address_rank.between(16, 30))
|
||||
@@ -727,8 +716,8 @@ class PlaceSearch(AbstractSearch):
|
||||
and (details.near is None or details.near_radius is not None)\
|
||||
and not self.qualifiers:
|
||||
sql = sql.add_columns(sa.func.first_value(inner.c.penalty - inner.c.importance)
|
||||
.over(order_by=inner.c.penalty - inner.c.importance)
|
||||
.label('min_penalty'))
|
||||
.over(order_by=inner.c.penalty - inner.c.importance)
|
||||
.label('min_penalty'))
|
||||
|
||||
inner = sql.subquery()
|
||||
|
||||
@@ -739,7 +728,6 @@ class PlaceSearch(AbstractSearch):
|
||||
|
||||
return sql.cte('searches')
|
||||
|
||||
|
||||
async def lookup(self, conn: SearchConnection,
|
||||
details: SearchDetails) -> nres.SearchResults:
|
||||
""" Find results for the search in the database.
|
||||
@@ -759,8 +747,8 @@ class PlaceSearch(AbstractSearch):
|
||||
pcs = self.postcodes.values
|
||||
|
||||
pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(t.c.centroid)))\
|
||||
.where(tpc.c.postcode.in_(pcs))\
|
||||
.scalar_subquery()
|
||||
.where(tpc.c.postcode.in_(pcs))\
|
||||
.scalar_subquery()
|
||||
penalty += sa.case((t.c.postcode.in_(pcs), 0.0),
|
||||
else_=sa.func.coalesce(pc_near, cast(SaColumn, 2.0)))
|
||||
|
||||
@@ -771,13 +759,12 @@ class PlaceSearch(AbstractSearch):
|
||||
|
||||
if details.near is not None:
|
||||
sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
|
||||
.label('importance'))
|
||||
.label('importance'))
|
||||
sql = sql.order_by(sa.desc(sa.text('importance')))
|
||||
else:
|
||||
sql = sql.order_by(penalty - tsearch.c.importance)
|
||||
sql = sql.add_columns(tsearch.c.importance)
|
||||
|
||||
|
||||
sql = sql.add_columns(penalty.label('accuracy'))\
|
||||
.order_by(sa.text('accuracy'))
|
||||
|
||||
@@ -814,7 +801,7 @@ class PlaceSearch(AbstractSearch):
|
||||
tiger_sql = sa.case((inner.c.country_code == 'us',
|
||||
_make_interpolation_subquery(conn.t.tiger, inner,
|
||||
numerals, details)
|
||||
), else_=None)
|
||||
), else_=None)
|
||||
else:
|
||||
interpol_sql = sa.null()
|
||||
tiger_sql = sa.null()
|
||||
@@ -868,7 +855,7 @@ class PlaceSearch(AbstractSearch):
|
||||
if (not details.excluded or result.place_id not in details.excluded)\
|
||||
and (not self.qualifiers or result.category in self.qualifiers.values)\
|
||||
and result.rank_address >= details.min_rank:
|
||||
result.accuracy += 1.0 # penalty for missing housenumber
|
||||
result.accuracy += 1.0 # penalty for missing housenumber
|
||||
results.append(result)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
@@ -23,6 +23,7 @@ from .db_searches import AbstractSearch
|
||||
from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
|
||||
from .query import Phrase, QueryStruct
|
||||
|
||||
|
||||
class ForwardGeocoder:
|
||||
""" Main class responsible for place search.
|
||||
"""
|
||||
@@ -34,14 +35,12 @@ class ForwardGeocoder:
|
||||
self.timeout = dt.timedelta(seconds=timeout or 1000000)
|
||||
self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
|
||||
|
||||
|
||||
@property
|
||||
def limit(self) -> int:
|
||||
""" Return the configured maximum number of search results.
|
||||
"""
|
||||
return self.params.max_results
|
||||
|
||||
|
||||
async def build_searches(self,
|
||||
phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
|
||||
""" Analyse the query and return the tokenized query and list of
|
||||
@@ -68,7 +67,6 @@ class ForwardGeocoder:
|
||||
|
||||
return query, searches
|
||||
|
||||
|
||||
async def execute_searches(self, query: QueryStruct,
|
||||
searches: List[AbstractSearch]) -> SearchResults:
|
||||
""" Run the abstract searches against the database until a result
|
||||
@@ -103,7 +101,6 @@ class ForwardGeocoder:
|
||||
|
||||
return SearchResults(results.values())
|
||||
|
||||
|
||||
def pre_filter_results(self, results: SearchResults) -> SearchResults:
|
||||
""" Remove results that are significantly worse than the
|
||||
best match.
|
||||
@@ -114,7 +111,6 @@ class ForwardGeocoder:
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
|
||||
""" Remove badly matching results, sort by ranking and
|
||||
limit to the configured number of results.
|
||||
@@ -124,21 +120,20 @@ class ForwardGeocoder:
|
||||
min_rank = results[0].rank_search
|
||||
min_ranking = results[0].ranking
|
||||
results = SearchResults(r for r in results
|
||||
if r.ranking + 0.03 * (r.rank_search - min_rank)
|
||||
< min_ranking + 0.5)
|
||||
if (r.ranking + 0.03 * (r.rank_search - min_rank)
|
||||
< min_ranking + 0.5))
|
||||
|
||||
results = SearchResults(results[:self.limit])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
|
||||
""" Adjust the accuracy of the localized result according to how well
|
||||
they match the original query.
|
||||
"""
|
||||
assert self.query_analyzer is not None
|
||||
qwords = [word for phrase in query.source
|
||||
for word in re.split('[, ]+', phrase.text) if word]
|
||||
for word in re.split('[, ]+', phrase.text) if word]
|
||||
if not qwords:
|
||||
return
|
||||
|
||||
@@ -167,7 +162,6 @@ class ForwardGeocoder:
|
||||
distance *= 2
|
||||
result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
|
||||
|
||||
|
||||
async def lookup_pois(self, categories: List[Tuple[str, str]],
|
||||
phrases: List[Phrase]) -> SearchResults:
|
||||
""" Look up places by category. If phrase is given, a place search
|
||||
@@ -197,7 +191,6 @@ class ForwardGeocoder:
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def lookup(self, phrases: List[Phrase]) -> SearchResults:
|
||||
""" Look up a single free-text query.
|
||||
"""
|
||||
@@ -223,7 +216,6 @@ class ForwardGeocoder:
|
||||
return results
|
||||
|
||||
|
||||
# pylint: disable=invalid-name,too-many-locals
|
||||
def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
|
||||
start: int = 0) -> Iterator[Optional[List[Any]]]:
|
||||
yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
|
||||
@@ -242,12 +234,11 @@ def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
|
||||
ranks = ranks[:100] + '...'
|
||||
return f"{f.column}({ranks},def={f.default:.3g})"
|
||||
|
||||
def fmt_lookup(l: Any) -> str:
|
||||
if not l:
|
||||
def fmt_lookup(lk: Any) -> str:
|
||||
if not lk:
|
||||
return ''
|
||||
|
||||
return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
|
||||
|
||||
return f"{lk.lookup_type}({lk.column}{tk(lk.tokens)})"
|
||||
|
||||
def fmt_cstr(c: Any) -> str:
|
||||
if not c:
|
||||
|
||||
@@ -48,6 +48,7 @@ class QueryPart(NamedTuple):
|
||||
QueryParts = List[QueryPart]
|
||||
WordDict = Dict[str, List[qmod.TokenRange]]
|
||||
|
||||
|
||||
def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
|
||||
""" Return all combinations of words in the terms list after the
|
||||
given position.
|
||||
@@ -72,7 +73,6 @@ class ICUToken(qmod.Token):
|
||||
assert self.info
|
||||
return self.info.get('class', ''), self.info.get('type', '')
|
||||
|
||||
|
||||
def rematch(self, norm: str) -> None:
|
||||
""" Check how well the token matches the given normalized string
|
||||
and add a penalty, if necessary.
|
||||
@@ -91,7 +91,6 @@ class ICUToken(qmod.Token):
|
||||
distance += abs((ato-afrom) - (bto-bfrom))
|
||||
self.penalty += (distance/len(self.lookup_word))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_db_row(row: SaRow) -> 'ICUToken':
|
||||
""" Create a ICUToken from the row of the word table.
|
||||
@@ -128,16 +127,13 @@ class ICUToken(qmod.Token):
|
||||
addr_count=max(1, addr_count))
|
||||
|
||||
|
||||
|
||||
class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
""" Converter for query strings into a tokenized query
|
||||
using the tokens created by a ICU tokenizer.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: SearchConnection) -> None:
|
||||
self.conn = conn
|
||||
|
||||
|
||||
async def setup(self) -> None:
|
||||
""" Set up static data structures needed for the analysis.
|
||||
"""
|
||||
@@ -163,7 +159,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
sa.Column('word', sa.Text),
|
||||
sa.Column('info', Json))
|
||||
|
||||
|
||||
async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
|
||||
""" Analyze the given list of phrases and return the
|
||||
tokenized query.
|
||||
@@ -202,7 +197,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def normalize_text(self, text: str) -> str:
|
||||
""" Bring the given text into a normalized form. That is the
|
||||
standardized form search will work with. All information removed
|
||||
@@ -210,7 +204,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
"""
|
||||
return cast(str, self.normalizer.transliterate(text))
|
||||
|
||||
|
||||
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
|
||||
""" Transliterate the phrases and split them into tokens.
|
||||
|
||||
@@ -243,7 +236,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
|
||||
return parts, words
|
||||
|
||||
|
||||
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
|
||||
""" Return the token information from the database for the
|
||||
given word tokens.
|
||||
@@ -251,7 +243,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
t = self.conn.t.meta.tables['word']
|
||||
return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
|
||||
|
||||
|
||||
def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
|
||||
""" Add tokens to query that are not saved in the database.
|
||||
"""
|
||||
@@ -263,7 +254,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
count=1, addr_count=1, lookup_word=part.token,
|
||||
word_token=part.token, info=None))
|
||||
|
||||
|
||||
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
|
||||
""" Add penalties to tokens that depend on presence of other token.
|
||||
"""
|
||||
@@ -274,8 +264,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
and (repl.ttype != qmod.TokenType.HOUSENUMBER
|
||||
or len(tlist.tokens[0].lookup_word) > 4):
|
||||
repl.add_penalty(0.39)
|
||||
elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
|
||||
and len(tlist.tokens[0].lookup_word) <= 3:
|
||||
elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
|
||||
and len(tlist.tokens[0].lookup_word) <= 3):
|
||||
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
|
||||
for repl in node.starting:
|
||||
if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:
|
||||
|
||||
@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
|
||||
import dataclasses
|
||||
import enum
|
||||
|
||||
|
||||
class BreakType(enum.Enum):
|
||||
""" Type of break between tokens.
|
||||
"""
|
||||
@@ -102,13 +103,13 @@ class Token(ABC):
|
||||
addr_count: int
|
||||
lookup_word: str
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_category(self) -> Tuple[str, str]:
|
||||
""" Return the category restriction for qualifier terms and
|
||||
category objects.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TokenRange:
|
||||
""" Indexes of query nodes over which a token spans.
|
||||
@@ -119,31 +120,25 @@ class TokenRange:
|
||||
def __lt__(self, other: 'TokenRange') -> bool:
|
||||
return self.end <= other.start
|
||||
|
||||
|
||||
def __le__(self, other: 'TokenRange') -> bool:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def __gt__(self, other: 'TokenRange') -> bool:
|
||||
return self.start >= other.end
|
||||
|
||||
|
||||
def __ge__(self, other: 'TokenRange') -> bool:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def replace_start(self, new_start: int) -> 'TokenRange':
|
||||
""" Return a new token range with the new start.
|
||||
"""
|
||||
return TokenRange(new_start, self.end)
|
||||
|
||||
|
||||
def replace_end(self, new_end: int) -> 'TokenRange':
|
||||
""" Return a new token range with the new end.
|
||||
"""
|
||||
return TokenRange(self.start, new_end)
|
||||
|
||||
|
||||
def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
|
||||
""" Split the span into two spans at the given index.
|
||||
The index must be within the span.
|
||||
@@ -159,7 +154,6 @@ class TokenList:
|
||||
ttype: TokenType
|
||||
tokens: List[Token]
|
||||
|
||||
|
||||
def add_penalty(self, penalty: float) -> None:
|
||||
""" Add the given penalty to all tokens in the list.
|
||||
"""
|
||||
@@ -181,7 +175,6 @@ class QueryNode:
|
||||
"""
|
||||
return any(tl.end == end and tl.ttype in ttypes for tl in self.starting)
|
||||
|
||||
|
||||
def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]:
|
||||
""" Get the list of tokens of the given type starting at this node
|
||||
and ending at the node 'end'. Returns 'None' if no such
|
||||
@@ -220,13 +213,11 @@ class QueryStruct:
|
||||
self.nodes: List[QueryNode] = \
|
||||
[QueryNode(BreakType.START, source[0].ptype if source else PhraseType.NONE)]
|
||||
|
||||
|
||||
def num_token_slots(self) -> int:
|
||||
""" Return the length of the query in vertice steps.
|
||||
"""
|
||||
return len(self.nodes) - 1
|
||||
|
||||
|
||||
def add_node(self, btype: BreakType, ptype: PhraseType) -> None:
|
||||
""" Append a new break node with the given break type.
|
||||
The phrase type denotes the type for any tokens starting
|
||||
@@ -234,7 +225,6 @@ class QueryStruct:
|
||||
"""
|
||||
self.nodes.append(QueryNode(btype, ptype))
|
||||
|
||||
|
||||
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
|
||||
""" Add a token to the query. 'start' and 'end' are the indexes of the
|
||||
nodes from which to which the token spans. The indexes must exist
|
||||
@@ -247,7 +237,7 @@ class QueryStruct:
|
||||
"""
|
||||
snode = self.nodes[trange.start]
|
||||
full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
|
||||
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
|
||||
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
|
||||
if snode.ptype.compatible_with(ttype, full_phrase):
|
||||
tlist = snode.get_tokens(trange.end, ttype)
|
||||
if tlist is None:
|
||||
@@ -255,7 +245,6 @@ class QueryStruct:
|
||||
else:
|
||||
tlist.append(token)
|
||||
|
||||
|
||||
def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
|
||||
""" Get the list of tokens of a given type, spanning the given
|
||||
nodes. The nodes must exist. If no tokens exist, an
|
||||
@@ -263,7 +252,6 @@ class QueryStruct:
|
||||
"""
|
||||
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
|
||||
|
||||
|
||||
def get_partials_list(self, trange: TokenRange) -> List[Token]:
|
||||
""" Create a list of partial tokens between the given nodes.
|
||||
The list is composed of the first token of type PARTIAL
|
||||
@@ -271,8 +259,7 @@ class QueryStruct:
|
||||
assumed to exist.
|
||||
"""
|
||||
return [next(iter(self.get_tokens(TokenRange(i, i+1), TokenType.PARTIAL)))
|
||||
for i in range(trange.start, trange.end)]
|
||||
|
||||
for i in range(trange.start, trange.end)]
|
||||
|
||||
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
|
||||
""" Iterator over all token lists in the query.
|
||||
@@ -281,7 +268,6 @@ class QueryStruct:
|
||||
for tlist in node.starting:
|
||||
yield i, node, tlist
|
||||
|
||||
|
||||
def find_lookup_word_by_id(self, token: int) -> str:
|
||||
""" Find the first token with the given token ID and return
|
||||
its lookup word. Returns 'None' if no such token exists.
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..connection import SearchConnection
|
||||
if TYPE_CHECKING:
|
||||
from .query import Phrase, QueryStruct
|
||||
|
||||
|
||||
class AbstractQueryAnalyzer(ABC):
|
||||
""" Class for analysing incoming queries.
|
||||
|
||||
@@ -29,7 +30,6 @@ class AbstractQueryAnalyzer(ABC):
|
||||
""" Analyze the given phrases and return the tokenized query.
|
||||
"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def normalize_text(self, text: str) -> str:
|
||||
""" Bring the given text into a normalized form. That is the
|
||||
@@ -38,7 +38,6 @@ class AbstractQueryAnalyzer(ABC):
|
||||
"""
|
||||
|
||||
|
||||
|
||||
async def make_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
|
||||
""" Create a query analyzer for the tokenizer used by the database.
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,6 @@ import dataclasses
|
||||
from ..logging import log
|
||||
from . import query as qmod
|
||||
|
||||
# pylint: disable=too-many-return-statements,too-many-branches
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypedRange:
|
||||
@@ -35,8 +34,9 @@ PENALTY_TOKENCHANGE = {
|
||||
|
||||
TypedRangeSeq = List[TypedRange]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TokenAssignment: # pylint: disable=too-many-instance-attributes
|
||||
class TokenAssignment:
|
||||
""" Representation of a possible assignment of token types
|
||||
to the tokens in a tokenized query.
|
||||
"""
|
||||
@@ -49,7 +49,6 @@ class TokenAssignment: # pylint: disable=too-many-instance-attributes
|
||||
near_item: Optional[qmod.TokenRange] = None
|
||||
qualifier: Optional[qmod.TokenRange] = None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
|
||||
""" Create a new token assignment from a sequence of typed spans.
|
||||
@@ -83,34 +82,29 @@ class _TokenSequence:
|
||||
self.direction = direction
|
||||
self.penalty = penalty
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
|
||||
return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
|
||||
|
||||
|
||||
@property
|
||||
def end_pos(self) -> int:
|
||||
""" Return the index of the global end of the current sequence.
|
||||
"""
|
||||
return self.seq[-1].trange.end if self.seq else 0
|
||||
|
||||
|
||||
def has_types(self, *ttypes: qmod.TokenType) -> bool:
|
||||
""" Check if the current sequence contains any typed ranges of
|
||||
the given types.
|
||||
"""
|
||||
return any(s.ttype in ttypes for s in self.seq)
|
||||
|
||||
|
||||
def is_final(self) -> bool:
|
||||
""" Return true when the sequence cannot be extended by any
|
||||
form of token anymore.
|
||||
"""
|
||||
# Country and category must be the final term for left-to-right
|
||||
return len(self.seq) > 1 and \
|
||||
self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
|
||||
|
||||
self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
|
||||
|
||||
def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
|
||||
""" Check if the give token type is appendable to the existing sequence.
|
||||
@@ -149,10 +143,10 @@ class _TokenSequence:
|
||||
return None
|
||||
if len(self.seq) > 2 \
|
||||
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
|
||||
return None # direction left-to-right: housenumber must come before anything
|
||||
elif self.direction == -1 \
|
||||
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
|
||||
return -1 # force direction right-to-left if after other terms
|
||||
return None # direction left-to-right: housenumber must come before anything
|
||||
elif (self.direction == -1
|
||||
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY)):
|
||||
return -1 # force direction right-to-left if after other terms
|
||||
|
||||
return self.direction
|
||||
|
||||
@@ -196,7 +190,6 @@ class _TokenSequence:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def advance(self, ttype: qmod.TokenType, end_pos: int,
|
||||
btype: qmod.BreakType) -> Optional['_TokenSequence']:
|
||||
""" Return a new token sequence state with the given token type
|
||||
@@ -223,7 +216,6 @@ class _TokenSequence:
|
||||
|
||||
return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
|
||||
|
||||
|
||||
def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
|
||||
if priors >= 2:
|
||||
if self.direction == 0:
|
||||
@@ -236,7 +228,6 @@ class _TokenSequence:
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def recheck_sequence(self) -> bool:
|
||||
""" Check that the sequence is a fully valid token assignment
|
||||
and adapt direction and penalties further if necessary.
|
||||
@@ -264,9 +255,8 @@ class _TokenSequence:
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_assignments_postcode(self, base: TokenAssignment,
|
||||
query_len: int) -> Iterator[TokenAssignment]:
|
||||
query_len: int) -> Iterator[TokenAssignment]:
|
||||
""" Yield possible assignments of Postcode searches with an
|
||||
address component.
|
||||
"""
|
||||
@@ -278,13 +268,12 @@ class _TokenSequence:
|
||||
# <address>,<postcode> should give preference to address search
|
||||
if base.postcode.start == 0:
|
||||
penalty = self.penalty
|
||||
self.direction = -1 # name searches are only possible backwards
|
||||
self.direction = -1 # name searches are only possible backwards
|
||||
else:
|
||||
penalty = self.penalty + 0.1
|
||||
self.direction = 1 # name searches are only possible forwards
|
||||
self.direction = 1 # name searches are only possible forwards
|
||||
yield dataclasses.replace(base, penalty=penalty)
|
||||
|
||||
|
||||
def _get_assignments_address_forward(self, base: TokenAssignment,
|
||||
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
|
||||
""" Yield possible assignments of address searches with
|
||||
@@ -320,7 +309,6 @@ class _TokenSequence:
|
||||
yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
|
||||
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
|
||||
|
||||
|
||||
def _get_assignments_address_backward(self, base: TokenAssignment,
|
||||
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
|
||||
""" Yield possible assignments of address searches with
|
||||
@@ -355,7 +343,6 @@ class _TokenSequence:
|
||||
yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
|
||||
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
|
||||
|
||||
|
||||
def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
|
||||
""" Yield possible assignments for the current sequence.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user