fix style issue found by flake8

This commit is contained in:
Sarah Hoffmann
2024-11-10 22:47:14 +01:00
parent 8c14df55a6
commit 1f07967787
112 changed files with 656 additions and 1109 deletions

View File

@@ -42,7 +42,7 @@ def build_poi_search(category: List[Tuple[str, str]],
class _PoiData(dbf.SearchData):
penalty = 0.0
qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
countries=ccs
countries = ccs
return dbs.PoiSearch(_PoiData())
@@ -55,15 +55,13 @@ class SearchBuilder:
self.query = query
self.details = details
@property
def configured_for_country(self) -> bool:
""" Return true if the search details are configured to
allow countries in the result.
"""
return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
and self.details.layer_enabled(DataLayer.ADDRESS)
and self.details.layer_enabled(DataLayer.ADDRESS)
@property
def configured_for_postcode(self) -> bool:
@@ -71,8 +69,7 @@ class SearchBuilder:
allow postcodes in the result.
"""
return self.details.min_rank <= 5 and self.details.max_rank >= 11\
and self.details.layer_enabled(DataLayer.ADDRESS)
and self.details.layer_enabled(DataLayer.ADDRESS)
@property
def configured_for_housenumbers(self) -> bool:
@@ -80,8 +77,7 @@ class SearchBuilder:
allow addresses in the result.
"""
return self.details.max_rank >= 30 \
and self.details.layer_enabled(DataLayer.ADDRESS)
and self.details.layer_enabled(DataLayer.ADDRESS)
def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
""" Yield all possible abstract searches for the given token assignment.
@@ -92,7 +88,7 @@ class SearchBuilder:
near_items = self.get_near_items(assignment)
if near_items is not None and not near_items:
return # impossible compbination of near items and category parameter
return # impossible combination of near items and category parameter
if assignment.name is None:
if near_items and not sdata.postcodes:
@@ -123,7 +119,6 @@ class SearchBuilder:
search.penalty += assignment.penalty
yield search
def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search query for a simple category search.
This kind of search requires an additional geographic constraint.
@@ -132,7 +127,6 @@ class SearchBuilder:
and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
yield dbs.PoiSearch(sdata)
def build_special_search(self, sdata: dbf.SearchData,
address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]:
@@ -157,7 +151,6 @@ class SearchBuilder:
penalty += 0.2
yield dbs.PostcodeSearch(penalty, sdata)
def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
""" Build a simple address search for special entries where the
@@ -167,7 +160,7 @@ class SearchBuilder:
expected_count = sum(t.count for t in hnrs)
partials = {t.token: t.addr_count for trange in address
for t in self.query.get_partials_list(trange)}
for t in self.query.get_partials_list(trange)}
if not partials:
# can happen when none of the partials is indexed
@@ -190,7 +183,6 @@ class SearchBuilder:
sdata.housenumbers = dbf.WeightedStrings([], [])
yield dbs.PlaceSearch(0.05, sdata, expected_count)
def build_name_search(self, sdata: dbf.SearchData,
name: TokenRange, address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]:
@@ -205,14 +197,13 @@ class SearchBuilder:
sdata.lookups = lookup
yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
-> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
def yield_lookups(self, name: TokenRange, address: List[TokenRange]
) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
""" Yield all variants how the given name and address should best
be searched for. This takes into account how frequent the terms
are and tries to find a lookup that optimizes index use.
"""
penalty = 0.0 # extra penalty
penalty = 0.0 # extra penalty
name_partials = {t.token: t for t in self.query.get_partials_list(name)}
addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
@@ -231,7 +222,7 @@ class SearchBuilder:
fulls_count = sum(t.count for t in name_fulls)
if fulls_count < 50000 or addr_count < 30000:
yield penalty,fulls_count / (2**len(addr_tokens)), \
yield penalty, fulls_count / (2**len(addr_tokens)), \
self.get_full_name_ranking(name_fulls, addr_partials,
fulls_count > 30000 / max(1, len(addr_tokens)))
@@ -241,9 +232,8 @@ class SearchBuilder:
if exp_count < 10000 and addr_count < 20000:
penalty += 0.35 * max(1 if name_fulls else 0.1,
5 - len(name_partials) - len(addr_tokens))
yield penalty, exp_count,\
self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
yield penalty, exp_count, \
self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
def get_name_address_ranking(self, name_tokens: List[int],
addr_partials: List[Token]) -> List[dbf.FieldLookup]:
@@ -268,7 +258,6 @@ class SearchBuilder:
return lookup
def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
use_lookup: bool) -> List[dbf.FieldLookup]:
""" Create a ranking expression with full name terms and
@@ -293,7 +282,6 @@ class SearchBuilder:
return dbf.lookup_by_any_name([t.token for t in name_fulls],
addr_restrict_tokens, addr_lookup_tokens)
def get_name_ranking(self, trange: TokenRange,
db_field: str = 'name_vector') -> dbf.FieldRanking:
""" Create a ranking expression for a name term in the given range.
@@ -306,7 +294,6 @@ class SearchBuilder:
default = sum(t.penalty for t in name_partials) + 0.2
return dbf.FieldRanking(db_field, default, ranks)
def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
""" Create a list of ranking expressions for an address term
for the given ranges.
@@ -315,7 +302,7 @@ class SearchBuilder:
heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
ranks: List[dbf.RankedTokens] = []
while todo: # pylint: disable=too-many-nested-blocks
while todo:
neglen, pos, rank = heapq.heappop(todo)
for tlist in self.query.nodes[pos].starting:
if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
@@ -354,7 +341,6 @@ class SearchBuilder:
return dbf.FieldRanking('nameaddress_vector', default, ranks)
def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
""" Collect the tokens for the non-name search fields in the
assignment.
@@ -401,7 +387,6 @@ class SearchBuilder:
return sdata
def get_country_tokens(self, trange: TokenRange) -> List[Token]:
""" Return the list of country tokens for the given range,
optionally filtered by the country list from the details
@@ -413,7 +398,6 @@ class SearchBuilder:
return tokens
def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
""" Return the list of qualifier tokens for the given range,
optionally filtered by the qualifier list from the details
@@ -425,7 +409,6 @@ class SearchBuilder:
return tokens
def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
""" Collect tokens for near items search or use the categories
requested per parameter.

View File

@@ -28,11 +28,9 @@ class WeightedStrings:
def __bool__(self) -> bool:
return bool(self.values)
def __iter__(self) -> Iterator[Tuple[str, float]]:
return iter(zip(self.values, self.penalties))
def get_penalty(self, value: str, default: float = 1000.0) -> float:
""" Get the penalty for the given value. Returns the given default
if the value does not exist.
@@ -54,11 +52,9 @@ class WeightedCategories:
def __bool__(self) -> bool:
return bool(self.values)
def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
return iter(zip(self.values, self.penalties))
def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
""" Get the penalty for the given value. Returns the given default
if the value does not exist.
@@ -69,7 +65,6 @@ class WeightedCategories:
pass
return default
def sql_restrict(self, table: SaFromClause) -> SaExpression:
""" Return an SQLAlcheny expression that restricts the
class and type columns of the given table to the values
@@ -125,7 +120,6 @@ class FieldRanking:
ranking.penalty -= min_penalty
return min_penalty
def sql_penalty(self, table: SaFromClause) -> SaColumn:
""" Create an SQL expression for the rankings.
"""
@@ -177,7 +171,6 @@ class SearchData:
qualifiers: WeightedCategories = WeightedCategories([], [])
def set_strings(self, field: str, tokens: List[Token]) -> None:
""" Set on of the WeightedStrings properties from the given
token list. Adapt the global penalty, so that the
@@ -191,7 +184,6 @@ class SearchData:
setattr(self, field, wstrs)
def set_qualifiers(self, tokens: List[Token]) -> None:
""" Set the qulaifier field from the given tokens.
"""
@@ -207,7 +199,6 @@ class SearchData:
self.qualifiers = WeightedCategories(list(categories.keys()),
list(categories.values()))
def set_ranking(self, rankings: List[FieldRanking]) -> None:
""" Set the list of rankings and normalize the ranking.
"""

View File

@@ -15,10 +15,10 @@ from sqlalchemy.ext.compiler import compiles
from ..typing import SaFromClause
from ..sql.sqlalchemy_types import IntArray
# pylint: disable=consider-using-f-string
LookupType = sa.sql.expression.FunctionElement[Any]
class LookupAll(LookupType):
""" Find all entries in search_name table that contain all of
a given list of tokens using an index for the search.
@@ -40,7 +40,7 @@ def _default_lookup_all(element: LookupAll,
@compiles(LookupAll, 'sqlite')
def _sqlite_lookup_all(element: LookupAll,
compiler: 'sa.Compiled', **kw: Any) -> str:
compiler: 'sa.Compiled', **kw: Any) -> str:
place, col, colname, tokens = list(element.clauses)
return "(%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_intersect_fuzzy(places) as p FROM"\
@@ -50,13 +50,11 @@ def _sqlite_lookup_all(element: LookupAll,
" ORDER BY length(places)) as x) as u,"\
" json_each('[' || u.p || ']'))"\
" AND array_contains(%s, %s))"\
% (compiler.process(place, **kw),
compiler.process(tokens, **kw),
compiler.process(colname, **kw),
compiler.process(col, **kw),
compiler.process(tokens, **kw)
)
% (compiler.process(place, **kw),
compiler.process(tokens, **kw),
compiler.process(colname, **kw),
compiler.process(col, **kw),
compiler.process(tokens, **kw))
class LookupAny(LookupType):
@@ -69,6 +67,7 @@ class LookupAny(LookupType):
super().__init__(table.c.place_id, getattr(table.c, column), column,
sa.type_coerce(tokens, IntArray))
@compiles(LookupAny)
def _default_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str:
@@ -76,9 +75,10 @@ def _default_lookup_any(element: LookupAny,
return "(%s && %s)" % (compiler.process(col, **kw),
compiler.process(tokens, **kw))
@compiles(LookupAny, 'sqlite')
def _sqlite_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str:
compiler: 'sa.Compiled', **kw: Any) -> str:
place, _, colname, tokens = list(element.clauses)
return "%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_union(places) as p FROM reverse_search_name"\
@@ -89,7 +89,6 @@ def _sqlite_lookup_any(element: LookupAny,
compiler.process(colname, **kw))
class Restrict(LookupType):
""" Find all entries that contain all of the given tokens.
Do not use an index for the search.
@@ -103,12 +102,13 @@ class Restrict(LookupType):
@compiles(Restrict)
def _default_restrict(element: Restrict,
compiler: 'sa.Compiled', **kw: Any) -> str:
compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw))
@compiles(Restrict, 'sqlite')
def _sqlite_restrict(element: Restrict,
compiler: 'sa.Compiled', **kw: Any) -> str:
compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_contains(%s)" % compiler.process(element.clauses, **kw)

View File

@@ -20,14 +20,12 @@ from ..types import SearchDetails, DataLayer, GeometryFormat, Bbox
from .. import results as nres
from .db_search_fields import SearchData, WeightedCategories
#pylint: disable=singleton-comparison,not-callable
#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
def no_index(expr: SaColumn) -> SaColumn:
""" Wrap the given expression, so that the query planner will
refrain from using the expression for index lookup.
"""
return sa.func.coalesce(sa.null(), expr) # pylint: disable=not-callable
return sa.func.coalesce(sa.null(), expr)
def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
@@ -68,7 +66,7 @@ def filter_by_area(sql: SaSelect, t: SaFromClause,
if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
use_index=not avoid_index and
details.viewbox.area < 0.2))
details.viewbox.area < 0.2))
return sql
@@ -190,7 +188,7 @@ def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
as rows in the column 'nr'.
"""
vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
.table_valued(sa.column('value', type_=sa.JSON))
.table_valued(sa.column('value', type_=sa.JSON))
return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
@@ -266,7 +264,6 @@ class NearSearch(AbstractSearch):
self.search = search
self.categories = categories
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
@@ -288,11 +285,12 @@ class NearSearch(AbstractSearch):
else:
min_rank = 26
max_rank = 30
base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX
and r.accuracy <= max_accuracy
and r.bbox and r.bbox.area < 20
and r.rank_address >= min_rank
and r.rank_address <= max_rank)
base = nres.SearchResults(r for r in base
if (r.source_table == nres.SourceTable.PLACEX
and r.accuracy <= max_accuracy
and r.bbox and r.bbox.area < 20
and r.rank_address >= min_rank
and r.rank_address <= max_rank))
if base:
baseids = [b.place_id for b in base[:5] if b.place_id]
@@ -304,7 +302,6 @@ class NearSearch(AbstractSearch):
return results
async def lookup_category(self, results: nres.SearchResults,
conn: SearchConnection, ids: List[int],
category: Tuple[str, str], penalty: float,
@@ -334,9 +331,9 @@ class NearSearch(AbstractSearch):
.join(tgeom,
table.c.centroid.ST_CoveredBy(
sa.case((sa.and_(tgeom.c.rank_address > 9,
tgeom.c.geometry.is_area()),
tgeom.c.geometry.is_area()),
tgeom.c.geometry),
else_ = tgeom.c.centroid.ST_Expand(0.05))))
else_=tgeom.c.centroid.ST_Expand(0.05))))
inner = sql.where(tgeom.c.place_id.in_(ids))\
.group_by(table.c.place_id).subquery()
@@ -363,7 +360,6 @@ class NearSearch(AbstractSearch):
results.append(result)
class PoiSearch(AbstractSearch):
""" Category search in a geographic area.
"""
@@ -372,7 +368,6 @@ class PoiSearch(AbstractSearch):
self.qualifiers = sdata.qualifiers
self.countries = sdata.countries
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
@@ -387,7 +382,7 @@ class PoiSearch(AbstractSearch):
def _base_query() -> SaSelect:
return _select_placex(t) \
.add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
.label('importance'))\
.label('importance'))\
.where(t.c.linked_place_id == None) \
.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
.order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
@@ -396,9 +391,9 @@ class PoiSearch(AbstractSearch):
classtype = self.qualifiers.values
if len(classtype) == 1:
cclass, ctype = classtype[0]
sql: SaLambdaSelect = sa.lambda_stmt(lambda: _base_query()
.where(t.c.class_ == cclass)
.where(t.c.type == ctype))
sql: SaLambdaSelect = sa.lambda_stmt(
lambda: _base_query().where(t.c.class_ == cclass)
.where(t.c.type == ctype))
else:
sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
for cls, typ in classtype)))
@@ -455,7 +450,6 @@ class CountrySearch(AbstractSearch):
super().__init__(sdata.penalty)
self.countries = sdata.countries
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
@@ -464,9 +458,9 @@ class CountrySearch(AbstractSearch):
ccodes = self.countries.values
sql = _select_placex(t)\
.add_columns(t.c.importance)\
.where(t.c.country_code.in_(ccodes))\
.where(t.c.rank_address == 4)
.add_columns(t.c.importance)\
.where(t.c.country_code.in_(ccodes))\
.where(t.c.rank_address == 4)
if details.geometry_output:
sql = _add_geometry_columns(sql, t.c.geometry, details)
@@ -493,7 +487,6 @@ class CountrySearch(AbstractSearch):
return results
async def lookup_in_country_table(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Look up the country in the fallback country tables.
@@ -509,7 +502,7 @@ class CountrySearch(AbstractSearch):
sql = sa.select(tgrid.c.country_code,
tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
.label('centroid'),
.label('centroid'),
tgrid.c.geometry.ST_Collect().ST_Expand(0).label('bbox'))\
.where(tgrid.c.country_code.in_(self.countries.values))\
.group_by(tgrid.c.country_code)
@@ -537,7 +530,6 @@ class CountrySearch(AbstractSearch):
return results
class PostcodeSearch(AbstractSearch):
""" Search for a postcode.
"""
@@ -548,7 +540,6 @@ class PostcodeSearch(AbstractSearch):
self.lookups = sdata.lookups
self.rankings = sdata.rankings
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
@@ -588,14 +579,13 @@ class PostcodeSearch(AbstractSearch):
tsearch = conn.t.search_name
sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
.where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
.contains(sa.type_coerce(self.lookups[0].tokens,
IntArray)))
.contains(sa.type_coerce(self.lookups[0].tokens,
IntArray)))
for ranking in self.rankings:
penalty += ranking.sql_penalty(conn.t.search_name)
penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
else_=1.0)
else_=1.0)
sql = sql.add_columns(penalty.label('accuracy'))
sql = sql.order_by('accuracy').limit(LIMIT_PARAM)
@@ -603,13 +593,14 @@ class PostcodeSearch(AbstractSearch):
results = nres.SearchResults()
for row in await conn.execute(sql, _details_to_bind_params(details)):
p = conn.t.placex
placex_sql = _select_placex(p).add_columns(p.c.importance)\
.where(sa.text("""class = 'boundary'
AND type = 'postal_code'
AND osm_type = 'R'"""))\
.where(p.c.country_code == row.country_code)\
.where(p.c.postcode == row.postcode)\
.limit(1)
placex_sql = _select_placex(p)\
.add_columns(p.c.importance)\
.where(sa.text("""class = 'boundary'
AND type = 'postal_code'
AND osm_type = 'R'"""))\
.where(p.c.country_code == row.country_code)\
.where(p.c.postcode == row.postcode)\
.limit(1)
if details.geometry_output:
placex_sql = _add_geometry_columns(placex_sql, p.c.geometry, details)
@@ -630,7 +621,6 @@ class PostcodeSearch(AbstractSearch):
return results
class PlaceSearch(AbstractSearch):
""" Generic search for an address or named place.
"""
@@ -646,7 +636,6 @@ class PlaceSearch(AbstractSearch):
self.rankings = sdata.rankings
self.expected_count = expected_count
def _inner_search_name_cte(self, conn: SearchConnection,
details: SearchDetails) -> 'sa.CTE':
""" Create a subquery that preselects the rows in the search_name
@@ -699,7 +688,7 @@ class PlaceSearch(AbstractSearch):
NEAR_RADIUS_PARAM))
else:
sql = sql.where(t.c.centroid
.ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM)
.ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM)
if self.housenumbers:
sql = sql.where(t.c.address_rank.between(16, 30))
@@ -727,8 +716,8 @@ class PlaceSearch(AbstractSearch):
and (details.near is None or details.near_radius is not None)\
and not self.qualifiers:
sql = sql.add_columns(sa.func.first_value(inner.c.penalty - inner.c.importance)
.over(order_by=inner.c.penalty - inner.c.importance)
.label('min_penalty'))
.over(order_by=inner.c.penalty - inner.c.importance)
.label('min_penalty'))
inner = sql.subquery()
@@ -739,7 +728,6 @@ class PlaceSearch(AbstractSearch):
return sql.cte('searches')
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
@@ -759,8 +747,8 @@ class PlaceSearch(AbstractSearch):
pcs = self.postcodes.values
pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(t.c.centroid)))\
.where(tpc.c.postcode.in_(pcs))\
.scalar_subquery()
.where(tpc.c.postcode.in_(pcs))\
.scalar_subquery()
penalty += sa.case((t.c.postcode.in_(pcs), 0.0),
else_=sa.func.coalesce(pc_near, cast(SaColumn, 2.0)))
@@ -771,13 +759,12 @@ class PlaceSearch(AbstractSearch):
if details.near is not None:
sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
.label('importance'))
.label('importance'))
sql = sql.order_by(sa.desc(sa.text('importance')))
else:
sql = sql.order_by(penalty - tsearch.c.importance)
sql = sql.add_columns(tsearch.c.importance)
sql = sql.add_columns(penalty.label('accuracy'))\
.order_by(sa.text('accuracy'))
@@ -814,7 +801,7 @@ class PlaceSearch(AbstractSearch):
tiger_sql = sa.case((inner.c.country_code == 'us',
_make_interpolation_subquery(conn.t.tiger, inner,
numerals, details)
), else_=None)
), else_=None)
else:
interpol_sql = sa.null()
tiger_sql = sa.null()
@@ -868,7 +855,7 @@ class PlaceSearch(AbstractSearch):
if (not details.excluded or result.place_id not in details.excluded)\
and (not self.qualifiers or result.category in self.qualifiers.values)\
and result.rank_address >= details.min_rank:
result.accuracy += 1.0 # penalty for missing housenumber
result.accuracy += 1.0 # penalty for missing housenumber
results.append(result)
else:
results.append(result)

View File

@@ -23,6 +23,7 @@ from .db_searches import AbstractSearch
from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
from .query import Phrase, QueryStruct
class ForwardGeocoder:
""" Main class responsible for place search.
"""
@@ -34,14 +35,12 @@ class ForwardGeocoder:
self.timeout = dt.timedelta(seconds=timeout or 1000000)
self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
@property
def limit(self) -> int:
""" Return the configured maximum number of search results.
"""
return self.params.max_results
async def build_searches(self,
phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
""" Analyse the query and return the tokenized query and list of
@@ -68,7 +67,6 @@ class ForwardGeocoder:
return query, searches
async def execute_searches(self, query: QueryStruct,
searches: List[AbstractSearch]) -> SearchResults:
""" Run the abstract searches against the database until a result
@@ -103,7 +101,6 @@ class ForwardGeocoder:
return SearchResults(results.values())
def pre_filter_results(self, results: SearchResults) -> SearchResults:
""" Remove results that are significantly worse than the
best match.
@@ -114,7 +111,6 @@ class ForwardGeocoder:
return results
def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
""" Remove badly matching results, sort by ranking and
limit to the configured number of results.
@@ -124,21 +120,20 @@ class ForwardGeocoder:
min_rank = results[0].rank_search
min_ranking = results[0].ranking
results = SearchResults(r for r in results
if r.ranking + 0.03 * (r.rank_search - min_rank)
< min_ranking + 0.5)
if (r.ranking + 0.03 * (r.rank_search - min_rank)
< min_ranking + 0.5))
results = SearchResults(results[:self.limit])
return results
def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
""" Adjust the accuracy of the localized result according to how well
they match the original query.
"""
assert self.query_analyzer is not None
qwords = [word for phrase in query.source
for word in re.split('[, ]+', phrase.text) if word]
for word in re.split('[, ]+', phrase.text) if word]
if not qwords:
return
@@ -167,7 +162,6 @@ class ForwardGeocoder:
distance *= 2
result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
async def lookup_pois(self, categories: List[Tuple[str, str]],
phrases: List[Phrase]) -> SearchResults:
""" Look up places by category. If phrase is given, a place search
@@ -197,7 +191,6 @@ class ForwardGeocoder:
return results
async def lookup(self, phrases: List[Phrase]) -> SearchResults:
""" Look up a single free-text query.
"""
@@ -223,7 +216,6 @@ class ForwardGeocoder:
return results
# pylint: disable=invalid-name,too-many-locals
def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
start: int = 0) -> Iterator[Optional[List[Any]]]:
yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
@@ -242,12 +234,11 @@ def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
ranks = ranks[:100] + '...'
return f"{f.column}({ranks},def={f.default:.3g})"
def fmt_lookup(l: Any) -> str:
if not l:
def fmt_lookup(lk: Any) -> str:
if not lk:
return ''
return f"{l.lookup_type}({l.column}{tk(l.tokens)})"
return f"{lk.lookup_type}({lk.column}{tk(lk.tokens)})"
def fmt_cstr(c: Any) -> str:
if not c:

View File

@@ -48,6 +48,7 @@ class QueryPart(NamedTuple):
QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]]
def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
""" Return all combinations of words in the terms list after the
given position.
@@ -72,7 +73,6 @@ class ICUToken(qmod.Token):
assert self.info
return self.info.get('class', ''), self.info.get('type', '')
def rematch(self, norm: str) -> None:
""" Check how well the token matches the given normalized string
and add a penalty, if necessary.
@@ -91,7 +91,6 @@ class ICUToken(qmod.Token):
distance += abs((ato-afrom) - (bto-bfrom))
self.penalty += (distance/len(self.lookup_word))
@staticmethod
def from_db_row(row: SaRow) -> 'ICUToken':
""" Create a ICUToken from the row of the word table.
@@ -128,16 +127,13 @@ class ICUToken(qmod.Token):
addr_count=max(1, addr_count))
class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" Converter for query strings into a tokenized query
using the tokens created by a ICU tokenizer.
"""
def __init__(self, conn: SearchConnection) -> None:
self.conn = conn
async def setup(self) -> None:
""" Set up static data structures needed for the analysis.
"""
@@ -163,7 +159,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
sa.Column('word', sa.Text),
sa.Column('info', Json))
async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
""" Analyze the given list of phrases and return the
tokenized query.
@@ -202,7 +197,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
return query
def normalize_text(self, text: str) -> str:
""" Bring the given text into a normalized form. That is the
standardized form search will work with. All information removed
@@ -210,7 +204,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
"""
return cast(str, self.normalizer.transliterate(text))
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
""" Transliterate the phrases and split them into tokens.
@@ -243,7 +236,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
return parts, words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the
given word tokens.
@@ -251,7 +243,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
t = self.conn.t.meta.tables['word']
return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add tokens to query that are not saved in the database.
"""
@@ -263,7 +254,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
count=1, addr_count=1, lookup_word=part.token,
word_token=part.token, info=None))
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add penalties to tokens that depend on presence of other token.
"""
@@ -274,8 +264,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
and (repl.ttype != qmod.TokenType.HOUSENUMBER
or len(tlist.tokens[0].lookup_word) > 4):
repl.add_penalty(0.39)
elif tlist.ttype == qmod.TokenType.HOUSENUMBER \
and len(tlist.tokens[0].lookup_word) <= 3:
elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
and len(tlist.tokens[0].lookup_word) <= 3):
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:

View File

@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
import dataclasses
import enum
class BreakType(enum.Enum):
""" Type of break between tokens.
"""
@@ -102,13 +103,13 @@ class Token(ABC):
addr_count: int
lookup_word: str
@abstractmethod
def get_category(self) -> Tuple[str, str]:
""" Return the category restriction for qualifier terms and
category objects.
"""
@dataclasses.dataclass
class TokenRange:
""" Indexes of query nodes over which a token spans.
@@ -119,31 +120,25 @@ class TokenRange:
def __lt__(self, other: 'TokenRange') -> bool:
return self.end <= other.start
def __le__(self, other: 'TokenRange') -> bool:
return NotImplemented
def __gt__(self, other: 'TokenRange') -> bool:
return self.start >= other.end
def __ge__(self, other: 'TokenRange') -> bool:
return NotImplemented
def replace_start(self, new_start: int) -> 'TokenRange':
""" Return a new token range with the new start.
"""
return TokenRange(new_start, self.end)
def replace_end(self, new_end: int) -> 'TokenRange':
""" Return a new token range with the new end.
"""
return TokenRange(self.start, new_end)
def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
""" Split the span into two spans at the given index.
The index must be within the span.
@@ -159,7 +154,6 @@ class TokenList:
ttype: TokenType
tokens: List[Token]
def add_penalty(self, penalty: float) -> None:
""" Add the given penalty to all tokens in the list.
"""
@@ -181,7 +175,6 @@ class QueryNode:
"""
return any(tl.end == end and tl.ttype in ttypes for tl in self.starting)
def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]:
""" Get the list of tokens of the given type starting at this node
and ending at the node 'end'. Returns 'None' if no such
@@ -220,13 +213,11 @@ class QueryStruct:
self.nodes: List[QueryNode] = \
[QueryNode(BreakType.START, source[0].ptype if source else PhraseType.NONE)]
def num_token_slots(self) -> int:
""" Return the length of the query in vertice steps.
"""
return len(self.nodes) - 1
def add_node(self, btype: BreakType, ptype: PhraseType) -> None:
""" Append a new break node with the given break type.
The phrase type denotes the type for any tokens starting
@@ -234,7 +225,6 @@ class QueryStruct:
"""
self.nodes.append(QueryNode(btype, ptype))
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
""" Add a token to the query. 'start' and 'end' are the indexes of the
nodes from which to which the token spans. The indexes must exist
@@ -247,7 +237,7 @@ class QueryStruct:
"""
snode = self.nodes[trange.start]
full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
if snode.ptype.compatible_with(ttype, full_phrase):
tlist = snode.get_tokens(trange.end, ttype)
if tlist is None:
@@ -255,7 +245,6 @@ class QueryStruct:
else:
tlist.append(token)
def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
""" Get the list of tokens of a given type, spanning the given
nodes. The nodes must exist. If no tokens exist, an
@@ -263,7 +252,6 @@ class QueryStruct:
"""
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
def get_partials_list(self, trange: TokenRange) -> List[Token]:
""" Create a list of partial tokens between the given nodes.
The list is composed of the first token of type PARTIAL
@@ -271,8 +259,7 @@ class QueryStruct:
assumed to exist.
"""
return [next(iter(self.get_tokens(TokenRange(i, i+1), TokenType.PARTIAL)))
for i in range(trange.start, trange.end)]
for i in range(trange.start, trange.end)]
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
""" Iterator over all token lists in the query.
@@ -281,7 +268,6 @@ class QueryStruct:
for tlist in node.starting:
yield i, node, tlist
def find_lookup_word_by_id(self, token: int) -> str:
""" Find the first token with the given token ID and return
its lookup word. Returns 'None' if no such token exists.

View File

@@ -18,6 +18,7 @@ from ..connection import SearchConnection
if TYPE_CHECKING:
from .query import Phrase, QueryStruct
class AbstractQueryAnalyzer(ABC):
""" Class for analysing incoming queries.
@@ -29,7 +30,6 @@ class AbstractQueryAnalyzer(ABC):
""" Analyze the given phrases and return the tokenized query.
"""
@abstractmethod
def normalize_text(self, text: str) -> str:
""" Bring the given text into a normalized form. That is the
@@ -38,7 +38,6 @@ class AbstractQueryAnalyzer(ABC):
"""
async def make_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
""" Create a query analyzer for the tokenizer used by the database.
"""

View File

@@ -14,7 +14,6 @@ import dataclasses
from ..logging import log
from . import query as qmod
# pylint: disable=too-many-return-statements,too-many-branches
@dataclasses.dataclass
class TypedRange:
@@ -35,8 +34,9 @@ PENALTY_TOKENCHANGE = {
TypedRangeSeq = List[TypedRange]
@dataclasses.dataclass
class TokenAssignment: # pylint: disable=too-many-instance-attributes
class TokenAssignment:
""" Representation of a possible assignment of token types
to the tokens in a tokenized query.
"""
@@ -49,7 +49,6 @@ class TokenAssignment: # pylint: disable=too-many-instance-attributes
near_item: Optional[qmod.TokenRange] = None
qualifier: Optional[qmod.TokenRange] = None
@staticmethod
def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
""" Create a new token assignment from a sequence of typed spans.
@@ -83,34 +82,29 @@ class _TokenSequence:
self.direction = direction
self.penalty = penalty
def __str__(self) -> str:
seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
@property
def end_pos(self) -> int:
""" Return the index of the global end of the current sequence.
"""
return self.seq[-1].trange.end if self.seq else 0
def has_types(self, *ttypes: qmod.TokenType) -> bool:
""" Check if the current sequence contains any typed ranges of
the given types.
"""
return any(s.ttype in ttypes for s in self.seq)
def is_final(self) -> bool:
""" Return true when the sequence cannot be extended by any
form of token anymore.
"""
# Country and category must be the final term for left-to-right
return len(self.seq) > 1 and \
self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
""" Check if the give token type is appendable to the existing sequence.
@@ -149,10 +143,10 @@ class _TokenSequence:
return None
if len(self.seq) > 2 \
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
return None # direction left-to-right: housenumber must come before anything
elif self.direction == -1 \
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
return -1 # force direction right-to-left if after other terms
return None # direction left-to-right: housenumber must come before anything
elif (self.direction == -1
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY)):
return -1 # force direction right-to-left if after other terms
return self.direction
@@ -196,7 +190,6 @@ class _TokenSequence:
return None
def advance(self, ttype: qmod.TokenType, end_pos: int,
btype: qmod.BreakType) -> Optional['_TokenSequence']:
""" Return a new token sequence state with the given token type
@@ -223,7 +216,6 @@ class _TokenSequence:
return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
if priors >= 2:
if self.direction == 0:
@@ -236,7 +228,6 @@ class _TokenSequence:
return True
def recheck_sequence(self) -> bool:
""" Check that the sequence is a fully valid token assignment
and adapt direction and penalties further if necessary.
@@ -264,9 +255,8 @@ class _TokenSequence:
return True
def _get_assignments_postcode(self, base: TokenAssignment,
query_len: int) -> Iterator[TokenAssignment]:
query_len: int) -> Iterator[TokenAssignment]:
""" Yield possible assignments of Postcode searches with an
address component.
"""
@@ -278,13 +268,12 @@ class _TokenSequence:
# <address>,<postcode> should give preference to address search
if base.postcode.start == 0:
penalty = self.penalty
self.direction = -1 # name searches are only possible backwards
self.direction = -1 # name searches are only possible backwards
else:
penalty = self.penalty + 0.1
self.direction = 1 # name searches are only possible forwards
self.direction = 1 # name searches are only possible forwards
yield dataclasses.replace(base, penalty=penalty)
def _get_assignments_address_forward(self, base: TokenAssignment,
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments of address searches with
@@ -320,7 +309,6 @@ class _TokenSequence:
yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def _get_assignments_address_backward(self, base: TokenAssignment,
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments of address searches with
@@ -355,7 +343,6 @@ class _TokenSequence:
yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments for the current sequence.