fix style issue found by flake8

This commit is contained in:
Sarah Hoffmann
2024-11-10 22:47:14 +01:00
parent 8c14df55a6
commit 1f07967787
112 changed files with 656 additions and 1109 deletions

View File

@@ -1,22 +0,0 @@
[MASTER]
extension-pkg-whitelist=osmium,falcon
ignored-modules=icu,datrie
[MESSAGES CONTROL]
[TYPECHECK]
# closing added here because it sometimes triggers a false positive with
# 'with' statements.
ignored-classes=NominatimArgs,closing
# 'too-many-ancestors' is triggered already by deriving from UserDict
# 'not-context-manager' disabled because it causes false positives once
# typed Python is enabled. See also https://github.com/PyCQA/pylint/issues/5273
disable=too-few-public-methods,duplicate-code,too-many-ancestors,bad-option-value,no-self-use,not-context-manager,use-dict-literal,chained-comparison,attribute-defined-outside-init,too-many-boolean-expressions,contextmanager-generator-missing-cleanup,too-many-positional-arguments
good-names=i,j,x,y,m,t,fd,db,cc,x1,x2,y1,y2,pt,k,v,nr
[DESIGN]
max-returns=7

View File

@@ -8,5 +8,5 @@
# This file is just a placeholder to make the config module available # This file is just a placeholder to make the config module available
# during development. It will be replaced by nominatim_db/config.py on # during development. It will be replaced by nominatim_db/config.py on
# installation. # installation.
# pylint: skip-file # flake8: noqa
from nominatim_db.config import * from nominatim_db.config import *

View File

@@ -21,6 +21,7 @@ from .logging import log
T = TypeVar('T') T = TypeVar('T')
class SearchConnection: class SearchConnection:
""" An extended SQLAlchemy connection class, that also contains """ An extended SQLAlchemy connection class, that also contains
the table definitions. The underlying asynchronous SQLAlchemy the table definitions. The underlying asynchronous SQLAlchemy
@@ -32,37 +33,32 @@ class SearchConnection:
tables: SearchTables, tables: SearchTables,
properties: Dict[str, Any]) -> None: properties: Dict[str, Any]) -> None:
self.connection = conn self.connection = conn
self.t = tables # pylint: disable=invalid-name self.t = tables
self._property_cache = properties self._property_cache = properties
self._classtables: Optional[Set[str]] = None self._classtables: Optional[Set[str]] = None
self.query_timeout: Optional[int] = None self.query_timeout: Optional[int] = None
def set_query_timeout(self, timeout: Optional[int]) -> None: def set_query_timeout(self, timeout: Optional[int]) -> None:
""" Set the timeout after which a query over this connection """ Set the timeout after which a query over this connection
is cancelled. is cancelled.
""" """
self.query_timeout = timeout self.query_timeout = timeout
async def scalar(self, sql: sa.sql.base.Executable, async def scalar(self, sql: sa.sql.base.Executable,
params: Union[Mapping[str, Any], None] = None params: Union[Mapping[str, Any], None] = None) -> Any:
) -> Any:
""" Execute a 'scalar()' query on the connection. """ Execute a 'scalar()' query on the connection.
""" """
log().sql(self.connection, sql, params) log().sql(self.connection, sql, params)
return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout) return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout)
async def execute(self, sql: 'sa.Executable', async def execute(self, sql: 'sa.Executable',
params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None] = None params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None] = None
) -> 'sa.Result[Any]': ) -> 'sa.Result[Any]':
""" Execute a 'execute()' query on the connection. """ Execute a 'execute()' query on the connection.
""" """
log().sql(self.connection, sql, params) log().sql(self.connection, sql, params)
return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout) return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout)
async def get_property(self, name: str, cached: bool = True) -> str: async def get_property(self, name: str, cached: bool = True) -> str:
""" Get a property from Nominatim's property table. """ Get a property from Nominatim's property table.
@@ -89,7 +85,6 @@ class SearchConnection:
return cast(str, value) return cast(str, value)
async def get_db_property(self, name: str) -> Any: async def get_db_property(self, name: str) -> Any:
""" Get a setting from the database. At the moment, only """ Get a setting from the database. At the moment, only
'server_version', the version of the database software, can 'server_version', the version of the database software, can
@@ -102,7 +97,6 @@ class SearchConnection:
return self._property_cache['DB:server_version'] return self._property_cache['DB:server_version']
async def get_cached_value(self, group: str, name: str, async def get_cached_value(self, group: str, name: str,
factory: Callable[[], Awaitable[T]]) -> T: factory: Callable[[], Awaitable[T]]) -> T:
""" Access the cache for this Nominatim instance. """ Access the cache for this Nominatim instance.
@@ -125,7 +119,6 @@ class SearchConnection:
return value return value
async def get_class_table(self, cls: str, typ: str) -> Optional[SaFromClause]: async def get_class_table(self, cls: str, typ: str) -> Optional[SaFromClause]:
""" Lookup up if there is a classtype table for the given category """ Lookup up if there is a classtype table for the given category
and return a SQLAlchemy table for it, if it exists. and return a SQLAlchemy table for it, if it exists.

View File

@@ -7,7 +7,7 @@
""" """
Implementation of classes for API access via libraries. Implementation of classes for API access via libraries.
""" """
from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List,\ from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, \
Union, Tuple, cast Union, Tuple, cast
import asyncio import asyncio
import sys import sys
@@ -21,7 +21,7 @@ from .errors import UsageError
from .sql.sqlalchemy_schema import SearchTables from .sql.sqlalchemy_schema import SearchTables
from .sql.async_core_library import PGCORE_LIB, PGCORE_ERROR from .sql.async_core_library import PGCORE_LIB, PGCORE_ERROR
from .config import Configuration from .config import Configuration
from .sql import sqlite_functions, sqlalchemy_functions #pylint: disable=unused-import from .sql import sqlite_functions, sqlalchemy_functions # noqa
from .connection import SearchConnection from .connection import SearchConnection
from .status import get_status, StatusResult from .status import get_status, StatusResult
from .lookup import get_detailed_place, get_simple_place from .lookup import get_detailed_place, get_simple_place
@@ -31,7 +31,7 @@ from . import types as ntyp
from .results import DetailedResult, ReverseResult, SearchResults from .results import DetailedResult, ReverseResult, SearchResults
class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes class NominatimAPIAsync:
""" The main frontend to the Nominatim database implements the """ The main frontend to the Nominatim database implements the
functions for lookup, forward and reverse geocoding using functions for lookup, forward and reverse geocoding using
asynchronous functions. asynchronous functions.
@@ -61,19 +61,18 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
""" """
self.config = Configuration(project_dir, environ) self.config = Configuration(project_dir, environ)
self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \ self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \
if self.config.QUERY_TIMEOUT else None if self.config.QUERY_TIMEOUT else None
self.reverse_restrict_to_country_area = self.config.get_bool('SEARCH_WITHIN_COUNTRIES') self.reverse_restrict_to_country_area = self.config.get_bool('SEARCH_WITHIN_COUNTRIES')
self.server_version = 0 self.server_version = 0
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
self._engine_lock = asyncio.Lock() self._engine_lock = asyncio.Lock()
else: else:
self._engine_lock = asyncio.Lock(loop=loop) # pylint: disable=unexpected-keyword-arg self._engine_lock = asyncio.Lock(loop=loop)
self._engine: Optional[sa_asyncio.AsyncEngine] = None self._engine: Optional[sa_asyncio.AsyncEngine] = None
self._tables: Optional[SearchTables] = None self._tables: Optional[SearchTables] = None
self._property_cache: Dict[str, Any] = {'DB:server_version': 0} self._property_cache: Dict[str, Any] = {'DB:server_version': 0}
async def setup_database(self) -> None: async def setup_database(self) -> None:
""" Set up the SQL engine and connections. """ Set up the SQL engine and connections.
@@ -95,7 +94,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
extra_args['max_overflow'] = 0 extra_args['max_overflow'] = 0
extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE') extra_args['pool_size'] = self.config.get_int('API_POOL_SIZE')
is_sqlite = self.config.DATABASE_DSN.startswith('sqlite:') is_sqlite = self.config.DATABASE_DSN.startswith('sqlite:')
if is_sqlite: if is_sqlite:
@@ -156,10 +154,9 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
self._property_cache['DB:server_version'] = server_version self._property_cache['DB:server_version'] = server_version
self._tables = SearchTables(sa.MetaData()) # pylint: disable=no-member self._tables = SearchTables(sa.MetaData())
self._engine = engine self._engine = engine
async def close(self) -> None: async def close(self) -> None:
""" Close all active connections to the database. The NominatimAPIAsync """ Close all active connections to the database. The NominatimAPIAsync
object remains usable after closing. If a new API functions is object remains usable after closing. If a new API functions is
@@ -168,15 +165,12 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
if self._engine is not None: if self._engine is not None:
await self._engine.dispose() await self._engine.dispose()
async def __aenter__(self) -> 'NominatimAPIAsync': async def __aenter__(self) -> 'NominatimAPIAsync':
return self return self
async def __aexit__(self, *_: Any) -> None: async def __aexit__(self, *_: Any) -> None:
await self.close() await self.close()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def begin(self) -> AsyncIterator[SearchConnection]: async def begin(self) -> AsyncIterator[SearchConnection]:
""" Create a new connection with automatic transaction handling. """ Create a new connection with automatic transaction handling.
@@ -194,7 +188,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
async with self._engine.begin() as conn: async with self._engine.begin() as conn:
yield SearchConnection(conn, self._tables, self._property_cache) yield SearchConnection(conn, self._tables, self._property_cache)
async def status(self) -> StatusResult: async def status(self) -> StatusResult:
""" Return the status of the database. """ Return the status of the database.
""" """
@@ -207,7 +200,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
return status return status
async def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]: async def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]:
""" Get detailed information about a place in the database. """ Get detailed information about a place in the database.
@@ -220,7 +212,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
await make_query_analyzer(conn) await make_query_analyzer(conn)
return await get_detailed_place(conn, place, details) return await get_detailed_place(conn, place, details)
async def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults: async def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
""" Get simple information about a list of places. """ Get simple information about a list of places.
@@ -234,7 +225,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
return SearchResults(filter(None, return SearchResults(filter(None,
[await get_simple_place(conn, p, details) for p in places])) [await get_simple_place(conn, p, details) for p in places]))
async def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]: async def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]:
""" Find a place by its coordinates. Also known as reverse geocoding. """ Find a place by its coordinates. Also known as reverse geocoding.
@@ -255,7 +245,6 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
self.reverse_restrict_to_country_area) self.reverse_restrict_to_country_area)
return await geocoder.lookup(coord) return await geocoder.lookup(coord)
async def search(self, query: str, **params: Any) -> SearchResults: async def search(self, query: str, **params: Any) -> SearchResults:
""" Find a place by free-text search. Also known as forward geocoding. """ Find a place by free-text search. Also known as forward geocoding.
""" """
@@ -266,13 +255,11 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
async with self.begin() as conn: async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params), geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params),
self.config.get_int('REQUEST_TIMEOUT') \ self.config.get_int('REQUEST_TIMEOUT')
if self.config.REQUEST_TIMEOUT else None) if self.config.REQUEST_TIMEOUT else None)
phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')] phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')]
return await geocoder.lookup(phrases) return await geocoder.lookup(phrases)
# pylint: disable=too-many-arguments,too-many-branches
async def search_address(self, amenity: Optional[str] = None, async def search_address(self, amenity: Optional[str] = None,
street: Optional[str] = None, street: Optional[str] = None,
city: Optional[str] = None, city: Optional[str] = None,
@@ -326,11 +313,10 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
details.layers |= ntyp.DataLayer.POI details.layers |= ntyp.DataLayer.POI
geocoder = ForwardGeocoder(conn, details, geocoder = ForwardGeocoder(conn, details,
self.config.get_int('REQUEST_TIMEOUT') \ self.config.get_int('REQUEST_TIMEOUT')
if self.config.REQUEST_TIMEOUT else None) if self.config.REQUEST_TIMEOUT else None)
return await geocoder.lookup(phrases) return await geocoder.lookup(phrases)
async def search_category(self, categories: List[Tuple[str, str]], async def search_category(self, categories: List[Tuple[str, str]],
near_query: Optional[str] = None, near_query: Optional[str] = None,
**params: Any) -> SearchResults: **params: Any) -> SearchResults:
@@ -352,12 +338,11 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
await make_query_analyzer(conn) await make_query_analyzer(conn)
geocoder = ForwardGeocoder(conn, details, geocoder = ForwardGeocoder(conn, details,
self.config.get_int('REQUEST_TIMEOUT') \ self.config.get_int('REQUEST_TIMEOUT')
if self.config.REQUEST_TIMEOUT else None) if self.config.REQUEST_TIMEOUT else None)
return await geocoder.lookup_pois(categories, phrases) return await geocoder.lookup_pois(categories, phrases)
class NominatimAPI: class NominatimAPI:
""" This class provides a thin synchronous wrapper around the asynchronous """ This class provides a thin synchronous wrapper around the asynchronous
Nominatim functions. It creates its own event loop and runs each Nominatim functions. It creates its own event loop and runs each
@@ -382,7 +367,6 @@ class NominatimAPI:
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._async_api = NominatimAPIAsync(project_dir, environ, loop=self._loop) self._async_api = NominatimAPIAsync(project_dir, environ, loop=self._loop)
def close(self) -> None: def close(self) -> None:
""" Close all active connections to the database. """ Close all active connections to the database.
@@ -393,15 +377,12 @@ class NominatimAPI:
self._loop.run_until_complete(self._async_api.close()) self._loop.run_until_complete(self._async_api.close())
self._loop.close() self._loop.close()
def __enter__(self) -> 'NominatimAPI': def __enter__(self) -> 'NominatimAPI':
return self return self
def __exit__(self, *_: Any) -> None: def __exit__(self, *_: Any) -> None:
self.close() self.close()
@property @property
def config(self) -> Configuration: def config(self) -> Configuration:
""" Provide read-only access to the [configuration](Configuration.md) """ Provide read-only access to the [configuration](Configuration.md)
@@ -427,7 +408,6 @@ class NominatimAPI:
""" """
return self._loop.run_until_complete(self._async_api.status()) return self._loop.run_until_complete(self._async_api.status())
def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]: def details(self, place: ntyp.PlaceRef, **params: Any) -> Optional[DetailedResult]:
""" Get detailed information about a place in the database. """ Get detailed information about a place in the database.
@@ -510,7 +490,6 @@ class NominatimAPI:
""" """
return self._loop.run_until_complete(self._async_api.details(place, **params)) return self._loop.run_until_complete(self._async_api.details(place, **params))
def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults: def lookup(self, places: Sequence[ntyp.PlaceRef], **params: Any) -> SearchResults:
""" Get simple information about a list of places. """ Get simple information about a list of places.
@@ -587,7 +566,6 @@ class NominatimAPI:
""" """
return self._loop.run_until_complete(self._async_api.lookup(places, **params)) return self._loop.run_until_complete(self._async_api.lookup(places, **params))
def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]: def reverse(self, coord: ntyp.AnyPoint, **params: Any) -> Optional[ReverseResult]:
""" Find a place by its coordinates. Also known as reverse geocoding. """ Find a place by its coordinates. Also known as reverse geocoding.
@@ -669,7 +647,6 @@ class NominatimAPI:
""" """
return self._loop.run_until_complete(self._async_api.reverse(coord, **params)) return self._loop.run_until_complete(self._async_api.reverse(coord, **params))
def search(self, query: str, **params: Any) -> SearchResults: def search(self, query: str, **params: Any) -> SearchResults:
""" Find a place by free-text search. Also known as forward geocoding. """ Find a place by free-text search. Also known as forward geocoding.
@@ -769,8 +746,6 @@ class NominatimAPI:
return self._loop.run_until_complete( return self._loop.run_until_complete(
self._async_api.search(query, **params)) self._async_api.search(query, **params))
# pylint: disable=too-many-arguments
def search_address(self, amenity: Optional[str] = None, def search_address(self, amenity: Optional[str] = None,
street: Optional[str] = None, street: Optional[str] = None,
city: Optional[str] = None, city: Optional[str] = None,
@@ -888,7 +863,6 @@ class NominatimAPI:
self._async_api.search_address(amenity, street, city, county, self._async_api.search_address(amenity, street, city, county,
state, country, postalcode, **params)) state, country, postalcode, **params))
def search_category(self, categories: List[Tuple[str, str]], def search_category(self, categories: List[Tuple[str, str]],
near_query: Optional[str] = None, near_query: Optional[str] = None,
**params: Any) -> SearchResults: **params: Any) -> SearchResults:

View File

@@ -8,6 +8,7 @@
Custom exception and error classes for Nominatim. Custom exception and error classes for Nominatim.
""" """
class UsageError(Exception): class UsageError(Exception):
""" An error raised because of bad user input. This error will usually """ An error raised because of bad user input. This error will usually
not cause a stack trace to be printed unless debugging is enabled. not cause a stack trace to be printed unless debugging is enabled.

View File

@@ -11,6 +11,7 @@ from typing import Mapping, List, Optional
import re import re
class Locales: class Locales:
""" Helper class for localization of names. """ Helper class for localization of names.
@@ -28,24 +29,20 @@ class Locales:
self._add_lang_tags('official_name', 'short_name') self._add_lang_tags('official_name', 'short_name')
self._add_tags('official_name', 'short_name', 'ref') self._add_tags('official_name', 'short_name', 'ref')
def __bool__(self) -> bool: def __bool__(self) -> bool:
return len(self.languages) > 0 return len(self.languages) > 0
def _add_tags(self, *tags: str) -> None: def _add_tags(self, *tags: str) -> None:
for tag in tags: for tag in tags:
self.name_tags.append(tag) self.name_tags.append(tag)
self.name_tags.append(f"_place_{tag}") self.name_tags.append(f"_place_{tag}")
def _add_lang_tags(self, *tags: str) -> None: def _add_lang_tags(self, *tags: str) -> None:
for tag in tags: for tag in tags:
for lang in self.languages: for lang in self.languages:
self.name_tags.append(f"{tag}:{lang}") self.name_tags.append(f"{tag}:{lang}")
self.name_tags.append(f"_place_{tag}:{lang}") self.name_tags.append(f"_place_{tag}:{lang}")
def display_name(self, names: Optional[Mapping[str, str]]) -> str: def display_name(self, names: Optional[Mapping[str, str]]) -> str:
""" Return the best matching name from a dictionary of names """ Return the best matching name from a dictionary of names
containing different name variants. containing different name variants.
@@ -64,7 +61,6 @@ class Locales:
# Nothing? Return any of the other names as a default. # Nothing? Return any of the other names as a default.
return next(iter(names.values())) return next(iter(names.values()))
@staticmethod @staticmethod
def from_accept_languages(langstr: str) -> 'Locales': def from_accept_languages(langstr: str) -> 'Locales':
""" Create a localization object from a language list in the """ Create a localization object from a language list in the

View File

@@ -49,41 +49,35 @@ class BaseLogger:
""" Start a new debug chapter for the given function and its parameters. """ Start a new debug chapter for the given function and its parameters.
""" """
def section(self, heading: str) -> None: def section(self, heading: str) -> None:
""" Start a new section with the given title. """ Start a new section with the given title.
""" """
def comment(self, text: str) -> None: def comment(self, text: str) -> None:
""" Add a simple comment to the debug output. """ Add a simple comment to the debug output.
""" """
def var_dump(self, heading: str, var: Any) -> None: def var_dump(self, heading: str, var: Any) -> None:
""" Print the content of the variable to the debug output prefixed by """ Print the content of the variable to the debug output prefixed by
the given heading. the given heading.
""" """
def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None: def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
""" Print the table generated by the generator function. """ Print the table generated by the generator function.
""" """
def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None: def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
""" Print a list of search results generated by the generator function. """ Print a list of search results generated by the generator function.
""" """
def sql(self, conn: AsyncConnection, statement: 'sa.Executable', def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None: params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
""" Print the SQL for the given statement. """ Print the SQL for the given statement.
""" """
def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable', def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
extra_params: Union[Mapping[str, Any], extra_params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]
Sequence[Mapping[str, Any]], None]) -> str: ) -> str:
""" Return the compiled version of the statement. """ Return the compiled version of the statement.
""" """
compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine) compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
@@ -108,7 +102,7 @@ class BaseLogger:
try: try:
sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr) sqlstr = re.sub(r'__\[POSTCOMPILE_[^]]*\]', '%s', sqlstr)
return sqlstr % tuple((repr(params.get(name, None)) return sqlstr % tuple((repr(params.get(name, None))
for name in compiled.positiontup)) # type: ignore for name in compiled.positiontup)) # type: ignore
except TypeError: except TypeError:
return sqlstr return sqlstr
@@ -121,28 +115,26 @@ class BaseLogger:
assert conn.dialect.name == 'sqlite' assert conn.dialect.name == 'sqlite'
# params in positional order # params in positional order
pparams = (repr(params.get(name, None)) for name in compiled.positiontup) # type: ignore pparams = (repr(params.get(name, None)) for name in compiled.positiontup) # type: ignore
sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', '?', sqlstr) sqlstr = re.sub(r'__\[POSTCOMPILE_([^]]*)\]', '?', sqlstr)
sqlstr = re.sub(r"\?", lambda m: next(pparams), sqlstr) sqlstr = re.sub(r"\?", lambda m: next(pparams), sqlstr)
return sqlstr return sqlstr
class HTMLLogger(BaseLogger): class HTMLLogger(BaseLogger):
""" Logger that formats messages in HTML. """ Logger that formats messages in HTML.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.buffer = io.StringIO() self.buffer = io.StringIO()
def _timestamp(self) -> None: def _timestamp(self) -> None:
self._write(f'<p class="timestamp">[{dt.datetime.now()}]</p>') self._write(f'<p class="timestamp">[{dt.datetime.now()}]</p>')
def get_buffer(self) -> str: def get_buffer(self) -> str:
return HTML_HEADER + self.buffer.getvalue() + HTML_FOOTER return HTML_HEADER + self.buffer.getvalue() + HTML_FOOTER
def function(self, func: str, **kwargs: Any) -> None: def function(self, func: str, **kwargs: Any) -> None:
self._timestamp() self._timestamp()
self._write(f"<h1>Debug output for {func}()</h1>\n<p>Parameters:<dl>") self._write(f"<h1>Debug output for {func}()</h1>\n<p>Parameters:<dl>")
@@ -150,17 +142,14 @@ class HTMLLogger(BaseLogger):
self._write(f'<dt>{name}</dt><dd>{self._python_var(value)}</dd>') self._write(f'<dt>{name}</dt><dd>{self._python_var(value)}</dd>')
self._write('</dl></p>') self._write('</dl></p>')
def section(self, heading: str) -> None: def section(self, heading: str) -> None:
self._timestamp() self._timestamp()
self._write(f"<h2>{heading}</h2>") self._write(f"<h2>{heading}</h2>")
def comment(self, text: str) -> None: def comment(self, text: str) -> None:
self._timestamp() self._timestamp()
self._write(f"<p>{text}</p>") self._write(f"<p>{text}</p>")
def var_dump(self, heading: str, var: Any) -> None: def var_dump(self, heading: str, var: Any) -> None:
self._timestamp() self._timestamp()
if callable(var): if callable(var):
@@ -168,7 +157,6 @@ class HTMLLogger(BaseLogger):
self._write(f'<h5>{heading}</h5>{self._python_var(var)}') self._write(f'<h5>{heading}</h5>{self._python_var(var)}')
def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None: def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
self._timestamp() self._timestamp()
head = next(rows) head = next(rows)
@@ -185,11 +173,11 @@ class HTMLLogger(BaseLogger):
self._write('</tr>') self._write('</tr>')
self._write('</tbody></table>') self._write('</tbody></table>')
def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None: def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
""" Print a list of search results generated by the generator function. """ Print a list of search results generated by the generator function.
""" """
self._timestamp() self._timestamp()
def format_osm(osm_object: Optional[Tuple[str, int]]) -> str: def format_osm(osm_object: Optional[Tuple[str, int]]) -> str:
if not osm_object: if not osm_object:
return '-' return '-'
@@ -218,7 +206,6 @@ class HTMLLogger(BaseLogger):
total += 1 total += 1
self._write(f'</dl><b>TOTAL:</b> {total}</p>') self._write(f'</dl><b>TOTAL:</b> {total}</p>')
def sql(self, conn: AsyncConnection, statement: 'sa.Executable', def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None: params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
self._timestamp() self._timestamp()
@@ -230,7 +217,6 @@ class HTMLLogger(BaseLogger):
else: else:
self._write(f'<code class="lang-sql">{html.escape(sqlstr)}</code>') self._write(f'<code class="lang-sql">{html.escape(sqlstr)}</code>')
def _python_var(self, var: Any) -> str: def _python_var(self, var: Any) -> str:
if CODE_HIGHLIGHT: if CODE_HIGHLIGHT:
fmt = highlight(str(var), PythonLexer(), HtmlFormatter(nowrap=True)) fmt = highlight(str(var), PythonLexer(), HtmlFormatter(nowrap=True))
@@ -238,7 +224,6 @@ class HTMLLogger(BaseLogger):
return f'<code class="lang-python">{html.escape(str(var))}</code>' return f'<code class="lang-python">{html.escape(str(var))}</code>'
def _write(self, text: str) -> None: def _write(self, text: str) -> None:
""" Add the raw text to the debug output. """ Add the raw text to the debug output.
""" """
@@ -251,38 +236,31 @@ class TextLogger(BaseLogger):
def __init__(self) -> None: def __init__(self) -> None:
self.buffer = io.StringIO() self.buffer = io.StringIO()
def _timestamp(self) -> None: def _timestamp(self) -> None:
self._write(f'[{dt.datetime.now()}]\n') self._write(f'[{dt.datetime.now()}]\n')
def get_buffer(self) -> str: def get_buffer(self) -> str:
return self.buffer.getvalue() return self.buffer.getvalue()
def function(self, func: str, **kwargs: Any) -> None: def function(self, func: str, **kwargs: Any) -> None:
self._write(f"#### Debug output for {func}()\n\nParameters:\n") self._write(f"#### Debug output for {func}()\n\nParameters:\n")
for name, value in kwargs.items(): for name, value in kwargs.items():
self._write(f' {name}: {self._python_var(value)}\n') self._write(f' {name}: {self._python_var(value)}\n')
self._write('\n') self._write('\n')
def section(self, heading: str) -> None: def section(self, heading: str) -> None:
self._timestamp() self._timestamp()
self._write(f"\n# {heading}\n\n") self._write(f"\n# {heading}\n\n")
def comment(self, text: str) -> None: def comment(self, text: str) -> None:
self._write(f"{text}\n") self._write(f"{text}\n")
def var_dump(self, heading: str, var: Any) -> None: def var_dump(self, heading: str, var: Any) -> None:
if callable(var): if callable(var):
var = var() var = var()
self._write(f'{heading}:\n {self._python_var(var)}\n\n') self._write(f'{heading}:\n {self._python_var(var)}\n\n')
def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None: def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
self._write(f'{heading}:\n') self._write(f'{heading}:\n')
data = [list(map(self._python_var, row)) if row else None for row in rows] data = [list(map(self._python_var, row)) if row else None for row in rows]
@@ -291,7 +269,7 @@ class TextLogger(BaseLogger):
maxlens = [max(len(d[i]) for d in data if d) for i in range(num_cols)] maxlens = [max(len(d[i]) for d in data if d) for i in range(num_cols)]
tablewidth = sum(maxlens) + 3 * num_cols + 1 tablewidth = sum(maxlens) + 3 * num_cols + 1
row_format = '| ' +' | '.join(f'{{:<{l}}}' for l in maxlens) + ' |\n' row_format = '| ' + ' | '.join(f'{{:<{ln}}}' for ln in maxlens) + ' |\n'
self._write('-'*tablewidth + '\n') self._write('-'*tablewidth + '\n')
self._write(row_format.format(*data[0])) self._write(row_format.format(*data[0]))
self._write('-'*tablewidth + '\n') self._write('-'*tablewidth + '\n')
@@ -303,7 +281,6 @@ class TextLogger(BaseLogger):
if data[-1]: if data[-1]:
self._write('-'*tablewidth + '\n') self._write('-'*tablewidth + '\n')
def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None: def result_dump(self, heading: str, results: Iterator[Tuple[Any, Any]]) -> None:
self._timestamp() self._timestamp()
self._write(f'{heading}:\n') self._write(f'{heading}:\n')
@@ -318,18 +295,15 @@ class TextLogger(BaseLogger):
total += 1 total += 1
self._write(f'TOTAL: {total}\n\n') self._write(f'TOTAL: {total}\n\n')
def sql(self, conn: AsyncConnection, statement: 'sa.Executable', def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None: params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
self._timestamp() self._timestamp()
sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement, params), width=78)) sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement, params), width=78))
self._write(f"| {sqlstr}\n\n") self._write(f"| {sqlstr}\n\n")
def _python_var(self, var: Any) -> str: def _python_var(self, var: Any) -> str:
return str(var) return str(var)
def _write(self, text: str) -> None: def _write(self, text: str) -> None:
self.buffer.write(text) self.buffer.write(text)
@@ -368,8 +342,8 @@ HTML_HEADER: str = """<!DOCTYPE html>
<title>Nominatim - Debug</title> <title>Nominatim - Debug</title>
<style> <style>
""" + \ """ + \
(HtmlFormatter(nobackground=True).get_style_defs('.highlight') if CODE_HIGHLIGHT else '') +\ (HtmlFormatter(nobackground=True).get_style_defs('.highlight') if CODE_HIGHLIGHT else '') + \
""" """
h2 { font-size: x-large } h2 { font-size: x-large }
dl { dl {

View File

@@ -127,7 +127,7 @@ async def find_in_postcode(conn: SearchConnection, place: ntyp.PlaceRef,
async def find_in_all_tables(conn: SearchConnection, place: ntyp.PlaceRef, async def find_in_all_tables(conn: SearchConnection, place: ntyp.PlaceRef,
add_geometries: GeomFunc add_geometries: GeomFunc
) -> Tuple[Optional[SaRow], RowFunc[nres.BaseResultT]]: ) -> Tuple[Optional[SaRow], RowFunc[nres.BaseResultT]]:
""" Search for the given place in all data tables """ Search for the given place in all data tables
and return the base information. and return the base information.
""" """
@@ -219,7 +219,6 @@ async def get_simple_place(conn: SearchConnection, place: ntyp.PlaceRef,
return sql.add_columns(*out) return sql.add_columns(*out)
row_func: RowFunc[nres.SearchResult] row_func: RowFunc[nres.SearchResult]
row, row_func = await find_in_all_tables(conn, place, _add_geometry) row, row_func = await find_in_all_tables(conn, place, _add_geometry)

View File

@@ -14,7 +14,7 @@ import importlib
from .server.content_types import CONTENT_JSON from .server.content_types import CONTENT_JSON
T = TypeVar('T') # pylint: disable=invalid-name T = TypeVar('T') # pylint: disable=invalid-name
FormatFunc = Callable[[T, Mapping[str, Any]], str] FormatFunc = Callable[[T, Mapping[str, Any]], str]
ErrorFormatFunc = Callable[[str, str, int], str] ErrorFormatFunc = Callable[[str, str, int], str]
@@ -31,7 +31,6 @@ class FormatDispatcher:
self.content_types.update(content_types) self.content_types.update(content_types)
self.format_functions: Dict[Type[Any], Dict[str, FormatFunc[Any]]] = defaultdict(dict) self.format_functions: Dict[Type[Any], Dict[str, FormatFunc[Any]]] = defaultdict(dict)
def format_func(self, result_class: Type[T], def format_func(self, result_class: Type[T],
fmt: str) -> Callable[[FormatFunc[T]], FormatFunc[T]]: fmt: str) -> Callable[[FormatFunc[T]], FormatFunc[T]]:
""" Decorator for a function that formats a given type of result into the """ Decorator for a function that formats a given type of result into the
@@ -43,7 +42,6 @@ class FormatDispatcher:
return decorator return decorator
def error_format_func(self, func: ErrorFormatFunc) -> ErrorFormatFunc: def error_format_func(self, func: ErrorFormatFunc) -> ErrorFormatFunc:
""" Decorator for a function that formats error messges. """ Decorator for a function that formats error messges.
There is only one error formatter per dispatcher. Using There is only one error formatter per dispatcher. Using
@@ -52,19 +50,16 @@ class FormatDispatcher:
self.error_handler = func self.error_handler = func
return func return func
def list_formats(self, result_type: Type[Any]) -> List[str]: def list_formats(self, result_type: Type[Any]) -> List[str]:
""" Return a list of formats supported by this formatter. """ Return a list of formats supported by this formatter.
""" """
return list(self.format_functions[result_type].keys()) return list(self.format_functions[result_type].keys())
def supports_format(self, result_type: Type[Any], fmt: str) -> bool: def supports_format(self, result_type: Type[Any], fmt: str) -> bool:
""" Check if the given format is supported by this formatter. """ Check if the given format is supported by this formatter.
""" """
return fmt in self.format_functions[result_type] return fmt in self.format_functions[result_type]
def format_result(self, result: Any, fmt: str, options: Mapping[str, Any]) -> str: def format_result(self, result: Any, fmt: str, options: Mapping[str, Any]) -> str:
""" Convert the given result into a string using the given format. """ Convert the given result into a string using the given format.
@@ -73,7 +68,6 @@ class FormatDispatcher:
""" """
return self.format_functions[type(result)][fmt](result, options) return self.format_functions[type(result)][fmt](result, options)
def format_error(self, content_type: str, msg: str, status: int) -> str: def format_error(self, content_type: str, msg: str, status: int) -> str:
""" Convert the given error message into a response string """ Convert the given error message into a response string
taking the requested content_type into account. taking the requested content_type into account.
@@ -82,7 +76,6 @@ class FormatDispatcher:
""" """
return self.error_handler(content_type, msg, status) return self.error_handler(content_type, msg, status)
def set_content_type(self, fmt: str, content_type: str) -> None: def set_content_type(self, fmt: str, content_type: str) -> None:
""" Set the content type for the given format. This is the string """ Set the content type for the given format. This is the string
that will be returned in the Content-Type header of the HTML that will be returned in the Content-Type header of the HTML
@@ -90,7 +83,6 @@ class FormatDispatcher:
""" """
self.content_types[fmt] = content_type self.content_types[fmt] = content_type
def get_content_type(self, fmt: str) -> str: def get_content_type(self, fmt: str) -> str:
""" Return the content type for the given format. """ Return the content type for the given format.

View File

@@ -26,7 +26,7 @@ from .logging import log
from .localization import Locales from .localization import Locales
# This file defines complex result data classes. # This file defines complex result data classes.
# pylint: disable=too-many-instance-attributes
def _mingle_name_tags(names: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]: def _mingle_name_tags(names: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]:
""" Mix-in names from linked places, so that they show up """ Mix-in names from linked places, so that they show up
@@ -153,7 +153,6 @@ class AddressLines(List[AddressLine]):
return label_parts return label_parts
@dataclasses.dataclass @dataclasses.dataclass
class WordInfo: class WordInfo:
""" Each entry in the list of search terms contains the """ Each entry in the list of search terms contains the
@@ -183,7 +182,7 @@ class BaseResult:
category: Tuple[str, str] category: Tuple[str, str]
centroid: Point centroid: Point
place_id : Optional[int] = None place_id: Optional[int] = None
osm_object: Optional[Tuple[str, int]] = None osm_object: Optional[Tuple[str, int]] = None
parent_place_id: Optional[int] = None parent_place_id: Optional[int] = None
linked_place_id: Optional[int] = None linked_place_id: Optional[int] = None
@@ -220,14 +219,12 @@ class BaseResult:
""" """
return self.centroid[1] return self.centroid[1]
@property @property
def lon(self) -> float: def lon(self) -> float:
""" Get the longitude (or x) of the center point of the place. """ Get the longitude (or x) of the center point of the place.
""" """
return self.centroid[0] return self.centroid[0]
def calculated_importance(self) -> float: def calculated_importance(self) -> float:
""" Get a valid importance value. This is either the stored importance """ Get a valid importance value. This is either the stored importance
of the value or an artificial value computed from the place's of the value or an artificial value computed from the place's
@@ -235,7 +232,6 @@ class BaseResult:
""" """
return self.importance or (0.40001 - (self.rank_search/75.0)) return self.importance or (0.40001 - (self.rank_search/75.0))
def localize(self, locales: Locales) -> None: def localize(self, locales: Locales) -> None:
""" Fill the locale_name and the display_name field for the """ Fill the locale_name and the display_name field for the
place and, if available, its address information. place and, if available, its address information.
@@ -247,9 +243,9 @@ class BaseResult:
self.display_name = self.locale_name self.display_name = self.locale_name
BaseResultT = TypeVar('BaseResultT', bound=BaseResult) BaseResultT = TypeVar('BaseResultT', bound=BaseResult)
@dataclasses.dataclass @dataclasses.dataclass
class DetailedResult(BaseResult): class DetailedResult(BaseResult):
""" A search result with more internal information from the database """ A search result with more internal information from the database
@@ -279,13 +275,12 @@ class SearchResult(BaseResult):
bbox: Optional[Bbox] = None bbox: Optional[Bbox] = None
accuracy: float = 0.0 accuracy: float = 0.0
@property @property
def ranking(self) -> float: def ranking(self) -> float:
""" Return the ranking, a combined measure of accuracy and importance. """ Return the ranking, a combined measure of accuracy and importance.
""" """
return (self.accuracy if self.accuracy is not None else 1) \ return (self.accuracy if self.accuracy is not None else 1) \
- self.calculated_importance() - self.calculated_importance()
class SearchResults(List[SearchResult]): class SearchResults(List[SearchResult]):
@@ -295,7 +290,7 @@ class SearchResults(List[SearchResult]):
def _filter_geometries(row: SaRow) -> Dict[str, str]: def _filter_geometries(row: SaRow) -> Dict[str, str]:
return {k[9:]: v for k, v in row._mapping.items() # pylint: disable=W0212 return {k[9:]: v for k, v in row._mapping.items()
if k.startswith('geometry_')} if k.startswith('geometry_')}
@@ -312,9 +307,9 @@ def create_from_placex_row(row: Optional[SaRow],
place_id=row.place_id, place_id=row.place_id,
osm_object=(row.osm_type, row.osm_id), osm_object=(row.osm_type, row.osm_id),
category=(row.class_, row.type), category=(row.class_, row.type),
parent_place_id = row.parent_place_id, parent_place_id=row.parent_place_id,
linked_place_id = getattr(row, 'linked_place_id', None), linked_place_id=getattr(row, 'linked_place_id', None),
admin_level = getattr(row, 'admin_level', 15), admin_level=getattr(row, 'admin_level', 15),
names=_mingle_name_tags(row.name), names=_mingle_name_tags(row.name),
address=row.address, address=row.address,
extratags=row.extratags, extratags=row.extratags,
@@ -345,7 +340,7 @@ def create_from_osmline_row(row: Optional[SaRow],
res = class_type(source_table=SourceTable.OSMLINE, res = class_type(source_table=SourceTable.OSMLINE,
place_id=row.place_id, place_id=row.place_id,
parent_place_id = row.parent_place_id, parent_place_id=row.parent_place_id,
osm_object=('W', row.osm_id), osm_object=('W', row.osm_id),
category=('place', 'houses' if hnr is None else 'house'), category=('place', 'houses' if hnr is None else 'house'),
address=row.address, address=row.address,
@@ -382,7 +377,7 @@ def create_from_tiger_row(row: Optional[SaRow],
res = class_type(source_table=SourceTable.TIGER, res = class_type(source_table=SourceTable.TIGER,
place_id=row.place_id, place_id=row.place_id,
parent_place_id = row.parent_place_id, parent_place_id=row.parent_place_id,
osm_object=(osm_type or row.osm_type, osm_id or row.osm_id), osm_object=(osm_type or row.osm_type, osm_id or row.osm_id),
category=('place', 'houses' if hnr is None else 'house'), category=('place', 'houses' if hnr is None else 'house'),
postcode=row.postcode, postcode=row.postcode,
@@ -401,7 +396,7 @@ def create_from_tiger_row(row: Optional[SaRow],
def create_from_postcode_row(row: Optional[SaRow], def create_from_postcode_row(row: Optional[SaRow],
class_type: Type[BaseResultT]) -> Optional[BaseResultT]: class_type: Type[BaseResultT]) -> Optional[BaseResultT]:
""" Construct a new result and add the data from the result row """ Construct a new result and add the data from the result row
from the postcode table. 'class_type' defines from the postcode table. 'class_type' defines
the type of result to return. Returns None if the row is None. the type of result to return. Returns None if the row is None.
@@ -411,7 +406,7 @@ def create_from_postcode_row(row: Optional[SaRow],
return class_type(source_table=SourceTable.POSTCODE, return class_type(source_table=SourceTable.POSTCODE,
place_id=row.place_id, place_id=row.place_id,
parent_place_id = row.parent_place_id, parent_place_id=row.parent_place_id,
category=('place', 'postcode'), category=('place', 'postcode'),
names={'ref': row.postcode}, names={'ref': row.postcode},
rank_search=row.rank_search, rank_search=row.rank_search,
@@ -422,7 +417,7 @@ def create_from_postcode_row(row: Optional[SaRow],
def create_from_country_row(row: Optional[SaRow], def create_from_country_row(row: Optional[SaRow],
class_type: Type[BaseResultT]) -> Optional[BaseResultT]: class_type: Type[BaseResultT]) -> Optional[BaseResultT]:
""" Construct a new result and add the data from the result row """ Construct a new result and add the data from the result row
from the fallback country tables. 'class_type' defines from the fallback country tables. 'class_type' defines
the type of result to return. Returns None if the row is None. the type of result to return. Returns None if the row is None.
@@ -535,7 +530,7 @@ async def _finalize_entry(conn: SearchConnection, result: BaseResultT) -> None:
distance=0.0)) distance=0.0))
result.address_rows.append(AddressLine( result.address_rows.append(AddressLine(
category=('place', 'country_code'), category=('place', 'country_code'),
names={'ref': result.country_code}, extratags = {}, names={'ref': result.country_code}, extratags={},
fromarea=True, isaddress=False, rank_address=4, fromarea=True, isaddress=False, rank_address=4,
distance=0.0)) distance=0.0))
@@ -580,12 +575,12 @@ async def complete_address_details(conn: SearchConnection, results: List[BaseRes
for result in results: for result in results:
_setup_address_details(result) _setup_address_details(result)
### Lookup entries from place_address line # Lookup entries from place_address line
lookup_ids = [{'pid': r.place_id, lookup_ids = [{'pid': r.place_id,
'lid': _get_address_lookup_id(r), 'lid': _get_address_lookup_id(r),
'names': list(r.address.values()) if r.address else [], 'names': list(r.address.values()) if r.address else [],
'c': ('SRID=4326;' + r.centroid.to_wkt()) if r.centroid else '' } 'c': ('SRID=4326;' + r.centroid.to_wkt()) if r.centroid else ''}
for r in results if r.place_id] for r in results if r.place_id]
if not lookup_ids: if not lookup_ids:
@@ -621,7 +616,6 @@ async def complete_address_details(conn: SearchConnection, results: List[BaseRes
.order_by(taddr.c.distance.desc())\ .order_by(taddr.c.distance.desc())\
.order_by(t.c.rank_search.desc()) .order_by(t.c.rank_search.desc())
current_result = None current_result = None
current_rank_address = -1 current_rank_address = -1
for row in await conn.execute(sql): for row in await conn.execute(sql):
@@ -649,8 +643,7 @@ async def complete_address_details(conn: SearchConnection, results: List[BaseRes
for result in results: for result in results:
await _finalize_entry(conn, result) await _finalize_entry(conn, result)
# Finally add the record for the parent entry where necessary.
### Finally add the record for the parent entry where necessary.
parent_lookup_ids = list(filter(lambda e: e['pid'] != e['lid'], lookup_ids)) parent_lookup_ids = list(filter(lambda e: e['pid'] != e['lid'], lookup_ids))
if parent_lookup_ids: if parent_lookup_ids:
@@ -661,7 +654,7 @@ async def complete_address_details(conn: SearchConnection, results: List[BaseRes
t.c.class_, t.c.type, t.c.extratags, t.c.class_, t.c.type, t.c.extratags,
t.c.admin_level, t.c.admin_level,
t.c.rank_address)\ t.c.rank_address)\
.where(t.c.place_id == ltab.c.value['lid'].as_integer()) .where(t.c.place_id == ltab.c.value['lid'].as_integer())
for row in await conn.execute(sql): for row in await conn.execute(sql):
current_result = next((r for r in results if r.place_id == row.src_place_id), None) current_result = next((r for r in results if r.place_id == row.src_place_id), None)
@@ -677,7 +670,7 @@ async def complete_address_details(conn: SearchConnection, results: List[BaseRes
fromarea=True, isaddress=True, fromarea=True, isaddress=True,
rank_address=row.rank_address, distance=0.0)) rank_address=row.rank_address, distance=0.0))
### Now sort everything # Now sort everything
def mk_sort_key(place_id: Optional[int]) -> Callable[[AddressLine], Tuple[bool, int, bool]]: def mk_sort_key(place_id: Optional[int]) -> Callable[[AddressLine], Tuple[bool, int, bool]]:
return lambda a: (a.place_id != place_id, -a.rank_address, a.isaddress) return lambda a: (a.place_id != place_id, -a.rank_address, a.isaddress)
@@ -706,7 +699,7 @@ async def complete_linked_places(conn: SearchConnection, result: BaseResult) ->
return return
sql = _placex_select_address_row(conn, result.centroid)\ sql = _placex_select_address_row(conn, result.centroid)\
.where(conn.t.placex.c.linked_place_id == result.place_id) .where(conn.t.placex.c.linked_place_id == result.place_id)
for row in await conn.execute(sql): for row in await conn.execute(sql):
result.linked_rows.append(_result_row_to_address_row(row)) result.linked_rows.append(_result_row_to_address_row(row))
@@ -745,8 +738,8 @@ async def complete_parented_places(conn: SearchConnection, result: BaseResult) -
return return
sql = _placex_select_address_row(conn, result.centroid)\ sql = _placex_select_address_row(conn, result.centroid)\
.where(conn.t.placex.c.parent_place_id == result.place_id)\ .where(conn.t.placex.c.parent_place_id == result.place_id)\
.where(conn.t.placex.c.rank_search == 30) .where(conn.t.placex.c.rank_search == 30)
for row in await conn.execute(sql): for row in await conn.execute(sql):
result.parented_rows.append(_result_row_to_address_row(row)) result.parented_rows.append(_result_row_to_address_row(row))

View File

@@ -12,7 +12,7 @@ import functools
import sqlalchemy as sa import sqlalchemy as sa
from .typing import SaColumn, SaSelect, SaFromClause, SaLabel, SaRow,\ from .typing import SaColumn, SaSelect, SaFromClause, SaLabel, SaRow, \
SaBind, SaLambdaSelect SaBind, SaLambdaSelect
from .sql.sqlalchemy_types import Geometry from .sql.sqlalchemy_types import Geometry
from .connection import SearchConnection from .connection import SearchConnection
@@ -29,11 +29,12 @@ RowFunc = Callable[[Optional[SaRow], Type[nres.ReverseResult]], Optional[nres.Re
WKT_PARAM: SaBind = sa.bindparam('wkt', type_=Geometry) WKT_PARAM: SaBind = sa.bindparam('wkt', type_=Geometry)
MAX_RANK_PARAM: SaBind = sa.bindparam('max_rank') MAX_RANK_PARAM: SaBind = sa.bindparam('max_rank')
def no_index(expr: SaColumn) -> SaColumn: def no_index(expr: SaColumn) -> SaColumn:
""" Wrap the given expression, so that the query planner will """ Wrap the given expression, so that the query planner will
refrain from using the expression for index lookup. refrain from using the expression for index lookup.
""" """
return sa.func.coalesce(sa.null(), expr) # pylint: disable=not-callable return sa.func.coalesce(sa.null(), expr)
def _select_from_placex(t: SaFromClause, use_wkt: bool = True) -> SaSelect: def _select_from_placex(t: SaFromClause, use_wkt: bool = True) -> SaSelect:
@@ -48,7 +49,6 @@ def _select_from_placex(t: SaFromClause, use_wkt: bool = True) -> SaSelect:
centroid = sa.case((t.c.geometry.is_line_like(), t.c.geometry.ST_ClosestPoint(WKT_PARAM)), centroid = sa.case((t.c.geometry.is_line_like(), t.c.geometry.ST_ClosestPoint(WKT_PARAM)),
else_=t.c.centroid).label('centroid') else_=t.c.centroid).label('centroid')
return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name, return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
t.c.class_, t.c.type, t.c.class_, t.c.type,
t.c.address, t.c.extratags, t.c.address, t.c.extratags,
@@ -63,8 +63,8 @@ def _select_from_placex(t: SaFromClause, use_wkt: bool = True) -> SaSelect:
def _interpolated_housenumber(table: SaFromClause) -> SaLabel: def _interpolated_housenumber(table: SaFromClause) -> SaLabel:
return sa.cast(table.c.startnumber return sa.cast(table.c.startnumber
+ sa.func.round(((table.c.endnumber - table.c.startnumber) * table.c.position) + sa.func.round(((table.c.endnumber - table.c.startnumber) * table.c.position)
/ table.c.step) * table.c.step, / table.c.step) * table.c.step,
sa.Integer).label('housenumber') sa.Integer).label('housenumber')
@@ -72,8 +72,8 @@ def _interpolated_position(table: SaFromClause) -> SaLabel:
fac = sa.cast(table.c.step, sa.Float) / (table.c.endnumber - table.c.startnumber) fac = sa.cast(table.c.step, sa.Float) / (table.c.endnumber - table.c.startnumber)
rounded_pos = sa.func.round(table.c.position / fac) * fac rounded_pos = sa.func.round(table.c.position / fac) * fac
return sa.case( return sa.case(
(table.c.endnumber == table.c.startnumber, table.c.linegeo.ST_Centroid()), (table.c.endnumber == table.c.startnumber, table.c.linegeo.ST_Centroid()),
else_=table.c.linegeo.ST_LineInterpolatePoint(rounded_pos)).label('centroid') else_=table.c.linegeo.ST_LineInterpolatePoint(rounded_pos)).label('centroid')
def _locate_interpolation(table: SaFromClause) -> SaLabel: def _locate_interpolation(table: SaFromClause) -> SaLabel:
@@ -101,38 +101,32 @@ class ReverseGeocoder:
self.bind_params: Dict[str, Any] = {'max_rank': params.max_rank} self.bind_params: Dict[str, Any] = {'max_rank': params.max_rank}
@property @property
def max_rank(self) -> int: def max_rank(self) -> int:
""" Return the maximum configured rank. """ Return the maximum configured rank.
""" """
return self.params.max_rank return self.params.max_rank
def has_geometries(self) -> bool: def has_geometries(self) -> bool:
""" Check if any geometries are requested. """ Check if any geometries are requested.
""" """
return bool(self.params.geometry_output) return bool(self.params.geometry_output)
def layer_enabled(self, *layer: DataLayer) -> bool: def layer_enabled(self, *layer: DataLayer) -> bool:
""" Return true when any of the given layer types are requested. """ Return true when any of the given layer types are requested.
""" """
return any(self.params.layers & l for l in layer) return any(self.params.layers & ly for ly in layer)
def layer_disabled(self, *layer: DataLayer) -> bool: def layer_disabled(self, *layer: DataLayer) -> bool:
""" Return true when none of the given layer types is requested. """ Return true when none of the given layer types is requested.
""" """
return not any(self.params.layers & l for l in layer) return not any(self.params.layers & ly for ly in layer)
def has_feature_layers(self) -> bool: def has_feature_layers(self) -> bool:
""" Return true if any layer other than ADDRESS or POI is requested. """ Return true if any layer other than ADDRESS or POI is requested.
""" """
return self.layer_enabled(DataLayer.RAILWAY, DataLayer.MANMADE, DataLayer.NATURAL) return self.layer_enabled(DataLayer.RAILWAY, DataLayer.MANMADE, DataLayer.NATURAL)
def _add_geometry_columns(self, sql: SaLambdaSelect, col: SaColumn) -> SaSelect: def _add_geometry_columns(self, sql: SaLambdaSelect, col: SaColumn) -> SaSelect:
out = [] out = []
@@ -150,7 +144,6 @@ class ReverseGeocoder:
return sql.add_columns(*out) return sql.add_columns(*out)
def _filter_by_layer(self, table: SaFromClause) -> SaColumn: def _filter_by_layer(self, table: SaFromClause) -> SaColumn:
if self.layer_enabled(DataLayer.MANMADE): if self.layer_enabled(DataLayer.MANMADE):
exclude = [] exclude = []
@@ -167,7 +160,6 @@ class ReverseGeocoder:
include.extend(('natural', 'water', 'waterway')) include.extend(('natural', 'water', 'waterway'))
return table.c.class_.in_(tuple(include)) return table.c.class_.in_(tuple(include))
async def _find_closest_street_or_poi(self, distance: float) -> Optional[SaRow]: async def _find_closest_street_or_poi(self, distance: float) -> Optional[SaRow]:
""" Look up the closest rank 26+ place in the database, which """ Look up the closest rank 26+ place in the database, which
is closer than the given distance. is closer than the given distance.
@@ -179,14 +171,15 @@ class ReverseGeocoder:
# when used with prepared statements # when used with prepared statements
diststr = sa.text(f"{distance}") diststr = sa.text(f"{distance}")
sql: SaLambdaSelect = sa.lambda_stmt(lambda: _select_from_placex(t) sql: SaLambdaSelect = sa.lambda_stmt(
.where(t.c.geometry.within_distance(WKT_PARAM, diststr)) lambda: _select_from_placex(t)
.where(t.c.indexed_status == 0) .where(t.c.geometry.within_distance(WKT_PARAM, diststr))
.where(t.c.linked_place_id == None) .where(t.c.indexed_status == 0)
.where(sa.or_(sa.not_(t.c.geometry.is_area()), .where(t.c.linked_place_id == None)
t.c.centroid.ST_Distance(WKT_PARAM) < diststr)) .where(sa.or_(sa.not_(t.c.geometry.is_area()),
.order_by('distance') t.c.centroid.ST_Distance(WKT_PARAM) < diststr))
.limit(2)) .order_by('distance')
.limit(2))
if self.has_geometries(): if self.has_geometries():
sql = self._add_geometry_columns(sql, t.c.geometry) sql = self._add_geometry_columns(sql, t.c.geometry)
@@ -227,7 +220,6 @@ class ReverseGeocoder:
return prev_row return prev_row
async def _find_housenumber_for_street(self, parent_place_id: int) -> Optional[SaRow]: async def _find_housenumber_for_street(self, parent_place_id: int) -> Optional[SaRow]:
t = self.conn.t.placex t = self.conn.t.placex
@@ -249,7 +241,6 @@ class ReverseGeocoder:
return (await self.conn.execute(sql, self.bind_params)).one_or_none() return (await self.conn.execute(sql, self.bind_params)).one_or_none()
async def _find_interpolation_for_street(self, parent_place_id: Optional[int], async def _find_interpolation_for_street(self, parent_place_id: Optional[int],
distance: float) -> Optional[SaRow]: distance: float) -> Optional[SaRow]:
t = self.conn.t.osmline t = self.conn.t.osmline
@@ -268,11 +259,11 @@ class ReverseGeocoder:
inner = sql.subquery('ipol') inner = sql.subquery('ipol')
sql = sa.select(inner.c.place_id, inner.c.osm_id, sql = sa.select(inner.c.place_id, inner.c.osm_id,
inner.c.parent_place_id, inner.c.address, inner.c.parent_place_id, inner.c.address,
_interpolated_housenumber(inner), _interpolated_housenumber(inner),
_interpolated_position(inner), _interpolated_position(inner),
inner.c.postcode, inner.c.country_code, inner.c.postcode, inner.c.country_code,
inner.c.distance) inner.c.distance)
if self.has_geometries(): if self.has_geometries():
sub = sql.subquery('geom') sub = sql.subquery('geom')
@@ -280,7 +271,6 @@ class ReverseGeocoder:
return (await self.conn.execute(sql, self.bind_params)).one_or_none() return (await self.conn.execute(sql, self.bind_params)).one_or_none()
async def _find_tiger_number_for_street(self, parent_place_id: int) -> Optional[SaRow]: async def _find_tiger_number_for_street(self, parent_place_id: int) -> Optional[SaRow]:
t = self.conn.t.tiger t = self.conn.t.tiger
@@ -310,7 +300,6 @@ class ReverseGeocoder:
return (await self.conn.execute(sql, self.bind_params)).one_or_none() return (await self.conn.execute(sql, self.bind_params)).one_or_none()
async def lookup_street_poi(self) -> Tuple[Optional[SaRow], RowFunc]: async def lookup_street_poi(self) -> Tuple[Optional[SaRow], RowFunc]:
""" Find a street or POI/address for the given WKT point. """ Find a street or POI/address for the given WKT point.
""" """
@@ -365,7 +354,6 @@ class ReverseGeocoder:
return row, row_func return row, row_func
async def _lookup_area_address(self) -> Optional[SaRow]: async def _lookup_area_address(self) -> Optional[SaRow]:
""" Lookup large addressable areas for the given WKT point. """ Lookup large addressable areas for the given WKT point.
""" """
@@ -384,9 +372,9 @@ class ReverseGeocoder:
.subquery('area') .subquery('area')
return _select_from_placex(inner, False)\ return _select_from_placex(inner, False)\
.where(inner.c.geometry.ST_Contains(WKT_PARAM))\ .where(inner.c.geometry.ST_Contains(WKT_PARAM))\
.order_by(sa.desc(inner.c.rank_search))\ .order_by(sa.desc(inner.c.rank_search))\
.limit(1) .limit(1)
sql: SaLambdaSelect = sa.lambda_stmt(_base_query) sql: SaLambdaSelect = sa.lambda_stmt(_base_query)
if self.has_geometries(): if self.has_geometries():
@@ -403,15 +391,14 @@ class ReverseGeocoder:
def _place_inside_area_query() -> SaSelect: def _place_inside_area_query() -> SaSelect:
inner = \ inner = \
sa.select(t, sa.select(t, t.c.geometry.ST_Distance(WKT_PARAM).label('distance'))\
t.c.geometry.ST_Distance(WKT_PARAM).label('distance'))\ .where(t.c.rank_search > address_rank)\
.where(t.c.rank_search > address_rank)\ .where(t.c.rank_search <= MAX_RANK_PARAM)\
.where(t.c.rank_search <= MAX_RANK_PARAM)\ .where(t.c.indexed_status == 0)\
.where(t.c.indexed_status == 0)\ .where(sa.func.IntersectsReverseDistance(t, WKT_PARAM))\
.where(sa.func.IntersectsReverseDistance(t, WKT_PARAM))\ .order_by(sa.desc(t.c.rank_search))\
.order_by(sa.desc(t.c.rank_search))\ .limit(50)\
.limit(50)\ .subquery('places')
.subquery('places')
touter = t.alias('outer') touter = t.alias('outer')
return _select_from_placex(inner, False)\ return _select_from_placex(inner, False)\
@@ -435,7 +422,6 @@ class ReverseGeocoder:
return address_row return address_row
async def _lookup_area_others(self) -> Optional[SaRow]: async def _lookup_area_others(self) -> Optional[SaRow]:
t = self.conn.t.placex t = self.conn.t.placex
@@ -453,10 +439,10 @@ class ReverseGeocoder:
.subquery() .subquery()
sql = _select_from_placex(inner, False)\ sql = _select_from_placex(inner, False)\
.where(sa.or_(sa.not_(inner.c.geometry.is_area()), .where(sa.or_(sa.not_(inner.c.geometry.is_area()),
inner.c.geometry.ST_Contains(WKT_PARAM)))\ inner.c.geometry.ST_Contains(WKT_PARAM)))\
.order_by(sa.desc(inner.c.rank_search), inner.c.distance)\ .order_by(sa.desc(inner.c.rank_search), inner.c.distance)\
.limit(1) .limit(1)
if self.has_geometries(): if self.has_geometries():
sql = self._add_geometry_columns(sql, inner.c.geometry) sql = self._add_geometry_columns(sql, inner.c.geometry)
@@ -466,7 +452,6 @@ class ReverseGeocoder:
return row return row
async def lookup_area(self) -> Optional[SaRow]: async def lookup_area(self) -> Optional[SaRow]:
""" Lookup large areas for the current search. """ Lookup large areas for the current search.
""" """
@@ -484,7 +469,6 @@ class ReverseGeocoder:
return _get_closest(address_row, other_row) return _get_closest(address_row, other_row)
async def lookup_country_codes(self) -> List[str]: async def lookup_country_codes(self) -> List[str]:
""" Lookup the country for the current search. """ Lookup the country for the current search.
""" """
@@ -497,7 +481,6 @@ class ReverseGeocoder:
log().var_dump('Country codes', ccodes) log().var_dump('Country codes', ccodes)
return ccodes return ccodes
async def lookup_country(self, ccodes: List[str]) -> Optional[SaRow]: async def lookup_country(self, ccodes: List[str]) -> Optional[SaRow]:
""" Lookup the country for the current search. """ Lookup the country for the current search.
""" """
@@ -512,17 +495,15 @@ class ReverseGeocoder:
log().comment('Search for place nodes in country') log().comment('Search for place nodes in country')
def _base_query() -> SaSelect: def _base_query() -> SaSelect:
inner = \ inner = sa.select(t, t.c.geometry.ST_Distance(WKT_PARAM).label('distance'))\
sa.select(t, .where(t.c.rank_search > 4)\
t.c.geometry.ST_Distance(WKT_PARAM).label('distance'))\ .where(t.c.rank_search <= MAX_RANK_PARAM)\
.where(t.c.rank_search > 4)\ .where(t.c.indexed_status == 0)\
.where(t.c.rank_search <= MAX_RANK_PARAM)\ .where(t.c.country_code.in_(ccodes))\
.where(t.c.indexed_status == 0)\ .where(sa.func.IntersectsReverseDistance(t, WKT_PARAM))\
.where(t.c.country_code.in_(ccodes))\ .order_by(sa.desc(t.c.rank_search))\
.where(sa.func.IntersectsReverseDistance(t, WKT_PARAM))\ .limit(50)\
.order_by(sa.desc(t.c.rank_search))\ .subquery('area')
.limit(50)\
.subquery('area')
return _select_from_placex(inner, False)\ return _select_from_placex(inner, False)\
.where(sa.func.IsBelowReverseDistance(inner.c.distance, inner.c.rank_search))\ .where(sa.func.IsBelowReverseDistance(inner.c.distance, inner.c.rank_search))\
@@ -561,14 +542,12 @@ class ReverseGeocoder:
return address_row return address_row
async def lookup(self, coord: AnyPoint) -> Optional[nres.ReverseResult]: async def lookup(self, coord: AnyPoint) -> Optional[nres.ReverseResult]:
""" Look up a single coordinate. Returns the place information, """ Look up a single coordinate. Returns the place information,
if a place was found near the coordinates or None otherwise. if a place was found near the coordinates or None otherwise.
""" """
log().function('reverse_lookup', coord=coord, params=self.params) log().function('reverse_lookup', coord=coord, params=self.params)
self.bind_params['wkt'] = f'POINT({coord[0]} {coord[1]})' self.bind_params['wkt'] = f'POINT({coord[0]} {coord[1]})'
row: Optional[SaRow] = None row: Optional[SaRow] = None

View File

@@ -42,7 +42,7 @@ def build_poi_search(category: List[Tuple[str, str]],
class _PoiData(dbf.SearchData): class _PoiData(dbf.SearchData):
penalty = 0.0 penalty = 0.0
qualifiers = dbf.WeightedCategories(category, [0.0] * len(category)) qualifiers = dbf.WeightedCategories(category, [0.0] * len(category))
countries=ccs countries = ccs
return dbs.PoiSearch(_PoiData()) return dbs.PoiSearch(_PoiData())
@@ -55,15 +55,13 @@ class SearchBuilder:
self.query = query self.query = query
self.details = details self.details = details
@property @property
def configured_for_country(self) -> bool: def configured_for_country(self) -> bool:
""" Return true if the search details are configured to """ Return true if the search details are configured to
allow countries in the result. allow countries in the result.
""" """
return self.details.min_rank <= 4 and self.details.max_rank >= 4 \ return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
and self.details.layer_enabled(DataLayer.ADDRESS) and self.details.layer_enabled(DataLayer.ADDRESS)
@property @property
def configured_for_postcode(self) -> bool: def configured_for_postcode(self) -> bool:
@@ -71,8 +69,7 @@ class SearchBuilder:
allow postcodes in the result. allow postcodes in the result.
""" """
return self.details.min_rank <= 5 and self.details.max_rank >= 11\ return self.details.min_rank <= 5 and self.details.max_rank >= 11\
and self.details.layer_enabled(DataLayer.ADDRESS) and self.details.layer_enabled(DataLayer.ADDRESS)
@property @property
def configured_for_housenumbers(self) -> bool: def configured_for_housenumbers(self) -> bool:
@@ -80,8 +77,7 @@ class SearchBuilder:
allow addresses in the result. allow addresses in the result.
""" """
return self.details.max_rank >= 30 \ return self.details.max_rank >= 30 \
and self.details.layer_enabled(DataLayer.ADDRESS) and self.details.layer_enabled(DataLayer.ADDRESS)
def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]: def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
""" Yield all possible abstract searches for the given token assignment. """ Yield all possible abstract searches for the given token assignment.
@@ -92,7 +88,7 @@ class SearchBuilder:
near_items = self.get_near_items(assignment) near_items = self.get_near_items(assignment)
if near_items is not None and not near_items: if near_items is not None and not near_items:
return # impossible compbination of near items and category parameter return # impossible combination of near items and category parameter
if assignment.name is None: if assignment.name is None:
if near_items and not sdata.postcodes: if near_items and not sdata.postcodes:
@@ -123,7 +119,6 @@ class SearchBuilder:
search.penalty += assignment.penalty search.penalty += assignment.penalty
yield search yield search
def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]: def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search query for a simple category search. """ Build abstract search query for a simple category search.
This kind of search requires an additional geographic constraint. This kind of search requires an additional geographic constraint.
@@ -132,7 +127,6 @@ class SearchBuilder:
and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near): and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
yield dbs.PoiSearch(sdata) yield dbs.PoiSearch(sdata)
def build_special_search(self, sdata: dbf.SearchData, def build_special_search(self, sdata: dbf.SearchData,
address: List[TokenRange], address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]: is_category: bool) -> Iterator[dbs.AbstractSearch]:
@@ -157,7 +151,6 @@ class SearchBuilder:
penalty += 0.2 penalty += 0.2
yield dbs.PostcodeSearch(penalty, sdata) yield dbs.PostcodeSearch(penalty, sdata)
def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token], def build_housenumber_search(self, sdata: dbf.SearchData, hnrs: List[Token],
address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]: address: List[TokenRange]) -> Iterator[dbs.AbstractSearch]:
""" Build a simple address search for special entries where the """ Build a simple address search for special entries where the
@@ -167,7 +160,7 @@ class SearchBuilder:
expected_count = sum(t.count for t in hnrs) expected_count = sum(t.count for t in hnrs)
partials = {t.token: t.addr_count for trange in address partials = {t.token: t.addr_count for trange in address
for t in self.query.get_partials_list(trange)} for t in self.query.get_partials_list(trange)}
if not partials: if not partials:
# can happen when none of the partials is indexed # can happen when none of the partials is indexed
@@ -190,7 +183,6 @@ class SearchBuilder:
sdata.housenumbers = dbf.WeightedStrings([], []) sdata.housenumbers = dbf.WeightedStrings([], [])
yield dbs.PlaceSearch(0.05, sdata, expected_count) yield dbs.PlaceSearch(0.05, sdata, expected_count)
def build_name_search(self, sdata: dbf.SearchData, def build_name_search(self, sdata: dbf.SearchData,
name: TokenRange, address: List[TokenRange], name: TokenRange, address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]: is_category: bool) -> Iterator[dbs.AbstractSearch]:
@@ -205,14 +197,13 @@ class SearchBuilder:
sdata.lookups = lookup sdata.lookups = lookup
yield dbs.PlaceSearch(penalty + name_penalty, sdata, count) yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
def yield_lookups(self, name: TokenRange, address: List[TokenRange]
def yield_lookups(self, name: TokenRange, address: List[TokenRange])\ ) -> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
-> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
""" Yield all variants how the given name and address should best """ Yield all variants how the given name and address should best
be searched for. This takes into account how frequent the terms be searched for. This takes into account how frequent the terms
are and tries to find a lookup that optimizes index use. are and tries to find a lookup that optimizes index use.
""" """
penalty = 0.0 # extra penalty penalty = 0.0 # extra penalty
name_partials = {t.token: t for t in self.query.get_partials_list(name)} name_partials = {t.token: t for t in self.query.get_partials_list(name)}
addr_partials = [t for r in address for t in self.query.get_partials_list(r)] addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
@@ -231,7 +222,7 @@ class SearchBuilder:
fulls_count = sum(t.count for t in name_fulls) fulls_count = sum(t.count for t in name_fulls)
if fulls_count < 50000 or addr_count < 30000: if fulls_count < 50000 or addr_count < 30000:
yield penalty,fulls_count / (2**len(addr_tokens)), \ yield penalty, fulls_count / (2**len(addr_tokens)), \
self.get_full_name_ranking(name_fulls, addr_partials, self.get_full_name_ranking(name_fulls, addr_partials,
fulls_count > 30000 / max(1, len(addr_tokens))) fulls_count > 30000 / max(1, len(addr_tokens)))
@@ -241,9 +232,8 @@ class SearchBuilder:
if exp_count < 10000 and addr_count < 20000: if exp_count < 10000 and addr_count < 20000:
penalty += 0.35 * max(1 if name_fulls else 0.1, penalty += 0.35 * max(1 if name_fulls else 0.1,
5 - len(name_partials) - len(addr_tokens)) 5 - len(name_partials) - len(addr_tokens))
yield penalty, exp_count,\ yield penalty, exp_count, \
self.get_name_address_ranking(list(name_partials.keys()), addr_partials) self.get_name_address_ranking(list(name_partials.keys()), addr_partials)
def get_name_address_ranking(self, name_tokens: List[int], def get_name_address_ranking(self, name_tokens: List[int],
addr_partials: List[Token]) -> List[dbf.FieldLookup]: addr_partials: List[Token]) -> List[dbf.FieldLookup]:
@@ -268,7 +258,6 @@ class SearchBuilder:
return lookup return lookup
def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token], def get_full_name_ranking(self, name_fulls: List[Token], addr_partials: List[Token],
use_lookup: bool) -> List[dbf.FieldLookup]: use_lookup: bool) -> List[dbf.FieldLookup]:
""" Create a ranking expression with full name terms and """ Create a ranking expression with full name terms and
@@ -293,7 +282,6 @@ class SearchBuilder:
return dbf.lookup_by_any_name([t.token for t in name_fulls], return dbf.lookup_by_any_name([t.token for t in name_fulls],
addr_restrict_tokens, addr_lookup_tokens) addr_restrict_tokens, addr_lookup_tokens)
def get_name_ranking(self, trange: TokenRange, def get_name_ranking(self, trange: TokenRange,
db_field: str = 'name_vector') -> dbf.FieldRanking: db_field: str = 'name_vector') -> dbf.FieldRanking:
""" Create a ranking expression for a name term in the given range. """ Create a ranking expression for a name term in the given range.
@@ -306,7 +294,6 @@ class SearchBuilder:
default = sum(t.penalty for t in name_partials) + 0.2 default = sum(t.penalty for t in name_partials) + 0.2
return dbf.FieldRanking(db_field, default, ranks) return dbf.FieldRanking(db_field, default, ranks)
def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking: def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
""" Create a list of ranking expressions for an address term """ Create a list of ranking expressions for an address term
for the given ranges. for the given ranges.
@@ -315,7 +302,7 @@ class SearchBuilder:
heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, []))) heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
ranks: List[dbf.RankedTokens] = [] ranks: List[dbf.RankedTokens] = []
while todo: # pylint: disable=too-many-nested-blocks while todo:
neglen, pos, rank = heapq.heappop(todo) neglen, pos, rank = heapq.heappop(todo)
for tlist in self.query.nodes[pos].starting: for tlist in self.query.nodes[pos].starting:
if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD): if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
@@ -354,7 +341,6 @@ class SearchBuilder:
return dbf.FieldRanking('nameaddress_vector', default, ranks) return dbf.FieldRanking('nameaddress_vector', default, ranks)
def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]: def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
""" Collect the tokens for the non-name search fields in the """ Collect the tokens for the non-name search fields in the
assignment. assignment.
@@ -401,7 +387,6 @@ class SearchBuilder:
return sdata return sdata
def get_country_tokens(self, trange: TokenRange) -> List[Token]: def get_country_tokens(self, trange: TokenRange) -> List[Token]:
""" Return the list of country tokens for the given range, """ Return the list of country tokens for the given range,
optionally filtered by the country list from the details optionally filtered by the country list from the details
@@ -413,7 +398,6 @@ class SearchBuilder:
return tokens return tokens
def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]: def get_qualifier_tokens(self, trange: TokenRange) -> List[Token]:
""" Return the list of qualifier tokens for the given range, """ Return the list of qualifier tokens for the given range,
optionally filtered by the qualifier list from the details optionally filtered by the qualifier list from the details
@@ -425,7 +409,6 @@ class SearchBuilder:
return tokens return tokens
def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]: def get_near_items(self, assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
""" Collect tokens for near items search or use the categories """ Collect tokens for near items search or use the categories
requested per parameter. requested per parameter.

View File

@@ -28,11 +28,9 @@ class WeightedStrings:
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self.values) return bool(self.values)
def __iter__(self) -> Iterator[Tuple[str, float]]: def __iter__(self) -> Iterator[Tuple[str, float]]:
return iter(zip(self.values, self.penalties)) return iter(zip(self.values, self.penalties))
def get_penalty(self, value: str, default: float = 1000.0) -> float: def get_penalty(self, value: str, default: float = 1000.0) -> float:
""" Get the penalty for the given value. Returns the given default """ Get the penalty for the given value. Returns the given default
if the value does not exist. if the value does not exist.
@@ -54,11 +52,9 @@ class WeightedCategories:
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self.values) return bool(self.values)
def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]: def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]:
return iter(zip(self.values, self.penalties)) return iter(zip(self.values, self.penalties))
def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float: def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float:
""" Get the penalty for the given value. Returns the given default """ Get the penalty for the given value. Returns the given default
if the value does not exist. if the value does not exist.
@@ -69,7 +65,6 @@ class WeightedCategories:
pass pass
return default return default
def sql_restrict(self, table: SaFromClause) -> SaExpression: def sql_restrict(self, table: SaFromClause) -> SaExpression:
""" Return an SQLAlcheny expression that restricts the """ Return an SQLAlcheny expression that restricts the
class and type columns of the given table to the values class and type columns of the given table to the values
@@ -125,7 +120,6 @@ class FieldRanking:
ranking.penalty -= min_penalty ranking.penalty -= min_penalty
return min_penalty return min_penalty
def sql_penalty(self, table: SaFromClause) -> SaColumn: def sql_penalty(self, table: SaFromClause) -> SaColumn:
""" Create an SQL expression for the rankings. """ Create an SQL expression for the rankings.
""" """
@@ -177,7 +171,6 @@ class SearchData:
qualifiers: WeightedCategories = WeightedCategories([], []) qualifiers: WeightedCategories = WeightedCategories([], [])
def set_strings(self, field: str, tokens: List[Token]) -> None: def set_strings(self, field: str, tokens: List[Token]) -> None:
""" Set on of the WeightedStrings properties from the given """ Set on of the WeightedStrings properties from the given
token list. Adapt the global penalty, so that the token list. Adapt the global penalty, so that the
@@ -191,7 +184,6 @@ class SearchData:
setattr(self, field, wstrs) setattr(self, field, wstrs)
def set_qualifiers(self, tokens: List[Token]) -> None: def set_qualifiers(self, tokens: List[Token]) -> None:
""" Set the qulaifier field from the given tokens. """ Set the qulaifier field from the given tokens.
""" """
@@ -207,7 +199,6 @@ class SearchData:
self.qualifiers = WeightedCategories(list(categories.keys()), self.qualifiers = WeightedCategories(list(categories.keys()),
list(categories.values())) list(categories.values()))
def set_ranking(self, rankings: List[FieldRanking]) -> None: def set_ranking(self, rankings: List[FieldRanking]) -> None:
""" Set the list of rankings and normalize the ranking. """ Set the list of rankings and normalize the ranking.
""" """

View File

@@ -15,10 +15,10 @@ from sqlalchemy.ext.compiler import compiles
from ..typing import SaFromClause from ..typing import SaFromClause
from ..sql.sqlalchemy_types import IntArray from ..sql.sqlalchemy_types import IntArray
# pylint: disable=consider-using-f-string
LookupType = sa.sql.expression.FunctionElement[Any] LookupType = sa.sql.expression.FunctionElement[Any]
class LookupAll(LookupType): class LookupAll(LookupType):
""" Find all entries in search_name table that contain all of """ Find all entries in search_name table that contain all of
a given list of tokens using an index for the search. a given list of tokens using an index for the search.
@@ -40,7 +40,7 @@ def _default_lookup_all(element: LookupAll,
@compiles(LookupAll, 'sqlite') @compiles(LookupAll, 'sqlite')
def _sqlite_lookup_all(element: LookupAll, def _sqlite_lookup_all(element: LookupAll,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
place, col, colname, tokens = list(element.clauses) place, col, colname, tokens = list(element.clauses)
return "(%s IN (SELECT CAST(value as bigint) FROM"\ return "(%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_intersect_fuzzy(places) as p FROM"\ " (SELECT array_intersect_fuzzy(places) as p FROM"\
@@ -50,13 +50,11 @@ def _sqlite_lookup_all(element: LookupAll,
" ORDER BY length(places)) as x) as u,"\ " ORDER BY length(places)) as x) as u,"\
" json_each('[' || u.p || ']'))"\ " json_each('[' || u.p || ']'))"\
" AND array_contains(%s, %s))"\ " AND array_contains(%s, %s))"\
% (compiler.process(place, **kw), % (compiler.process(place, **kw),
compiler.process(tokens, **kw), compiler.process(tokens, **kw),
compiler.process(colname, **kw), compiler.process(colname, **kw),
compiler.process(col, **kw), compiler.process(col, **kw),
compiler.process(tokens, **kw) compiler.process(tokens, **kw))
)
class LookupAny(LookupType): class LookupAny(LookupType):
@@ -69,6 +67,7 @@ class LookupAny(LookupType):
super().__init__(table.c.place_id, getattr(table.c, column), column, super().__init__(table.c.place_id, getattr(table.c, column), column,
sa.type_coerce(tokens, IntArray)) sa.type_coerce(tokens, IntArray))
@compiles(LookupAny) @compiles(LookupAny)
def _default_lookup_any(element: LookupAny, def _default_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
@@ -76,9 +75,10 @@ def _default_lookup_any(element: LookupAny,
return "(%s && %s)" % (compiler.process(col, **kw), return "(%s && %s)" % (compiler.process(col, **kw),
compiler.process(tokens, **kw)) compiler.process(tokens, **kw))
@compiles(LookupAny, 'sqlite') @compiles(LookupAny, 'sqlite')
def _sqlite_lookup_any(element: LookupAny, def _sqlite_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
place, _, colname, tokens = list(element.clauses) place, _, colname, tokens = list(element.clauses)
return "%s IN (SELECT CAST(value as bigint) FROM"\ return "%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_union(places) as p FROM reverse_search_name"\ " (SELECT array_union(places) as p FROM reverse_search_name"\
@@ -89,7 +89,6 @@ def _sqlite_lookup_any(element: LookupAny,
compiler.process(colname, **kw)) compiler.process(colname, **kw))
class Restrict(LookupType): class Restrict(LookupType):
""" Find all entries that contain all of the given tokens. """ Find all entries that contain all of the given tokens.
Do not use an index for the search. Do not use an index for the search.
@@ -103,12 +102,13 @@ class Restrict(LookupType):
@compiles(Restrict) @compiles(Restrict)
def _default_restrict(element: Restrict, def _default_restrict(element: Restrict,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw), return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw)) compiler.process(arg2, **kw))
@compiles(Restrict, 'sqlite') @compiles(Restrict, 'sqlite')
def _sqlite_restrict(element: Restrict, def _sqlite_restrict(element: Restrict,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_contains(%s)" % compiler.process(element.clauses, **kw) return "array_contains(%s)" % compiler.process(element.clauses, **kw)

View File

@@ -20,14 +20,12 @@ from ..types import SearchDetails, DataLayer, GeometryFormat, Bbox
from .. import results as nres from .. import results as nres
from .db_search_fields import SearchData, WeightedCategories from .db_search_fields import SearchData, WeightedCategories
#pylint: disable=singleton-comparison,not-callable
#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
def no_index(expr: SaColumn) -> SaColumn: def no_index(expr: SaColumn) -> SaColumn:
""" Wrap the given expression, so that the query planner will """ Wrap the given expression, so that the query planner will
refrain from using the expression for index lookup. refrain from using the expression for index lookup.
""" """
return sa.func.coalesce(sa.null(), expr) # pylint: disable=not-callable return sa.func.coalesce(sa.null(), expr)
def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]: def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
@@ -68,7 +66,7 @@ def filter_by_area(sql: SaSelect, t: SaFromClause,
if details.viewbox is not None and details.bounded_viewbox: if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM, sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM,
use_index=not avoid_index and use_index=not avoid_index and
details.viewbox.area < 0.2)) details.viewbox.area < 0.2))
return sql return sql
@@ -190,7 +188,7 @@ def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
as rows in the column 'nr'. as rows in the column 'nr'.
""" """
vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\ vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
.table_valued(sa.column('value', type_=sa.JSON)) .table_valued(sa.column('value', type_=sa.JSON))
return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery() return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
@@ -266,7 +264,6 @@ class NearSearch(AbstractSearch):
self.search = search self.search = search
self.categories = categories self.categories = categories
async def lookup(self, conn: SearchConnection, async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults: details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database. """ Find results for the search in the database.
@@ -288,11 +285,12 @@ class NearSearch(AbstractSearch):
else: else:
min_rank = 26 min_rank = 26
max_rank = 30 max_rank = 30
base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX base = nres.SearchResults(r for r in base
and r.accuracy <= max_accuracy if (r.source_table == nres.SourceTable.PLACEX
and r.bbox and r.bbox.area < 20 and r.accuracy <= max_accuracy
and r.rank_address >= min_rank and r.bbox and r.bbox.area < 20
and r.rank_address <= max_rank) and r.rank_address >= min_rank
and r.rank_address <= max_rank))
if base: if base:
baseids = [b.place_id for b in base[:5] if b.place_id] baseids = [b.place_id for b in base[:5] if b.place_id]
@@ -304,7 +302,6 @@ class NearSearch(AbstractSearch):
return results return results
async def lookup_category(self, results: nres.SearchResults, async def lookup_category(self, results: nres.SearchResults,
conn: SearchConnection, ids: List[int], conn: SearchConnection, ids: List[int],
category: Tuple[str, str], penalty: float, category: Tuple[str, str], penalty: float,
@@ -334,9 +331,9 @@ class NearSearch(AbstractSearch):
.join(tgeom, .join(tgeom,
table.c.centroid.ST_CoveredBy( table.c.centroid.ST_CoveredBy(
sa.case((sa.and_(tgeom.c.rank_address > 9, sa.case((sa.and_(tgeom.c.rank_address > 9,
tgeom.c.geometry.is_area()), tgeom.c.geometry.is_area()),
tgeom.c.geometry), tgeom.c.geometry),
else_ = tgeom.c.centroid.ST_Expand(0.05)))) else_=tgeom.c.centroid.ST_Expand(0.05))))
inner = sql.where(tgeom.c.place_id.in_(ids))\ inner = sql.where(tgeom.c.place_id.in_(ids))\
.group_by(table.c.place_id).subquery() .group_by(table.c.place_id).subquery()
@@ -363,7 +360,6 @@ class NearSearch(AbstractSearch):
results.append(result) results.append(result)
class PoiSearch(AbstractSearch): class PoiSearch(AbstractSearch):
""" Category search in a geographic area. """ Category search in a geographic area.
""" """
@@ -372,7 +368,6 @@ class PoiSearch(AbstractSearch):
self.qualifiers = sdata.qualifiers self.qualifiers = sdata.qualifiers
self.countries = sdata.countries self.countries = sdata.countries
async def lookup(self, conn: SearchConnection, async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults: details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database. """ Find results for the search in the database.
@@ -387,7 +382,7 @@ class PoiSearch(AbstractSearch):
def _base_query() -> SaSelect: def _base_query() -> SaSelect:
return _select_placex(t) \ return _select_placex(t) \
.add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM)) .add_columns((-t.c.centroid.ST_Distance(NEAR_PARAM))
.label('importance'))\ .label('importance'))\
.where(t.c.linked_place_id == None) \ .where(t.c.linked_place_id == None) \
.where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \ .where(t.c.geometry.within_distance(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
.order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \ .order_by(t.c.centroid.ST_Distance(NEAR_PARAM)) \
@@ -396,9 +391,9 @@ class PoiSearch(AbstractSearch):
classtype = self.qualifiers.values classtype = self.qualifiers.values
if len(classtype) == 1: if len(classtype) == 1:
cclass, ctype = classtype[0] cclass, ctype = classtype[0]
sql: SaLambdaSelect = sa.lambda_stmt(lambda: _base_query() sql: SaLambdaSelect = sa.lambda_stmt(
.where(t.c.class_ == cclass) lambda: _base_query().where(t.c.class_ == cclass)
.where(t.c.type == ctype)) .where(t.c.type == ctype))
else: else:
sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ) sql = _base_query().where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
for cls, typ in classtype))) for cls, typ in classtype)))
@@ -455,7 +450,6 @@ class CountrySearch(AbstractSearch):
super().__init__(sdata.penalty) super().__init__(sdata.penalty)
self.countries = sdata.countries self.countries = sdata.countries
async def lookup(self, conn: SearchConnection, async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults: details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database. """ Find results for the search in the database.
@@ -464,9 +458,9 @@ class CountrySearch(AbstractSearch):
ccodes = self.countries.values ccodes = self.countries.values
sql = _select_placex(t)\ sql = _select_placex(t)\
.add_columns(t.c.importance)\ .add_columns(t.c.importance)\
.where(t.c.country_code.in_(ccodes))\ .where(t.c.country_code.in_(ccodes))\
.where(t.c.rank_address == 4) .where(t.c.rank_address == 4)
if details.geometry_output: if details.geometry_output:
sql = _add_geometry_columns(sql, t.c.geometry, details) sql = _add_geometry_columns(sql, t.c.geometry, details)
@@ -493,7 +487,6 @@ class CountrySearch(AbstractSearch):
return results return results
async def lookup_in_country_table(self, conn: SearchConnection, async def lookup_in_country_table(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults: details: SearchDetails) -> nres.SearchResults:
""" Look up the country in the fallback country tables. """ Look up the country in the fallback country tables.
@@ -509,7 +502,7 @@ class CountrySearch(AbstractSearch):
sql = sa.select(tgrid.c.country_code, sql = sa.select(tgrid.c.country_code,
tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid() tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid()
.label('centroid'), .label('centroid'),
tgrid.c.geometry.ST_Collect().ST_Expand(0).label('bbox'))\ tgrid.c.geometry.ST_Collect().ST_Expand(0).label('bbox'))\
.where(tgrid.c.country_code.in_(self.countries.values))\ .where(tgrid.c.country_code.in_(self.countries.values))\
.group_by(tgrid.c.country_code) .group_by(tgrid.c.country_code)
@@ -537,7 +530,6 @@ class CountrySearch(AbstractSearch):
return results return results
class PostcodeSearch(AbstractSearch): class PostcodeSearch(AbstractSearch):
""" Search for a postcode. """ Search for a postcode.
""" """
@@ -548,7 +540,6 @@ class PostcodeSearch(AbstractSearch):
self.lookups = sdata.lookups self.lookups = sdata.lookups
self.rankings = sdata.rankings self.rankings = sdata.rankings
async def lookup(self, conn: SearchConnection, async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults: details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database. """ Find results for the search in the database.
@@ -588,14 +579,13 @@ class PostcodeSearch(AbstractSearch):
tsearch = conn.t.search_name tsearch = conn.t.search_name
sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\ sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
.where((tsearch.c.name_vector + tsearch.c.nameaddress_vector) .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
.contains(sa.type_coerce(self.lookups[0].tokens, .contains(sa.type_coerce(self.lookups[0].tokens,
IntArray))) IntArray)))
for ranking in self.rankings: for ranking in self.rankings:
penalty += ranking.sql_penalty(conn.t.search_name) penalty += ranking.sql_penalty(conn.t.search_name)
penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes), penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes),
else_=1.0) else_=1.0)
sql = sql.add_columns(penalty.label('accuracy')) sql = sql.add_columns(penalty.label('accuracy'))
sql = sql.order_by('accuracy').limit(LIMIT_PARAM) sql = sql.order_by('accuracy').limit(LIMIT_PARAM)
@@ -603,13 +593,14 @@ class PostcodeSearch(AbstractSearch):
results = nres.SearchResults() results = nres.SearchResults()
for row in await conn.execute(sql, _details_to_bind_params(details)): for row in await conn.execute(sql, _details_to_bind_params(details)):
p = conn.t.placex p = conn.t.placex
placex_sql = _select_placex(p).add_columns(p.c.importance)\ placex_sql = _select_placex(p)\
.where(sa.text("""class = 'boundary' .add_columns(p.c.importance)\
AND type = 'postal_code' .where(sa.text("""class = 'boundary'
AND osm_type = 'R'"""))\ AND type = 'postal_code'
.where(p.c.country_code == row.country_code)\ AND osm_type = 'R'"""))\
.where(p.c.postcode == row.postcode)\ .where(p.c.country_code == row.country_code)\
.limit(1) .where(p.c.postcode == row.postcode)\
.limit(1)
if details.geometry_output: if details.geometry_output:
placex_sql = _add_geometry_columns(placex_sql, p.c.geometry, details) placex_sql = _add_geometry_columns(placex_sql, p.c.geometry, details)
@@ -630,7 +621,6 @@ class PostcodeSearch(AbstractSearch):
return results return results
class PlaceSearch(AbstractSearch): class PlaceSearch(AbstractSearch):
""" Generic search for an address or named place. """ Generic search for an address or named place.
""" """
@@ -646,7 +636,6 @@ class PlaceSearch(AbstractSearch):
self.rankings = sdata.rankings self.rankings = sdata.rankings
self.expected_count = expected_count self.expected_count = expected_count
def _inner_search_name_cte(self, conn: SearchConnection, def _inner_search_name_cte(self, conn: SearchConnection,
details: SearchDetails) -> 'sa.CTE': details: SearchDetails) -> 'sa.CTE':
""" Create a subquery that preselects the rows in the search_name """ Create a subquery that preselects the rows in the search_name
@@ -699,7 +688,7 @@ class PlaceSearch(AbstractSearch):
NEAR_RADIUS_PARAM)) NEAR_RADIUS_PARAM))
else: else:
sql = sql.where(t.c.centroid sql = sql.where(t.c.centroid
.ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM) .ST_Distance(NEAR_PARAM) < NEAR_RADIUS_PARAM)
if self.housenumbers: if self.housenumbers:
sql = sql.where(t.c.address_rank.between(16, 30)) sql = sql.where(t.c.address_rank.between(16, 30))
@@ -727,8 +716,8 @@ class PlaceSearch(AbstractSearch):
and (details.near is None or details.near_radius is not None)\ and (details.near is None or details.near_radius is not None)\
and not self.qualifiers: and not self.qualifiers:
sql = sql.add_columns(sa.func.first_value(inner.c.penalty - inner.c.importance) sql = sql.add_columns(sa.func.first_value(inner.c.penalty - inner.c.importance)
.over(order_by=inner.c.penalty - inner.c.importance) .over(order_by=inner.c.penalty - inner.c.importance)
.label('min_penalty')) .label('min_penalty'))
inner = sql.subquery() inner = sql.subquery()
@@ -739,7 +728,6 @@ class PlaceSearch(AbstractSearch):
return sql.cte('searches') return sql.cte('searches')
async def lookup(self, conn: SearchConnection, async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults: details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database. """ Find results for the search in the database.
@@ -759,8 +747,8 @@ class PlaceSearch(AbstractSearch):
pcs = self.postcodes.values pcs = self.postcodes.values
pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(t.c.centroid)))\ pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(t.c.centroid)))\
.where(tpc.c.postcode.in_(pcs))\ .where(tpc.c.postcode.in_(pcs))\
.scalar_subquery() .scalar_subquery()
penalty += sa.case((t.c.postcode.in_(pcs), 0.0), penalty += sa.case((t.c.postcode.in_(pcs), 0.0),
else_=sa.func.coalesce(pc_near, cast(SaColumn, 2.0))) else_=sa.func.coalesce(pc_near, cast(SaColumn, 2.0)))
@@ -771,13 +759,12 @@ class PlaceSearch(AbstractSearch):
if details.near is not None: if details.near is not None:
sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM)) sql = sql.add_columns((-tsearch.c.centroid.ST_Distance(NEAR_PARAM))
.label('importance')) .label('importance'))
sql = sql.order_by(sa.desc(sa.text('importance'))) sql = sql.order_by(sa.desc(sa.text('importance')))
else: else:
sql = sql.order_by(penalty - tsearch.c.importance) sql = sql.order_by(penalty - tsearch.c.importance)
sql = sql.add_columns(tsearch.c.importance) sql = sql.add_columns(tsearch.c.importance)
sql = sql.add_columns(penalty.label('accuracy'))\ sql = sql.add_columns(penalty.label('accuracy'))\
.order_by(sa.text('accuracy')) .order_by(sa.text('accuracy'))
@@ -814,7 +801,7 @@ class PlaceSearch(AbstractSearch):
tiger_sql = sa.case((inner.c.country_code == 'us', tiger_sql = sa.case((inner.c.country_code == 'us',
_make_interpolation_subquery(conn.t.tiger, inner, _make_interpolation_subquery(conn.t.tiger, inner,
numerals, details) numerals, details)
), else_=None) ), else_=None)
else: else:
interpol_sql = sa.null() interpol_sql = sa.null()
tiger_sql = sa.null() tiger_sql = sa.null()
@@ -868,7 +855,7 @@ class PlaceSearch(AbstractSearch):
if (not details.excluded or result.place_id not in details.excluded)\ if (not details.excluded or result.place_id not in details.excluded)\
and (not self.qualifiers or result.category in self.qualifiers.values)\ and (not self.qualifiers or result.category in self.qualifiers.values)\
and result.rank_address >= details.min_rank: and result.rank_address >= details.min_rank:
result.accuracy += 1.0 # penalty for missing housenumber result.accuracy += 1.0 # penalty for missing housenumber
results.append(result) results.append(result)
else: else:
results.append(result) results.append(result)

View File

@@ -23,6 +23,7 @@ from .db_searches import AbstractSearch
from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer from .query_analyzer_factory import make_query_analyzer, AbstractQueryAnalyzer
from .query import Phrase, QueryStruct from .query import Phrase, QueryStruct
class ForwardGeocoder: class ForwardGeocoder:
""" Main class responsible for place search. """ Main class responsible for place search.
""" """
@@ -34,14 +35,12 @@ class ForwardGeocoder:
self.timeout = dt.timedelta(seconds=timeout or 1000000) self.timeout = dt.timedelta(seconds=timeout or 1000000)
self.query_analyzer: Optional[AbstractQueryAnalyzer] = None self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
@property @property
def limit(self) -> int: def limit(self) -> int:
""" Return the configured maximum number of search results. """ Return the configured maximum number of search results.
""" """
return self.params.max_results return self.params.max_results
async def build_searches(self, async def build_searches(self,
phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]: phrases: List[Phrase]) -> Tuple[QueryStruct, List[AbstractSearch]]:
""" Analyse the query and return the tokenized query and list of """ Analyse the query and return the tokenized query and list of
@@ -68,7 +67,6 @@ class ForwardGeocoder:
return query, searches return query, searches
async def execute_searches(self, query: QueryStruct, async def execute_searches(self, query: QueryStruct,
searches: List[AbstractSearch]) -> SearchResults: searches: List[AbstractSearch]) -> SearchResults:
""" Run the abstract searches against the database until a result """ Run the abstract searches against the database until a result
@@ -103,7 +101,6 @@ class ForwardGeocoder:
return SearchResults(results.values()) return SearchResults(results.values())
def pre_filter_results(self, results: SearchResults) -> SearchResults: def pre_filter_results(self, results: SearchResults) -> SearchResults:
""" Remove results that are significantly worse than the """ Remove results that are significantly worse than the
best match. best match.
@@ -114,7 +111,6 @@ class ForwardGeocoder:
return results return results
def sort_and_cut_results(self, results: SearchResults) -> SearchResults: def sort_and_cut_results(self, results: SearchResults) -> SearchResults:
""" Remove badly matching results, sort by ranking and """ Remove badly matching results, sort by ranking and
limit to the configured number of results. limit to the configured number of results.
@@ -124,21 +120,20 @@ class ForwardGeocoder:
min_rank = results[0].rank_search min_rank = results[0].rank_search
min_ranking = results[0].ranking min_ranking = results[0].ranking
results = SearchResults(r for r in results results = SearchResults(r for r in results
if r.ranking + 0.03 * (r.rank_search - min_rank) if (r.ranking + 0.03 * (r.rank_search - min_rank)
< min_ranking + 0.5) < min_ranking + 0.5))
results = SearchResults(results[:self.limit]) results = SearchResults(results[:self.limit])
return results return results
def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None: def rerank_by_query(self, query: QueryStruct, results: SearchResults) -> None:
""" Adjust the accuracy of the localized result according to how well """ Adjust the accuracy of the localized result according to how well
they match the original query. they match the original query.
""" """
assert self.query_analyzer is not None assert self.query_analyzer is not None
qwords = [word for phrase in query.source qwords = [word for phrase in query.source
for word in re.split('[, ]+', phrase.text) if word] for word in re.split('[, ]+', phrase.text) if word]
if not qwords: if not qwords:
return return
@@ -167,7 +162,6 @@ class ForwardGeocoder:
distance *= 2 distance *= 2
result.accuracy += distance * 0.4 / sum(len(w) for w in qwords) result.accuracy += distance * 0.4 / sum(len(w) for w in qwords)
async def lookup_pois(self, categories: List[Tuple[str, str]], async def lookup_pois(self, categories: List[Tuple[str, str]],
phrases: List[Phrase]) -> SearchResults: phrases: List[Phrase]) -> SearchResults:
""" Look up places by category. If phrase is given, a place search """ Look up places by category. If phrase is given, a place search
@@ -197,7 +191,6 @@ class ForwardGeocoder:
return results return results
async def lookup(self, phrases: List[Phrase]) -> SearchResults: async def lookup(self, phrases: List[Phrase]) -> SearchResults:
""" Look up a single free-text query. """ Look up a single free-text query.
""" """
@@ -223,7 +216,6 @@ class ForwardGeocoder:
return results return results
# pylint: disable=invalid-name,too-many-locals
def _dump_searches(searches: List[AbstractSearch], query: QueryStruct, def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
start: int = 0) -> Iterator[Optional[List[Any]]]: start: int = 0) -> Iterator[Optional[List[Any]]]:
yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries', yield ['Penalty', 'Lookups', 'Housenr', 'Postcode', 'Countries',
@@ -242,12 +234,11 @@ def _dump_searches(searches: List[AbstractSearch], query: QueryStruct,
ranks = ranks[:100] + '...' ranks = ranks[:100] + '...'
return f"{f.column}({ranks},def={f.default:.3g})" return f"{f.column}({ranks},def={f.default:.3g})"
def fmt_lookup(l: Any) -> str: def fmt_lookup(lk: Any) -> str:
if not l: if not lk:
return '' return ''
return f"{l.lookup_type}({l.column}{tk(l.tokens)})" return f"{lk.lookup_type}({lk.column}{tk(lk.tokens)})"
def fmt_cstr(c: Any) -> str: def fmt_cstr(c: Any) -> str:
if not c: if not c:

View File

@@ -48,6 +48,7 @@ class QueryPart(NamedTuple):
QueryParts = List[QueryPart] QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]] WordDict = Dict[str, List[qmod.TokenRange]]
def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]: def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
""" Return all combinations of words in the terms list after the """ Return all combinations of words in the terms list after the
given position. given position.
@@ -72,7 +73,6 @@ class ICUToken(qmod.Token):
assert self.info assert self.info
return self.info.get('class', ''), self.info.get('type', '') return self.info.get('class', ''), self.info.get('type', '')
def rematch(self, norm: str) -> None: def rematch(self, norm: str) -> None:
""" Check how well the token matches the given normalized string """ Check how well the token matches the given normalized string
and add a penalty, if necessary. and add a penalty, if necessary.
@@ -91,7 +91,6 @@ class ICUToken(qmod.Token):
distance += abs((ato-afrom) - (bto-bfrom)) distance += abs((ato-afrom) - (bto-bfrom))
self.penalty += (distance/len(self.lookup_word)) self.penalty += (distance/len(self.lookup_word))
@staticmethod @staticmethod
def from_db_row(row: SaRow) -> 'ICUToken': def from_db_row(row: SaRow) -> 'ICUToken':
""" Create a ICUToken from the row of the word table. """ Create a ICUToken from the row of the word table.
@@ -128,16 +127,13 @@ class ICUToken(qmod.Token):
addr_count=max(1, addr_count)) addr_count=max(1, addr_count))
class ICUQueryAnalyzer(AbstractQueryAnalyzer): class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" Converter for query strings into a tokenized query """ Converter for query strings into a tokenized query
using the tokens created by a ICU tokenizer. using the tokens created by a ICU tokenizer.
""" """
def __init__(self, conn: SearchConnection) -> None: def __init__(self, conn: SearchConnection) -> None:
self.conn = conn self.conn = conn
async def setup(self) -> None: async def setup(self) -> None:
""" Set up static data structures needed for the analysis. """ Set up static data structures needed for the analysis.
""" """
@@ -163,7 +159,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
sa.Column('word', sa.Text), sa.Column('word', sa.Text),
sa.Column('info', Json)) sa.Column('info', Json))
async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct: async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
""" Analyze the given list of phrases and return the """ Analyze the given list of phrases and return the
tokenized query. tokenized query.
@@ -202,7 +197,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
return query return query
def normalize_text(self, text: str) -> str: def normalize_text(self, text: str) -> str:
""" Bring the given text into a normalized form. That is the """ Bring the given text into a normalized form. That is the
standardized form search will work with. All information removed standardized form search will work with. All information removed
@@ -210,7 +204,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" """
return cast(str, self.normalizer.transliterate(text)) return cast(str, self.normalizer.transliterate(text))
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]: def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
""" Transliterate the phrases and split them into tokens. """ Transliterate the phrases and split them into tokens.
@@ -243,7 +236,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
return parts, words return parts, words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the """ Return the token information from the database for the
given word tokens. given word tokens.
@@ -251,7 +243,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
t = self.conn.t.meta.tables['word'] t = self.conn.t.meta.tables['word']
return await self.conn.execute(t.select().where(t.c.word_token.in_(words))) return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add tokens to query that are not saved in the database. """ Add tokens to query that are not saved in the database.
""" """
@@ -263,7 +254,6 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
count=1, addr_count=1, lookup_word=part.token, count=1, addr_count=1, lookup_word=part.token,
word_token=part.token, info=None)) word_token=part.token, info=None))
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add penalties to tokens that depend on presence of other token. """ Add penalties to tokens that depend on presence of other token.
""" """
@@ -274,8 +264,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
and (repl.ttype != qmod.TokenType.HOUSENUMBER and (repl.ttype != qmod.TokenType.HOUSENUMBER
or len(tlist.tokens[0].lookup_word) > 4): or len(tlist.tokens[0].lookup_word) > 4):
repl.add_penalty(0.39) repl.add_penalty(0.39)
elif tlist.ttype == qmod.TokenType.HOUSENUMBER \ elif (tlist.ttype == qmod.TokenType.HOUSENUMBER
and len(tlist.tokens[0].lookup_word) <= 3: and len(tlist.tokens[0].lookup_word) <= 3):
if any(c.isdigit() for c in tlist.tokens[0].lookup_word): if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
for repl in node.starting: for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER: if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER:

View File

@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
import dataclasses import dataclasses
import enum import enum
class BreakType(enum.Enum): class BreakType(enum.Enum):
""" Type of break between tokens. """ Type of break between tokens.
""" """
@@ -102,13 +103,13 @@ class Token(ABC):
addr_count: int addr_count: int
lookup_word: str lookup_word: str
@abstractmethod @abstractmethod
def get_category(self) -> Tuple[str, str]: def get_category(self) -> Tuple[str, str]:
""" Return the category restriction for qualifier terms and """ Return the category restriction for qualifier terms and
category objects. category objects.
""" """
@dataclasses.dataclass @dataclasses.dataclass
class TokenRange: class TokenRange:
""" Indexes of query nodes over which a token spans. """ Indexes of query nodes over which a token spans.
@@ -119,31 +120,25 @@ class TokenRange:
def __lt__(self, other: 'TokenRange') -> bool: def __lt__(self, other: 'TokenRange') -> bool:
return self.end <= other.start return self.end <= other.start
def __le__(self, other: 'TokenRange') -> bool: def __le__(self, other: 'TokenRange') -> bool:
return NotImplemented return NotImplemented
def __gt__(self, other: 'TokenRange') -> bool: def __gt__(self, other: 'TokenRange') -> bool:
return self.start >= other.end return self.start >= other.end
def __ge__(self, other: 'TokenRange') -> bool: def __ge__(self, other: 'TokenRange') -> bool:
return NotImplemented return NotImplemented
def replace_start(self, new_start: int) -> 'TokenRange': def replace_start(self, new_start: int) -> 'TokenRange':
""" Return a new token range with the new start. """ Return a new token range with the new start.
""" """
return TokenRange(new_start, self.end) return TokenRange(new_start, self.end)
def replace_end(self, new_end: int) -> 'TokenRange': def replace_end(self, new_end: int) -> 'TokenRange':
""" Return a new token range with the new end. """ Return a new token range with the new end.
""" """
return TokenRange(self.start, new_end) return TokenRange(self.start, new_end)
def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']: def split(self, index: int) -> Tuple['TokenRange', 'TokenRange']:
""" Split the span into two spans at the given index. """ Split the span into two spans at the given index.
The index must be within the span. The index must be within the span.
@@ -159,7 +154,6 @@ class TokenList:
ttype: TokenType ttype: TokenType
tokens: List[Token] tokens: List[Token]
def add_penalty(self, penalty: float) -> None: def add_penalty(self, penalty: float) -> None:
""" Add the given penalty to all tokens in the list. """ Add the given penalty to all tokens in the list.
""" """
@@ -181,7 +175,6 @@ class QueryNode:
""" """
return any(tl.end == end and tl.ttype in ttypes for tl in self.starting) return any(tl.end == end and tl.ttype in ttypes for tl in self.starting)
def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]: def get_tokens(self, end: int, ttype: TokenType) -> Optional[List[Token]]:
""" Get the list of tokens of the given type starting at this node """ Get the list of tokens of the given type starting at this node
and ending at the node 'end'. Returns 'None' if no such and ending at the node 'end'. Returns 'None' if no such
@@ -220,13 +213,11 @@ class QueryStruct:
self.nodes: List[QueryNode] = \ self.nodes: List[QueryNode] = \
[QueryNode(BreakType.START, source[0].ptype if source else PhraseType.NONE)] [QueryNode(BreakType.START, source[0].ptype if source else PhraseType.NONE)]
def num_token_slots(self) -> int: def num_token_slots(self) -> int:
""" Return the length of the query in vertice steps. """ Return the length of the query in vertice steps.
""" """
return len(self.nodes) - 1 return len(self.nodes) - 1
def add_node(self, btype: BreakType, ptype: PhraseType) -> None: def add_node(self, btype: BreakType, ptype: PhraseType) -> None:
""" Append a new break node with the given break type. """ Append a new break node with the given break type.
The phrase type denotes the type for any tokens starting The phrase type denotes the type for any tokens starting
@@ -234,7 +225,6 @@ class QueryStruct:
""" """
self.nodes.append(QueryNode(btype, ptype)) self.nodes.append(QueryNode(btype, ptype))
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None: def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
""" Add a token to the query. 'start' and 'end' are the indexes of the """ Add a token to the query. 'start' and 'end' are the indexes of the
nodes from which to which the token spans. The indexes must exist nodes from which to which the token spans. The indexes must exist
@@ -247,7 +237,7 @@ class QueryStruct:
""" """
snode = self.nodes[trange.start] snode = self.nodes[trange.start]
full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\ full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END) and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
if snode.ptype.compatible_with(ttype, full_phrase): if snode.ptype.compatible_with(ttype, full_phrase):
tlist = snode.get_tokens(trange.end, ttype) tlist = snode.get_tokens(trange.end, ttype)
if tlist is None: if tlist is None:
@@ -255,7 +245,6 @@ class QueryStruct:
else: else:
tlist.append(token) tlist.append(token)
def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]: def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
""" Get the list of tokens of a given type, spanning the given """ Get the list of tokens of a given type, spanning the given
nodes. The nodes must exist. If no tokens exist, an nodes. The nodes must exist. If no tokens exist, an
@@ -263,7 +252,6 @@ class QueryStruct:
""" """
return self.nodes[trange.start].get_tokens(trange.end, ttype) or [] return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
def get_partials_list(self, trange: TokenRange) -> List[Token]: def get_partials_list(self, trange: TokenRange) -> List[Token]:
""" Create a list of partial tokens between the given nodes. """ Create a list of partial tokens between the given nodes.
The list is composed of the first token of type PARTIAL The list is composed of the first token of type PARTIAL
@@ -271,8 +259,7 @@ class QueryStruct:
assumed to exist. assumed to exist.
""" """
return [next(iter(self.get_tokens(TokenRange(i, i+1), TokenType.PARTIAL))) return [next(iter(self.get_tokens(TokenRange(i, i+1), TokenType.PARTIAL)))
for i in range(trange.start, trange.end)] for i in range(trange.start, trange.end)]
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]: def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
""" Iterator over all token lists in the query. """ Iterator over all token lists in the query.
@@ -281,7 +268,6 @@ class QueryStruct:
for tlist in node.starting: for tlist in node.starting:
yield i, node, tlist yield i, node, tlist
def find_lookup_word_by_id(self, token: int) -> str: def find_lookup_word_by_id(self, token: int) -> str:
""" Find the first token with the given token ID and return """ Find the first token with the given token ID and return
its lookup word. Returns 'None' if no such token exists. its lookup word. Returns 'None' if no such token exists.

View File

@@ -18,6 +18,7 @@ from ..connection import SearchConnection
if TYPE_CHECKING: if TYPE_CHECKING:
from .query import Phrase, QueryStruct from .query import Phrase, QueryStruct
class AbstractQueryAnalyzer(ABC): class AbstractQueryAnalyzer(ABC):
""" Class for analysing incoming queries. """ Class for analysing incoming queries.
@@ -29,7 +30,6 @@ class AbstractQueryAnalyzer(ABC):
""" Analyze the given phrases and return the tokenized query. """ Analyze the given phrases and return the tokenized query.
""" """
@abstractmethod @abstractmethod
def normalize_text(self, text: str) -> str: def normalize_text(self, text: str) -> str:
""" Bring the given text into a normalized form. That is the """ Bring the given text into a normalized form. That is the
@@ -38,7 +38,6 @@ class AbstractQueryAnalyzer(ABC):
""" """
async def make_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer: async def make_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
""" Create a query analyzer for the tokenizer used by the database. """ Create a query analyzer for the tokenizer used by the database.
""" """

View File

@@ -14,7 +14,6 @@ import dataclasses
from ..logging import log from ..logging import log
from . import query as qmod from . import query as qmod
# pylint: disable=too-many-return-statements,too-many-branches
@dataclasses.dataclass @dataclasses.dataclass
class TypedRange: class TypedRange:
@@ -35,8 +34,9 @@ PENALTY_TOKENCHANGE = {
TypedRangeSeq = List[TypedRange] TypedRangeSeq = List[TypedRange]
@dataclasses.dataclass @dataclasses.dataclass
class TokenAssignment: # pylint: disable=too-many-instance-attributes class TokenAssignment:
""" Representation of a possible assignment of token types """ Representation of a possible assignment of token types
to the tokens in a tokenized query. to the tokens in a tokenized query.
""" """
@@ -49,7 +49,6 @@ class TokenAssignment: # pylint: disable=too-many-instance-attributes
near_item: Optional[qmod.TokenRange] = None near_item: Optional[qmod.TokenRange] = None
qualifier: Optional[qmod.TokenRange] = None qualifier: Optional[qmod.TokenRange] = None
@staticmethod @staticmethod
def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment': def from_ranges(ranges: TypedRangeSeq) -> 'TokenAssignment':
""" Create a new token assignment from a sequence of typed spans. """ Create a new token assignment from a sequence of typed spans.
@@ -83,34 +82,29 @@ class _TokenSequence:
self.direction = direction self.direction = direction
self.penalty = penalty self.penalty = penalty
def __str__(self) -> str: def __str__(self) -> str:
seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq) seq = ''.join(f'[{r.trange.start} - {r.trange.end}: {r.ttype.name}]' for r in self.seq)
return f'{seq} (dir: {self.direction}, penalty: {self.penalty})' return f'{seq} (dir: {self.direction}, penalty: {self.penalty})'
@property @property
def end_pos(self) -> int: def end_pos(self) -> int:
""" Return the index of the global end of the current sequence. """ Return the index of the global end of the current sequence.
""" """
return self.seq[-1].trange.end if self.seq else 0 return self.seq[-1].trange.end if self.seq else 0
def has_types(self, *ttypes: qmod.TokenType) -> bool: def has_types(self, *ttypes: qmod.TokenType) -> bool:
""" Check if the current sequence contains any typed ranges of """ Check if the current sequence contains any typed ranges of
the given types. the given types.
""" """
return any(s.ttype in ttypes for s in self.seq) return any(s.ttype in ttypes for s in self.seq)
def is_final(self) -> bool: def is_final(self) -> bool:
""" Return true when the sequence cannot be extended by any """ Return true when the sequence cannot be extended by any
form of token anymore. form of token anymore.
""" """
# Country and category must be the final term for left-to-right # Country and category must be the final term for left-to-right
return len(self.seq) > 1 and \ return len(self.seq) > 1 and \
self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM) self.seq[-1].ttype in (qmod.TokenType.COUNTRY, qmod.TokenType.NEAR_ITEM)
def appendable(self, ttype: qmod.TokenType) -> Optional[int]: def appendable(self, ttype: qmod.TokenType) -> Optional[int]:
""" Check if the give token type is appendable to the existing sequence. """ Check if the give token type is appendable to the existing sequence.
@@ -149,10 +143,10 @@ class _TokenSequence:
return None return None
if len(self.seq) > 2 \ if len(self.seq) > 2 \
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY): or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY):
return None # direction left-to-right: housenumber must come before anything return None # direction left-to-right: housenumber must come before anything
elif self.direction == -1 \ elif (self.direction == -1
or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY): or self.has_types(qmod.TokenType.POSTCODE, qmod.TokenType.COUNTRY)):
return -1 # force direction right-to-left if after other terms return -1 # force direction right-to-left if after other terms
return self.direction return self.direction
@@ -196,7 +190,6 @@ class _TokenSequence:
return None return None
def advance(self, ttype: qmod.TokenType, end_pos: int, def advance(self, ttype: qmod.TokenType, end_pos: int,
btype: qmod.BreakType) -> Optional['_TokenSequence']: btype: qmod.BreakType) -> Optional['_TokenSequence']:
""" Return a new token sequence state with the given token type """ Return a new token sequence state with the given token type
@@ -223,7 +216,6 @@ class _TokenSequence:
return _TokenSequence(newseq, newdir, self.penalty + new_penalty) return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool: def _adapt_penalty_from_priors(self, priors: int, new_dir: int) -> bool:
if priors >= 2: if priors >= 2:
if self.direction == 0: if self.direction == 0:
@@ -236,7 +228,6 @@ class _TokenSequence:
return True return True
def recheck_sequence(self) -> bool: def recheck_sequence(self) -> bool:
""" Check that the sequence is a fully valid token assignment """ Check that the sequence is a fully valid token assignment
and adapt direction and penalties further if necessary. and adapt direction and penalties further if necessary.
@@ -264,9 +255,8 @@ class _TokenSequence:
return True return True
def _get_assignments_postcode(self, base: TokenAssignment, def _get_assignments_postcode(self, base: TokenAssignment,
query_len: int) -> Iterator[TokenAssignment]: query_len: int) -> Iterator[TokenAssignment]:
""" Yield possible assignments of Postcode searches with an """ Yield possible assignments of Postcode searches with an
address component. address component.
""" """
@@ -278,13 +268,12 @@ class _TokenSequence:
# <address>,<postcode> should give preference to address search # <address>,<postcode> should give preference to address search
if base.postcode.start == 0: if base.postcode.start == 0:
penalty = self.penalty penalty = self.penalty
self.direction = -1 # name searches are only possible backwards self.direction = -1 # name searches are only possible backwards
else: else:
penalty = self.penalty + 0.1 penalty = self.penalty + 0.1
self.direction = 1 # name searches are only possible forwards self.direction = 1 # name searches are only possible forwards
yield dataclasses.replace(base, penalty=penalty) yield dataclasses.replace(base, penalty=penalty)
def _get_assignments_address_forward(self, base: TokenAssignment, def _get_assignments_address_forward(self, base: TokenAssignment,
query: qmod.QueryStruct) -> Iterator[TokenAssignment]: query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments of address searches with """ Yield possible assignments of address searches with
@@ -320,7 +309,6 @@ class _TokenSequence:
yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:], yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype]) penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def _get_assignments_address_backward(self, base: TokenAssignment, def _get_assignments_address_backward(self, base: TokenAssignment,
query: qmod.QueryStruct) -> Iterator[TokenAssignment]: query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments of address searches with """ Yield possible assignments of address searches with
@@ -355,7 +343,6 @@ class _TokenSequence:
yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr], yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype]) penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]: def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
""" Yield possible assignments for the current sequence. """ Yield possible assignments for the current sequence.

View File

@@ -16,13 +16,13 @@ from ..core import NominatimAPIAsync
from ..result_formatting import FormatDispatcher from ..result_formatting import FormatDispatcher
from .content_types import CONTENT_TEXT from .content_types import CONTENT_TEXT
class ASGIAdaptor(abc.ABC): class ASGIAdaptor(abc.ABC):
""" Adapter class for the different ASGI frameworks. """ Adapter class for the different ASGI frameworks.
Wraps functionality over concrete requests and responses. Wraps functionality over concrete requests and responses.
""" """
content_type: str = CONTENT_TEXT content_type: str = CONTENT_TEXT
@abc.abstractmethod @abc.abstractmethod
def get(self, name: str, default: Optional[str] = None) -> Optional[str]: def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
""" Return an input parameter as a string. If the parameter was """ Return an input parameter as a string. If the parameter was
@@ -35,14 +35,12 @@ class ASGIAdaptor(abc.ABC):
not provided, return the 'default' value. not provided, return the 'default' value.
""" """
@abc.abstractmethod @abc.abstractmethod
def error(self, msg: str, status: int = 400) -> Exception: def error(self, msg: str, status: int = 400) -> Exception:
""" Construct an appropriate exception from the given error message. """ Construct an appropriate exception from the given error message.
The exception must result in a HTTP error with the given status. The exception must result in a HTTP error with the given status.
""" """
@abc.abstractmethod @abc.abstractmethod
def create_response(self, status: int, output: str, num_results: int) -> Any: def create_response(self, status: int, output: str, num_results: int) -> Any:
""" Create a response from the given parameters. The result will """ Create a response from the given parameters. The result will
@@ -55,25 +53,21 @@ class ASGIAdaptor(abc.ABC):
body of the response to 'output'. body of the response to 'output'.
""" """
@abc.abstractmethod @abc.abstractmethod
def base_uri(self) -> str: def base_uri(self) -> str:
""" Return the URI of the original request. """ Return the URI of the original request.
""" """
@abc.abstractmethod @abc.abstractmethod
def config(self) -> Configuration: def config(self) -> Configuration:
""" Return the current configuration object. """ Return the current configuration object.
""" """
@abc.abstractmethod @abc.abstractmethod
def formatting(self) -> FormatDispatcher: def formatting(self) -> FormatDispatcher:
""" Return the formatting object to use. """ Return the formatting object to use.
""" """
def get_int(self, name: str, default: Optional[int] = None) -> int: def get_int(self, name: str, default: Optional[int] = None) -> int:
""" Return an input parameter as an int. Raises an exception if """ Return an input parameter as an int. Raises an exception if
the parameter is given but not in an integer format. the parameter is given but not in an integer format.
@@ -97,7 +91,6 @@ class ASGIAdaptor(abc.ABC):
return intval return intval
def get_float(self, name: str, default: Optional[float] = None) -> float: def get_float(self, name: str, default: Optional[float] = None) -> float:
""" Return an input parameter as a flaoting-point number. Raises an """ Return an input parameter as a flaoting-point number. Raises an
exception if the parameter is given but not in an float format. exception if the parameter is given but not in an float format.
@@ -124,7 +117,6 @@ class ASGIAdaptor(abc.ABC):
return fval return fval
def get_bool(self, name: str, default: Optional[bool] = None) -> bool: def get_bool(self, name: str, default: Optional[bool] = None) -> bool:
""" Return an input parameter as bool. Only '0' is accepted as """ Return an input parameter as bool. Only '0' is accepted as
an input for 'false' all other inputs will be interpreted as 'true'. an input for 'false' all other inputs will be interpreted as 'true'.
@@ -143,7 +135,6 @@ class ASGIAdaptor(abc.ABC):
return value != '0' return value != '0'
def raise_error(self, msg: str, status: int = 400) -> NoReturn: def raise_error(self, msg: str, status: int = 400) -> NoReturn:
""" Raise an exception resulting in the given HTTP status and """ Raise an exception resulting in the given HTTP status and
message. The message will be formatted according to the message. The message will be formatted according to the

View File

@@ -21,6 +21,7 @@ from ...result_formatting import FormatDispatcher, load_format_dispatcher
from ... import logging as loglib from ... import logging as loglib
from ..asgi_adaptor import ASGIAdaptor, EndpointFunc from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
class HTTPNominatimError(Exception): class HTTPNominatimError(Exception):
""" A special exception class for errors raised during processing. """ A special exception class for errors raised during processing.
""" """
@@ -30,7 +31,7 @@ class HTTPNominatimError(Exception):
self.content_type = content_type self.content_type = content_type
async def nominatim_error_handler(req: Request, resp: Response, #pylint: disable=unused-argument async def nominatim_error_handler(req: Request, resp: Response,
exception: HTTPNominatimError, exception: HTTPNominatimError,
_: Any) -> None: _: Any) -> None:
""" Special error handler that passes message and content type as """ Special error handler that passes message and content type as
@@ -41,8 +42,8 @@ async def nominatim_error_handler(req: Request, resp: Response, #pylint: disable
resp.content_type = exception.content_type resp.content_type = exception.content_type
async def timeout_error_handler(req: Request, resp: Response, #pylint: disable=unused-argument async def timeout_error_handler(req: Request, resp: Response,
exception: TimeoutError, #pylint: disable=unused-argument exception: TimeoutError,
_: Any) -> None: _: Any) -> None:
""" Special error handler that passes message and content type as """ Special error handler that passes message and content type as
per exception info. per exception info.
@@ -70,26 +71,21 @@ class ParamWrapper(ASGIAdaptor):
self._config = config self._config = config
self._formatter = formatter self._formatter = formatter
def get(self, name: str, default: Optional[str] = None) -> Optional[str]: def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
return self.request.get_param(name, default=default) return self.request.get_param(name, default=default)
def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]: def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
return self.request.get_header(name, default=default) return self.request.get_header(name, default=default)
def error(self, msg: str, status: int = 400) -> HTTPNominatimError: def error(self, msg: str, status: int = 400) -> HTTPNominatimError:
return HTTPNominatimError(msg, status, self.content_type) return HTTPNominatimError(msg, status, self.content_type)
def create_response(self, status: int, output: str, num_results: int) -> None: def create_response(self, status: int, output: str, num_results: int) -> None:
self.response.context.num_results = num_results self.response.context.num_results = num_results
self.response.status = status self.response.status = status
self.response.text = output self.response.text = output
self.response.content_type = self.content_type self.response.content_type = self.content_type
def base_uri(self) -> str: def base_uri(self) -> str:
return self.request.forwarded_prefix return self.request.forwarded_prefix
@@ -111,7 +107,6 @@ class EndpointWrapper:
self.api = api self.api = api
self.formatter = formatter self.formatter = formatter
async def on_get(self, req: Request, resp: Response) -> None: async def on_get(self, req: Request, resp: Response) -> None:
""" Implementation of the endpoint. """ Implementation of the endpoint.
""" """
@@ -124,15 +119,13 @@ class FileLoggingMiddleware:
""" """
def __init__(self, file_name: str): def __init__(self, file_name: str):
self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732 self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
async def process_request(self, req: Request, _: Response) -> None: async def process_request(self, req: Request, _: Response) -> None:
""" Callback before the request starts timing. """ Callback before the request starts timing.
""" """
req.context.start = dt.datetime.now(tz=dt.timezone.utc) req.context.start = dt.datetime.now(tz=dt.timezone.utc)
async def process_response(self, req: Request, resp: Response, async def process_response(self, req: Request, resp: Response,
resource: Optional[EndpointWrapper], resource: Optional[EndpointWrapper],
req_succeeded: bool) -> None: req_succeeded: bool) -> None:
@@ -140,7 +133,7 @@ class FileLoggingMiddleware:
writes logs for successful requests for search, reverse and lookup. writes logs for successful requests for search, reverse and lookup.
""" """
if not req_succeeded or resource is None or resp.status != 200\ if not req_succeeded or resource is None or resp.status != 200\
or resource.name not in ('reverse', 'search', 'lookup', 'details'): or resource.name not in ('reverse', 'search', 'lookup', 'details'):
return return
finish = dt.datetime.now(tz=dt.timezone.utc) finish = dt.datetime.now(tz=dt.timezone.utc)
@@ -183,7 +176,7 @@ def get_application(project_dir: Path,
app.add_error_handler(HTTPNominatimError, nominatim_error_handler) app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler) app.add_error_handler(TimeoutError, timeout_error_handler)
# different from TimeoutError in Python <= 3.10 # different from TimeoutError in Python <= 3.10
app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type] app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type]
legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS') legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', project_dir) formatter = load_format_dispatcher('v1', project_dir)

View File

@@ -28,6 +28,7 @@ from ...result_formatting import FormatDispatcher, load_format_dispatcher
from ..asgi_adaptor import ASGIAdaptor, EndpointFunc from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
from ... import logging as loglib from ... import logging as loglib
class ParamWrapper(ASGIAdaptor): class ParamWrapper(ASGIAdaptor):
""" Adaptor class for server glue to Starlette framework. """ Adaptor class for server glue to Starlette framework.
""" """
@@ -35,25 +36,20 @@ class ParamWrapper(ASGIAdaptor):
def __init__(self, request: Request) -> None: def __init__(self, request: Request) -> None:
self.request = request self.request = request
def get(self, name: str, default: Optional[str] = None) -> Optional[str]: def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
return self.request.query_params.get(name, default=default) return self.request.query_params.get(name, default=default)
def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]: def get_header(self, name: str, default: Optional[str] = None) -> Optional[str]:
return self.request.headers.get(name, default) return self.request.headers.get(name, default)
def error(self, msg: str, status: int = 400) -> HTTPException: def error(self, msg: str, status: int = 400) -> HTTPException:
return HTTPException(status, detail=msg, return HTTPException(status, detail=msg,
headers={'content-type': self.content_type}) headers={'content-type': self.content_type})
def create_response(self, status: int, output: str, num_results: int) -> Response: def create_response(self, status: int, output: str, num_results: int) -> Response:
self.request.state.num_results = num_results self.request.state.num_results = num_results
return Response(output, status_code=status, media_type=self.content_type) return Response(output, status_code=status, media_type=self.content_type)
def base_uri(self) -> str: def base_uri(self) -> str:
scheme = self.request.url.scheme scheme = self.request.url.scheme
host = self.request.url.hostname host = self.request.url.hostname
@@ -66,11 +62,9 @@ class ParamWrapper(ASGIAdaptor):
return f"{scheme}://{host}{root}" return f"{scheme}://{host}{root}"
def config(self) -> Configuration: def config(self) -> Configuration:
return cast(Configuration, self.request.app.state.API.config) return cast(Configuration, self.request.app.state.API.config)
def formatting(self) -> FormatDispatcher: def formatting(self) -> FormatDispatcher:
return cast(FormatDispatcher, self.request.app.state.API.formatter) return cast(FormatDispatcher, self.request.app.state.API.formatter)
@@ -89,7 +83,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
def __init__(self, app: Starlette, file_name: str = ''): def __init__(self, app: Starlette, file_name: str = ''):
super().__init__(app) super().__init__(app)
self.fd = open(file_name, 'a', buffering=1, encoding='utf8') # pylint: disable=R1732 self.fd = open(file_name, 'a', buffering=1, encoding='utf8')
async def dispatch(self, request: Request, async def dispatch(self, request: Request,
call_next: RequestResponseEndpoint) -> Response: call_next: RequestResponseEndpoint) -> Response:
@@ -118,7 +112,7 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
return response return response
async def timeout_error(request: Request, #pylint: disable=unused-argument async def timeout_error(request: Request,
_: Exception) -> Response: _: Exception) -> Response:
""" Error handler for query timeouts. """ Error handler for query timeouts.
""" """

View File

@@ -7,10 +7,10 @@
""" """
Import the base library to use with asynchronous SQLAlchemy. Import the base library to use with asynchronous SQLAlchemy.
""" """
# pylint: disable=invalid-name, ungrouped-imports, unused-import
from typing import Any from typing import Any
# flake8: noqa
try: try:
import sqlalchemy.dialects.postgresql.psycopg import sqlalchemy.dialects.postgresql.psycopg
import psycopg import psycopg

View File

@@ -15,7 +15,6 @@ from sqlalchemy.ext.compiler import compiles
from ..typing import SaColumn from ..typing import SaColumn
# pylint: disable=all
class PlacexGeometryReverseLookuppolygon(sa.sql.functions.GenericFunction[Any]): class PlacexGeometryReverseLookuppolygon(sa.sql.functions.GenericFunction[Any]):
""" Check for conditions that allow partial index use on """ Check for conditions that allow partial index use on
@@ -69,8 +68,8 @@ def default_reverse_place_diameter(element: IntersectsReverseDistance,
f" AND {table}.name is not null"\ f" AND {table}.name is not null"\
f" AND {table}.linked_place_id is null"\ f" AND {table}.linked_place_id is null"\
f" AND {table}.osm_type = 'N'" + \ f" AND {table}.osm_type = 'N'" + \
" AND ST_Buffer(%s, reverse_place_diameter(%s)) && %s)" % \ " AND ST_Buffer(%s, reverse_place_diameter(%s)) && %s)" \
tuple(map(lambda c: compiler.process(c, **kw), element.clauses)) % tuple(map(lambda c: compiler.process(c, **kw), element.clauses))
@compiles(IntersectsReverseDistance, 'sqlite') @compiles(IntersectsReverseDistance, 'sqlite')
@@ -79,17 +78,17 @@ def sqlite_reverse_place_diameter(element: IntersectsReverseDistance,
geom1, rank, geom2 = list(element.clauses) geom1, rank, geom2 = list(element.clauses)
table = element.tablename table = element.tablename
return (f"({table}.rank_address between 4 and 25"\ return (f"({table}.rank_address between 4 and 25"
f" AND {table}.type != 'postcode'"\ f" AND {table}.type != 'postcode'"
f" AND {table}.name is not null"\ f" AND {table}.name is not null"
f" AND {table}.linked_place_id is null"\ f" AND {table}.linked_place_id is null"
f" AND {table}.osm_type = 'N'"\ f" AND {table}.osm_type = 'N'"
" AND MbrIntersects(%s, ST_Expand(%s, 14.0 * exp(-0.2 * %s) - 0.03))"\ " AND MbrIntersects(%s, ST_Expand(%s, 14.0 * exp(-0.2 * %s) - 0.03))"
f" AND {table}.place_id IN"\ f" AND {table}.place_id IN"
" (SELECT place_id FROM placex_place_node_areas"\ " (SELECT place_id FROM placex_place_node_areas"
" WHERE ROWID IN (SELECT ROWID FROM SpatialIndex"\ " WHERE ROWID IN (SELECT ROWID FROM SpatialIndex"
" WHERE f_table_name = 'placex_place_node_areas'"\ " WHERE f_table_name = 'placex_place_node_areas'"
" AND search_frame = %s)))") % ( " AND search_frame = %s)))") % (
compiler.process(geom1, **kw), compiler.process(geom1, **kw),
compiler.process(geom2, **kw), compiler.process(geom2, **kw),
compiler.process(rank, **kw), compiler.process(rank, **kw),
@@ -153,6 +152,7 @@ class CrosscheckNames(sa.sql.functions.GenericFunction[Any]):
name = 'CrosscheckNames' name = 'CrosscheckNames'
inherit_cache = True inherit_cache = True
@compiles(CrosscheckNames) @compiles(CrosscheckNames)
def compile_crosscheck_names(element: CrosscheckNames, def compile_crosscheck_names(element: CrosscheckNames,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
@@ -188,7 +188,6 @@ def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw
return "json_each(%s)" % compiler.process(element.clauses, **kw) return "json_each(%s)" % compiler.process(element.clauses, **kw)
class Greatest(sa.sql.functions.GenericFunction[Any]): class Greatest(sa.sql.functions.GenericFunction[Any]):
""" Function to compute maximum of all its input parameters. """ Function to compute maximum of all its input parameters.
""" """
@@ -201,7 +200,6 @@ def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> st
return "max(%s)" % compiler.process(element.clauses, **kw) return "max(%s)" % compiler.process(element.clauses, **kw)
class RegexpWord(sa.sql.functions.GenericFunction[Any]): class RegexpWord(sa.sql.functions.GenericFunction[Any]):
""" Check if a full word is in a given string. """ Check if a full word is in a given string.
""" """
@@ -212,10 +210,12 @@ class RegexpWord(sa.sql.functions.GenericFunction[Any]):
@compiles(RegexpWord, 'postgresql') @compiles(RegexpWord, 'postgresql')
def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str: def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "%s ~* ('\\m(' || %s || ')\\M')::text" % (compiler.process(arg2, **kw), compiler.process(arg1, **kw)) return "%s ~* ('\\m(' || %s || ')\\M')::text" \
% (compiler.process(arg2, **kw), compiler.process(arg1, **kw))
@compiles(RegexpWord, 'sqlite') @compiles(RegexpWord, 'sqlite')
def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str: def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "regexp('\\b(' || %s || ')\\b', %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) return "regexp('\\b(' || %s || ')\\b', %s)"\
% (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -11,7 +11,7 @@ import sqlalchemy as sa
from .sqlalchemy_types import Geometry, KeyValueStore, IntArray from .sqlalchemy_types import Geometry, KeyValueStore, IntArray
#pylint: disable=too-many-instance-attributes
class SearchTables: class SearchTables:
""" Data class that holds the tables of the Nominatim database. """ Data class that holds the tables of the Nominatim database.
@@ -22,16 +22,19 @@ class SearchTables:
def __init__(self, meta: sa.MetaData) -> None: def __init__(self, meta: sa.MetaData) -> None:
self.meta = meta self.meta = meta
self.import_status = sa.Table('import_status', meta, self.import_status = sa.Table(
'import_status', meta,
sa.Column('lastimportdate', sa.DateTime(True), nullable=False), sa.Column('lastimportdate', sa.DateTime(True), nullable=False),
sa.Column('sequence_id', sa.Integer), sa.Column('sequence_id', sa.Integer),
sa.Column('indexed', sa.Boolean)) sa.Column('indexed', sa.Boolean))
self.properties = sa.Table('nominatim_properties', meta, self.properties = sa.Table(
'nominatim_properties', meta,
sa.Column('property', sa.Text, nullable=False), sa.Column('property', sa.Text, nullable=False),
sa.Column('value', sa.Text)) sa.Column('value', sa.Text))
self.placex = sa.Table('placex', meta, self.placex = sa.Table(
'placex', meta,
sa.Column('place_id', sa.BigInteger, nullable=False), sa.Column('place_id', sa.BigInteger, nullable=False),
sa.Column('parent_place_id', sa.BigInteger), sa.Column('parent_place_id', sa.BigInteger),
sa.Column('linked_place_id', sa.BigInteger), sa.Column('linked_place_id', sa.BigInteger),
@@ -55,14 +58,16 @@ class SearchTables:
sa.Column('postcode', sa.Text), sa.Column('postcode', sa.Text),
sa.Column('centroid', Geometry)) sa.Column('centroid', Geometry))
self.addressline = sa.Table('place_addressline', meta, self.addressline = sa.Table(
'place_addressline', meta,
sa.Column('place_id', sa.BigInteger), sa.Column('place_id', sa.BigInteger),
sa.Column('address_place_id', sa.BigInteger), sa.Column('address_place_id', sa.BigInteger),
sa.Column('distance', sa.Float), sa.Column('distance', sa.Float),
sa.Column('fromarea', sa.Boolean), sa.Column('fromarea', sa.Boolean),
sa.Column('isaddress', sa.Boolean)) sa.Column('isaddress', sa.Boolean))
self.postcode = sa.Table('location_postcode', meta, self.postcode = sa.Table(
'location_postcode', meta,
sa.Column('place_id', sa.BigInteger), sa.Column('place_id', sa.BigInteger),
sa.Column('parent_place_id', sa.BigInteger), sa.Column('parent_place_id', sa.BigInteger),
sa.Column('rank_search', sa.SmallInteger), sa.Column('rank_search', sa.SmallInteger),
@@ -73,7 +78,8 @@ class SearchTables:
sa.Column('postcode', sa.Text), sa.Column('postcode', sa.Text),
sa.Column('geometry', Geometry)) sa.Column('geometry', Geometry))
self.osmline = sa.Table('location_property_osmline', meta, self.osmline = sa.Table(
'location_property_osmline', meta,
sa.Column('place_id', sa.BigInteger, nullable=False), sa.Column('place_id', sa.BigInteger, nullable=False),
sa.Column('osm_id', sa.BigInteger), sa.Column('osm_id', sa.BigInteger),
sa.Column('parent_place_id', sa.BigInteger), sa.Column('parent_place_id', sa.BigInteger),
@@ -87,19 +93,22 @@ class SearchTables:
sa.Column('postcode', sa.Text), sa.Column('postcode', sa.Text),
sa.Column('country_code', sa.String(2))) sa.Column('country_code', sa.String(2)))
self.country_name = sa.Table('country_name', meta, self.country_name = sa.Table(
'country_name', meta,
sa.Column('country_code', sa.String(2)), sa.Column('country_code', sa.String(2)),
sa.Column('name', KeyValueStore), sa.Column('name', KeyValueStore),
sa.Column('derived_name', KeyValueStore), sa.Column('derived_name', KeyValueStore),
sa.Column('partition', sa.Integer)) sa.Column('partition', sa.Integer))
self.country_grid = sa.Table('country_osm_grid', meta, self.country_grid = sa.Table(
'country_osm_grid', meta,
sa.Column('country_code', sa.String(2)), sa.Column('country_code', sa.String(2)),
sa.Column('area', sa.Float), sa.Column('area', sa.Float),
sa.Column('geometry', Geometry)) sa.Column('geometry', Geometry))
# The following tables are not necessarily present. # The following tables are not necessarily present.
self.search_name = sa.Table('search_name', meta, self.search_name = sa.Table(
'search_name', meta,
sa.Column('place_id', sa.BigInteger), sa.Column('place_id', sa.BigInteger),
sa.Column('importance', sa.Float), sa.Column('importance', sa.Float),
sa.Column('search_rank', sa.SmallInteger), sa.Column('search_rank', sa.SmallInteger),
@@ -109,7 +118,8 @@ class SearchTables:
sa.Column('country_code', sa.String(2)), sa.Column('country_code', sa.String(2)),
sa.Column('centroid', Geometry)) sa.Column('centroid', Geometry))
self.tiger = sa.Table('location_property_tiger', meta, self.tiger = sa.Table(
'location_property_tiger', meta,
sa.Column('place_id', sa.BigInteger), sa.Column('place_id', sa.BigInteger),
sa.Column('parent_place_id', sa.BigInteger), sa.Column('parent_place_id', sa.BigInteger),
sa.Column('startnumber', sa.Integer), sa.Column('startnumber', sa.Integer),

View File

@@ -9,7 +9,6 @@ Custom types for SQLAlchemy.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Callable, Any, cast from typing import Callable, Any, cast
import sys
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
@@ -17,7 +16,6 @@ from sqlalchemy import types
from ...typing import SaColumn, SaBind from ...typing import SaColumn, SaBind
#pylint: disable=all
class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]): class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
""" Function to compute the spherical distance in meters. """ Function to compute the spherical distance in meters.
@@ -126,12 +124,12 @@ def spatialite_intersects_column(element: Geometry_ColumnIntersectsBbox,
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "MbrIntersects(%s, %s) = 1 and "\ return "MbrIntersects(%s, %s) = 1 and "\
"%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\ "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
"WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\ " WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
"AND search_frame = %s)" %( " AND search_frame = %s)"\
compiler.process(arg1, **kw), % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw), compiler.process(arg2, **kw),
arg1.table.name, arg1.table.name, arg1.name, arg1.table.name, arg1.table.name, arg1.name,
compiler.process(arg2, **kw)) compiler.process(arg2, **kw))
class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]): class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
@@ -149,23 +147,24 @@ def default_dwithin_column(element: Geometry_ColumnDWithin,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw) return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
@compiles(Geometry_ColumnDWithin, 'sqlite') @compiles(Geometry_ColumnDWithin, 'sqlite')
def spatialite_dwithin_column(element: Geometry_ColumnDWithin, def spatialite_dwithin_column(element: Geometry_ColumnDWithin,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
geom1, geom2, dist = list(element.clauses) geom1, geom2, dist = list(element.clauses)
return "ST_Distance(%s, %s) < %s and "\ return "ST_Distance(%s, %s) < %s and "\
"%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\ "%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
"WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\ " WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
"AND search_frame = ST_Expand(%s, %s))" %( " AND search_frame = ST_Expand(%s, %s))"\
compiler.process(geom1, **kw), % (compiler.process(geom1, **kw),
compiler.process(geom2, **kw), compiler.process(geom2, **kw),
compiler.process(dist, **kw), compiler.process(dist, **kw),
geom1.table.name, geom1.table.name, geom1.name, geom1.table.name, geom1.table.name, geom1.name,
compiler.process(geom2, **kw), compiler.process(geom2, **kw),
compiler.process(dist, **kw)) compiler.process(dist, **kw))
class Geometry(types.UserDefinedType): # type: ignore[type-arg] class Geometry(types.UserDefinedType): # type: ignore[type-arg]
""" Simplified type decorator for PostGIS geometry. This type """ Simplified type decorator for PostGIS geometry. This type
only supports geometries in 4326 projection. only supports geometries in 4326 projection.
""" """
@@ -174,11 +173,9 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
def __init__(self, subtype: str = 'Geometry'): def __init__(self, subtype: str = 'Geometry'):
self.subtype = subtype self.subtype = subtype
def get_col_spec(self) -> str: def get_col_spec(self) -> str:
return f'GEOMETRY({self.subtype}, 4326)' return f'GEOMETRY({self.subtype}, 4326)'
def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]: def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
def process(value: Any) -> str: def process(value: Any) -> str:
if isinstance(value, str): if isinstance(value, str):
@@ -187,23 +184,19 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
return cast(str, value.to_wkt()) return cast(str, value.to_wkt())
return process return process
def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]: def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
def process(value: Any) -> str: def process(value: Any) -> str:
assert isinstance(value, str) assert isinstance(value, str)
return value return value
return process return process
def column_expression(self, col: SaColumn) -> SaColumn: def column_expression(self, col: SaColumn) -> SaColumn:
return sa.func.ST_AsEWKB(col) return sa.func.ST_AsEWKB(col)
def bind_expression(self, bindvalue: SaBind) -> SaColumn: def bind_expression(self, bindvalue: SaBind) -> SaColumn:
return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self) return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators': def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators':
if not use_index: if not use_index:
@@ -214,69 +207,55 @@ class Geometry(types.UserDefinedType): # type: ignore[type-arg]
return Geometry_IntersectsBbox(self.expr, other) return Geometry_IntersectsBbox(self.expr, other)
def is_line_like(self) -> SaColumn: def is_line_like(self) -> SaColumn:
return Geometry_IsLineLike(self) return Geometry_IsLineLike(self)
def is_area(self) -> SaColumn: def is_area(self) -> SaColumn:
return Geometry_IsAreaLike(self) return Geometry_IsAreaLike(self)
def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn: def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn:
if isinstance(self.expr, sa.Column): if isinstance(self.expr, sa.Column):
return Geometry_ColumnDWithin(self.expr, other, distance) return Geometry_ColumnDWithin(self.expr, other, distance)
return self.ST_Distance(other) < distance return self.ST_Distance(other) < distance
def ST_Distance(self, other: SaColumn) -> SaColumn: def ST_Distance(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Distance(self, other, type_=sa.Float) return sa.func.ST_Distance(self, other, type_=sa.Float)
def ST_Contains(self, other: SaColumn) -> SaColumn: def ST_Contains(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Contains(self, other, type_=sa.Boolean) return sa.func.ST_Contains(self, other, type_=sa.Boolean)
def ST_CoveredBy(self, other: SaColumn) -> SaColumn: def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean) return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
def ST_ClosestPoint(self, other: SaColumn) -> SaColumn: def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry), return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
other) other)
def ST_Buffer(self, other: SaColumn) -> SaColumn: def ST_Buffer(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Buffer(self, other, type_=Geometry) return sa.func.ST_Buffer(self, other, type_=Geometry)
def ST_Expand(self, other: SaColumn) -> SaColumn: def ST_Expand(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Expand(self, other, type_=Geometry) return sa.func.ST_Expand(self, other, type_=Geometry)
def ST_Collect(self) -> SaColumn: def ST_Collect(self) -> SaColumn:
return sa.func.ST_Collect(self, type_=Geometry) return sa.func.ST_Collect(self, type_=Geometry)
def ST_Centroid(self) -> SaColumn: def ST_Centroid(self) -> SaColumn:
return sa.func.ST_Centroid(self, type_=Geometry) return sa.func.ST_Centroid(self, type_=Geometry)
def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn: def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry) return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn: def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float) return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
def distance_spheroid(self, other: SaColumn) -> SaColumn: def distance_spheroid(self, other: SaColumn) -> SaColumn:
return Geometry_DistanceSpheroid(self, other) return Geometry_DistanceSpheroid(self, other)
@compiles(Geometry, 'sqlite') @compiles(Geometry, 'sqlite')
def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def] def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
return 'GEOMETRY' return 'GEOMETRY'
@@ -290,6 +269,7 @@ SQLITE_FUNCTION_ALIAS = (
('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'), ('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
) )
def _add_function_alias(func: str, ftype: type, alias: str) -> None: def _add_function_alias(func: str, ftype: type, alias: str) -> None:
_FuncDef = type(func, (sa.sql.functions.GenericFunction, ), { _FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
"type": ftype(), "type": ftype(),
@@ -304,5 +284,6 @@ def _add_function_alias(func: str, ftype: type, alias: str) -> None:
compiles(_FuncDef, 'sqlite')(_sqlite_impl) compiles(_FuncDef, 'sqlite')(_sqlite_impl)
for alias in SQLITE_FUNCTION_ALIAS: for alias in SQLITE_FUNCTION_ALIAS:
_add_function_alias(*alias) _add_function_alias(*alias)

View File

@@ -7,7 +7,7 @@
""" """
Custom type for an array of integers. Custom type for an array of integers.
""" """
from typing import Any, List, cast, Optional from typing import Any, List, Optional
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
@@ -15,7 +15,6 @@ from sqlalchemy.dialects.postgresql import ARRAY
from ...typing import SaDialect, SaColumn from ...typing import SaDialect, SaColumn
# pylint: disable=all
class IntList(sa.types.TypeDecorator[Any]): class IntList(sa.types.TypeDecorator[Any]):
""" A list of integers saved as a text of comma-separated numbers. """ A list of integers saved as a text of comma-separated numbers.
@@ -46,12 +45,11 @@ class IntArray(sa.types.TypeDecorator[Any]):
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]: def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
if dialect.name == 'postgresql': if dialect.name == 'postgresql':
return ARRAY(sa.Integer()) #pylint: disable=invalid-name return ARRAY(sa.Integer())
return IntList() return IntList()
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
def __add__(self, other: SaColumn) -> 'sa.ColumnOperators': def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
""" Concate the array with the given array. If one of the """ Concate the array with the given array. If one of the
@@ -59,7 +57,6 @@ class IntArray(sa.types.TypeDecorator[Any]):
""" """
return ArrayCat(self.expr, other) return ArrayCat(self.expr, other)
def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators': def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
""" Return true if the array contains all the value of the argument """ Return true if the array contains all the value of the argument
array. array.
@@ -67,7 +64,6 @@ class IntArray(sa.types.TypeDecorator[Any]):
return ArrayContains(self.expr, other) return ArrayContains(self.expr, other)
class ArrayAgg(sa.sql.functions.GenericFunction[Any]): class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
""" Aggregate function to collect elements in an array. """ Aggregate function to collect elements in an array.
""" """
@@ -82,7 +78,6 @@ def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> s
return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw) return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
class ArrayContains(sa.sql.expression.FunctionElement[Any]): class ArrayContains(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another. """ Function to check if an array is fully contained in another.
""" """
@@ -102,7 +97,6 @@ def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw:
return "array_contains(%s)" % compiler.process(element.clauses, **kw) return "array_contains(%s)" % compiler.process(element.clauses, **kw)
class ArrayCat(sa.sql.expression.FunctionElement[Any]): class ArrayCat(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another. """ Function to check if an array is fully contained in another.
""" """
@@ -120,4 +114,3 @@ def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) ->
def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str: def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -15,7 +15,6 @@ from sqlalchemy.dialects.sqlite import JSON as sqlite_json
from ...typing import SaDialect from ...typing import SaDialect
# pylint: disable=all
class Json(sa.types.TypeDecorator[Any]): class Json(sa.types.TypeDecorator[Any]):
""" Dialect-independent type for JSON. """ Dialect-independent type for JSON.
@@ -25,6 +24,6 @@ class Json(sa.types.TypeDecorator[Any]):
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]: def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
if dialect.name == 'postgresql': if dialect.name == 'postgresql':
return JSONB(none_as_null=True) # type: ignore[no-untyped-call] return JSONB(none_as_null=True) # type: ignore[no-untyped-call]
return sqlite_json(none_as_null=True) return sqlite_json(none_as_null=True)

View File

@@ -16,7 +16,6 @@ from sqlalchemy.dialects.sqlite import JSON as sqlite_json
from ...typing import SaDialect, SaColumn from ...typing import SaDialect, SaColumn
# pylint: disable=all
class KeyValueStore(sa.types.TypeDecorator[Any]): class KeyValueStore(sa.types.TypeDecorator[Any]):
""" Dialect-independent type of a simple key-value store of strings. """ Dialect-independent type of a simple key-value store of strings.
@@ -26,12 +25,11 @@ class KeyValueStore(sa.types.TypeDecorator[Any]):
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]: def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
if dialect.name == 'postgresql': if dialect.name == 'postgresql':
return HSTORE() # type: ignore[no-untyped-call] return HSTORE() # type: ignore[no-untyped-call]
return sqlite_json(none_as_null=True) return sqlite_json(none_as_null=True)
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
def merge(self, other: SaColumn) -> 'sa.Operators': def merge(self, other: SaColumn) -> 'sa.Operators':
""" Merge the values from the given KeyValueStore into this """ Merge the values from the given KeyValueStore into this
@@ -48,15 +46,16 @@ class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
name = 'JsonConcat' name = 'JsonConcat'
inherit_cache = True inherit_cache = True
@compiles(KeyValueConcat) @compiles(KeyValueConcat)
def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str: def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) return "(%s || coalesce(%s, ''::hstore))"\
% (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
@compiles(KeyValueConcat, 'sqlite') @compiles(KeyValueConcat, 'sqlite')
def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str: def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw)) return "json_patch(%s, coalesce(%s, '{}'))"\
% (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -10,7 +10,6 @@ Custom functions for SQLite.
from typing import cast, Optional, Set, Any from typing import cast, Optional, Set, Any
import json import json
# pylint: disable=protected-access
def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float: def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
""" Custom weight function for search results. """ Custom weight function for search results.
@@ -118,5 +117,5 @@ async def _make_aggregate(aioconn: Any, *args: Any) -> None:
def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None: def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
try: try:
conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate)) conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
except Exception as error: # pylint: disable=broad-exception-caught except Exception as error:
conn._handle_exception(error) conn._handle_exception(error)

View File

@@ -16,6 +16,7 @@ import sqlalchemy as sa
from .connection import SearchConnection from .connection import SearchConnection
from .version import NOMINATIM_API_VERSION from .version import NOMINATIM_API_VERSION
@dataclasses.dataclass @dataclasses.dataclass
class StatusResult: class StatusResult:
""" Result of a call to the status API. """ Result of a call to the status API.

View File

@@ -19,7 +19,6 @@ from binascii import unhexlify
from .errors import UsageError from .errors import UsageError
from .localization import Locales from .localization import Locales
# pylint: disable=no-member,too-many-boolean-expressions,too-many-instance-attributes
@dataclasses.dataclass @dataclasses.dataclass
class PlaceID: class PlaceID:
@@ -72,27 +71,23 @@ class Point(NamedTuple):
x: float x: float
y: float y: float
@property @property
def lat(self) -> float: def lat(self) -> float:
""" Return the latitude of the point. """ Return the latitude of the point.
""" """
return self.y return self.y
@property @property
def lon(self) -> float: def lon(self) -> float:
""" Return the longitude of the point. """ Return the longitude of the point.
""" """
return self.x return self.x
def to_geojson(self) -> str: def to_geojson(self) -> str:
""" Return the point in GeoJSON format. """ Return the point in GeoJSON format.
""" """
return f'{{"type": "Point","coordinates": [{self.x}, {self.y}]}}' return f'{{"type": "Point","coordinates": [{self.x}, {self.y}]}}'
@staticmethod @staticmethod
def from_wkb(wkb: Union[str, bytes]) -> 'Point': def from_wkb(wkb: Union[str, bytes]) -> 'Point':
""" Create a point from EWKB as returned from the database. """ Create a point from EWKB as returned from the database.
@@ -115,7 +110,6 @@ class Point(NamedTuple):
return Point(x, y) return Point(x, y)
@staticmethod @staticmethod
def from_param(inp: Any) -> 'Point': def from_param(inp: Any) -> 'Point':
""" Create a point from an input parameter. The parameter """ Create a point from an input parameter. The parameter
@@ -144,19 +138,18 @@ class Point(NamedTuple):
return Point(x, y) return Point(x, y)
def to_wkt(self) -> str: def to_wkt(self) -> str:
""" Return the WKT representation of the point. """ Return the WKT representation of the point.
""" """
return f'POINT({self.x} {self.y})' return f'POINT({self.x} {self.y})'
AnyPoint = Union[Point, Tuple[float, float]] AnyPoint = Union[Point, Tuple[float, float]]
WKB_BBOX_HEADER_LE = b'\x01\x03\x00\x00\x20\xE6\x10\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00' WKB_BBOX_HEADER_LE = b'\x01\x03\x00\x00\x20\xE6\x10\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00'
WKB_BBOX_HEADER_BE = b'\x00\x20\x00\x00\x03\x00\x00\x10\xe6\x00\x00\x00\x01\x00\x00\x00\x05' WKB_BBOX_HEADER_BE = b'\x00\x20\x00\x00\x03\x00\x00\x10\xe6\x00\x00\x00\x01\x00\x00\x00\x05'
class Bbox: class Bbox:
""" A bounding box in WGS84 projection. """ A bounding box in WGS84 projection.
@@ -169,56 +162,48 @@ class Bbox:
""" """
self.coords = (minx, miny, maxx, maxy) self.coords = (minx, miny, maxx, maxy)
@property @property
def minlat(self) -> float: def minlat(self) -> float:
""" Southern-most latitude, corresponding to the minimum y coordinate. """ Southern-most latitude, corresponding to the minimum y coordinate.
""" """
return self.coords[1] return self.coords[1]
@property @property
def maxlat(self) -> float: def maxlat(self) -> float:
""" Northern-most latitude, corresponding to the maximum y coordinate. """ Northern-most latitude, corresponding to the maximum y coordinate.
""" """
return self.coords[3] return self.coords[3]
@property @property
def minlon(self) -> float: def minlon(self) -> float:
""" Western-most longitude, corresponding to the minimum x coordinate. """ Western-most longitude, corresponding to the minimum x coordinate.
""" """
return self.coords[0] return self.coords[0]
@property @property
def maxlon(self) -> float: def maxlon(self) -> float:
""" Eastern-most longitude, corresponding to the maximum x coordinate. """ Eastern-most longitude, corresponding to the maximum x coordinate.
""" """
return self.coords[2] return self.coords[2]
@property @property
def area(self) -> float: def area(self) -> float:
""" Return the area of the box in WGS84. """ Return the area of the box in WGS84.
""" """
return (self.coords[2] - self.coords[0]) * (self.coords[3] - self.coords[1]) return (self.coords[2] - self.coords[0]) * (self.coords[3] - self.coords[1])
def contains(self, pt: Point) -> bool: def contains(self, pt: Point) -> bool:
""" Check if the point is inside or on the boundary of the box. """ Check if the point is inside or on the boundary of the box.
""" """
return self.coords[0] <= pt[0] and self.coords[1] <= pt[1]\ return self.coords[0] <= pt[0] and self.coords[1] <= pt[1]\
and self.coords[2] >= pt[0] and self.coords[3] >= pt[1] and self.coords[2] >= pt[0] and self.coords[3] >= pt[1]
def to_wkt(self) -> str: def to_wkt(self) -> str:
""" Return the WKT representation of the Bbox. This """ Return the WKT representation of the Bbox. This
is a simple polygon with four points. is a simple polygon with four points.
""" """
return 'POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))'\ return 'POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))'\
.format(*self.coords) # pylint: disable=consider-using-f-string .format(*self.coords)
@staticmethod @staticmethod
def from_wkb(wkb: Union[None, str, bytes]) -> 'Optional[Bbox]': def from_wkb(wkb: Union[None, str, bytes]) -> 'Optional[Bbox]':
@@ -242,7 +227,6 @@ class Bbox:
return Bbox(min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)) return Bbox(min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))
@staticmethod @staticmethod
def from_point(pt: Point, buffer: float) -> 'Bbox': def from_point(pt: Point, buffer: float) -> 'Bbox':
""" Return a Bbox around the point with the buffer added to all sides. """ Return a Bbox around the point with the buffer added to all sides.
@@ -250,7 +234,6 @@ class Bbox:
return Bbox(pt[0] - buffer, pt[1] - buffer, return Bbox(pt[0] - buffer, pt[1] - buffer,
pt[0] + buffer, pt[1] + buffer) pt[0] + buffer, pt[1] + buffer)
@staticmethod @staticmethod
def from_param(inp: Any) -> 'Bbox': def from_param(inp: Any) -> 'Bbox':
""" Return a Bbox from an input parameter. The box may be """ Return a Bbox from an input parameter. The box may be
@@ -383,7 +366,9 @@ def format_categories(categories: List[Tuple[str, str]]) -> List[Tuple[str, str]
""" """
return categories return categories
TParam = TypeVar('TParam', bound='LookupDetails') # pylint: disable=invalid-name
TParam = TypeVar('TParam', bound='LookupDetails')
@dataclasses.dataclass @dataclasses.dataclass
class LookupDetails: class LookupDetails:
@@ -434,7 +419,7 @@ class LookupDetails:
else field.default else field.default
if field.metadata and 'transform' in field.metadata: if field.metadata and 'transform' in field.metadata:
return field.metadata['transform'](v) return field.metadata['transform'](v)
if not isinstance(v, field.type): # type: ignore[arg-type] if not isinstance(v, field.type): # type: ignore[arg-type]
raise UsageError(f"Parameter '{field.name}' needs to be of {field.type!s}.") raise UsageError(f"Parameter '{field.name}' needs to be of {field.type!s}.")
return v return v
@@ -446,15 +431,17 @@ class LookupDetails:
class ReverseDetails(LookupDetails): class ReverseDetails(LookupDetails):
""" Collection of parameters for the reverse call. """ Collection of parameters for the reverse call.
""" """
max_rank: int = dataclasses.field(default=30, max_rank: int = dataclasses.field(default=30,
metadata={'transform': lambda v: max(0, min(v, 30))} metadata={'transform': lambda v: max(0, min(v, 30))})
)
""" Highest address rank to return. """ Highest address rank to return.
""" """
layers: DataLayer = DataLayer.ADDRESS | DataLayer.POI layers: DataLayer = DataLayer.ADDRESS | DataLayer.POI
""" Filter which kind of data to include. """ Filter which kind of data to include.
""" """
@dataclasses.dataclass @dataclasses.dataclass
class SearchDetails(LookupDetails): class SearchDetails(LookupDetails):
""" Collection of parameters for the search call. """ Collection of parameters for the search call.
@@ -463,54 +450,63 @@ class SearchDetails(LookupDetails):
""" Maximum number of results to be returned. The actual number of results """ Maximum number of results to be returned. The actual number of results
may be less. may be less.
""" """
min_rank: int = dataclasses.field(default=0, min_rank: int = dataclasses.field(default=0,
metadata={'transform': lambda v: max(0, min(v, 30))} metadata={'transform': lambda v: max(0, min(v, 30))})
)
""" Lowest address rank to return. """ Lowest address rank to return.
""" """
max_rank: int = dataclasses.field(default=30, max_rank: int = dataclasses.field(default=30,
metadata={'transform': lambda v: max(0, min(v, 30))} metadata={'transform': lambda v: max(0, min(v, 30))})
)
""" Highest address rank to return. """ Highest address rank to return.
""" """
layers: Optional[DataLayer] = dataclasses.field(default=None, layers: Optional[DataLayer] = dataclasses.field(default=None,
metadata={'transform': lambda r : r}) metadata={'transform': lambda r: r})
""" Filter which kind of data to include. When 'None' (the default) then """ Filter which kind of data to include. When 'None' (the default) then
filtering by layers is disabled. filtering by layers is disabled.
""" """
countries: List[str] = dataclasses.field(default_factory=list, countries: List[str] = dataclasses.field(default_factory=list,
metadata={'transform': format_country}) metadata={'transform': format_country})
""" Restrict search results to the given countries. An empty list (the """ Restrict search results to the given countries. An empty list (the
default) will disable this filter. default) will disable this filter.
""" """
excluded: List[int] = dataclasses.field(default_factory=list, excluded: List[int] = dataclasses.field(default_factory=list,
metadata={'transform': format_excluded}) metadata={'transform': format_excluded})
""" List of OSM objects to exclude from the results. Currently only """ List of OSM objects to exclude from the results. Currently only
works when the internal place ID is given. works when the internal place ID is given.
An empty list (the default) will disable this filter. An empty list (the default) will disable this filter.
""" """
viewbox: Optional[Bbox] = dataclasses.field(default=None, viewbox: Optional[Bbox] = dataclasses.field(default=None,
metadata={'transform': Bbox.from_param}) metadata={'transform': Bbox.from_param})
""" Focus the search on a given map area. """ Focus the search on a given map area.
""" """
bounded_viewbox: bool = False bounded_viewbox: bool = False
""" Use 'viewbox' as a filter and restrict results to places within the """ Use 'viewbox' as a filter and restrict results to places within the
given area. given area.
""" """
near: Optional[Point] = dataclasses.field(default=None, near: Optional[Point] = dataclasses.field(default=None,
metadata={'transform': Point.from_param}) metadata={'transform': Point.from_param})
""" Order results by distance to the given point. """ Order results by distance to the given point.
""" """
near_radius: Optional[float] = dataclasses.field(default=None, near_radius: Optional[float] = dataclasses.field(default=None,
metadata={'transform': lambda r : r}) metadata={'transform': lambda r: r})
""" Use near point as a filter and drop results outside the given """ Use near point as a filter and drop results outside the given
radius. Radius is given in degrees WSG84. radius. Radius is given in degrees WSG84.
""" """
categories: List[Tuple[str, str]] = dataclasses.field(default_factory=list, categories: List[Tuple[str, str]] = dataclasses.field(default_factory=list,
metadata={'transform': format_categories}) metadata={'transform': format_categories})
""" Restrict search to places with one of the given class/type categories. """ Restrict search to places with one of the given class/type categories.
An empty list (the default) will disable this filter. An empty list (the default) will disable this filter.
""" """
viewbox_x2: Optional[Bbox] = None viewbox_x2: Optional[Bbox] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
@@ -520,7 +516,6 @@ class SearchDetails(LookupDetails):
self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.minlat - yext, self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.minlat - yext,
self.viewbox.maxlon + xext, self.viewbox.maxlat + yext) self.viewbox.maxlon + xext, self.viewbox.maxlat + yext)
def restrict_min_max_rank(self, new_min: int, new_max: int) -> None: def restrict_min_max_rank(self, new_min: int, new_max: int) -> None:
""" Change the min_rank and max_rank fields to respect the """ Change the min_rank and max_rank fields to respect the
given boundaries. given boundaries.
@@ -529,7 +524,6 @@ class SearchDetails(LookupDetails):
self.min_rank = max(self.min_rank, new_min) self.min_rank = max(self.min_rank, new_min)
self.max_rank = min(self.max_rank, new_max) self.max_rank = min(self.max_rank, new_max)
def is_impossible(self) -> bool: def is_impossible(self) -> bool:
""" Check if the parameter configuration is contradictionary and """ Check if the parameter configuration is contradictionary and
cannot yield any results. cannot yield any results.
@@ -542,7 +536,6 @@ class SearchDetails(LookupDetails):
or (self.max_rank <= 4 and or (self.max_rank <= 4 and
self.layers is not None and not self.layers & DataLayer.ADDRESS)) self.layers is not None and not self.layers & DataLayer.ADDRESS))
def layer_enabled(self, layer: DataLayer) -> bool: def layer_enabled(self, layer: DataLayer) -> bool:
""" Check if the given layer has been chosen. Also returns """ Check if the given layer has been chosen. Also returns
true when layer restriction has been disabled completely. true when layer restriction has been disabled completely.

View File

@@ -11,7 +11,7 @@ Complex type definitions are moved here, to keep the source files readable.
""" """
from typing import Union, TYPE_CHECKING from typing import Union, TYPE_CHECKING
# pylint: disable=missing-class-docstring,useless-import-alias # flake8: noqa
# SQLAlchemy introduced generic types in version 2.0 making typing # SQLAlchemy introduced generic types in version 2.0 making typing
# incompatible with older versions. Add wrappers here so we don't have # incompatible with older versions. Add wrappers here so we don't have

View File

@@ -12,9 +12,10 @@ import io
try: try:
import ujson as json import ujson as json
except ModuleNotFoundError: except ModuleNotFoundError:
import json # type: ignore[no-redef] import json # type: ignore[no-redef]
T = TypeVar('T')
T = TypeVar('T') # pylint: disable=invalid-name
class JsonWriter: class JsonWriter:
""" JSON encoder that renders the output directly into an output """ JSON encoder that renders the output directly into an output
@@ -33,7 +34,6 @@ class JsonWriter:
self.data = io.StringIO() self.data = io.StringIO()
self.pending = '' self.pending = ''
def __call__(self) -> str: def __call__(self) -> str:
""" Return the rendered JSON content as a string. """ Return the rendered JSON content as a string.
The writer remains usable after calling this function. The writer remains usable after calling this function.
@@ -44,7 +44,6 @@ class JsonWriter:
self.pending = '' self.pending = ''
return self.data.getvalue() return self.data.getvalue()
def start_object(self) -> 'JsonWriter': def start_object(self) -> 'JsonWriter':
""" Write the open bracket of a JSON object. """ Write the open bracket of a JSON object.
""" """
@@ -53,7 +52,6 @@ class JsonWriter:
self.pending = '{' self.pending = '{'
return self return self
def end_object(self) -> 'JsonWriter': def end_object(self) -> 'JsonWriter':
""" Write the closing bracket of a JSON object. """ Write the closing bracket of a JSON object.
""" """
@@ -63,7 +61,6 @@ class JsonWriter:
self.pending = '}' self.pending = '}'
return self return self
def start_array(self) -> 'JsonWriter': def start_array(self) -> 'JsonWriter':
""" Write the opening bracket of a JSON array. """ Write the opening bracket of a JSON array.
""" """
@@ -72,7 +69,6 @@ class JsonWriter:
self.pending = '[' self.pending = '['
return self return self
def end_array(self) -> 'JsonWriter': def end_array(self) -> 'JsonWriter':
""" Write the closing bracket of a JSON array. """ Write the closing bracket of a JSON array.
""" """
@@ -82,7 +78,6 @@ class JsonWriter:
self.pending = ']' self.pending = ']'
return self return self
def key(self, name: str) -> 'JsonWriter': def key(self, name: str) -> 'JsonWriter':
""" Write the key string of a JSON object. """ Write the key string of a JSON object.
""" """
@@ -92,7 +87,6 @@ class JsonWriter:
self.pending = ':' self.pending = ':'
return self return self
def value(self, value: Any) -> 'JsonWriter': def value(self, value: Any) -> 'JsonWriter':
""" Write out a value as JSON. The function uses the json.dumps() """ Write out a value as JSON. The function uses the json.dumps()
function for encoding the JSON. Thus any value that can be function for encoding the JSON. Thus any value that can be
@@ -100,7 +94,6 @@ class JsonWriter:
""" """
return self.raw(json.dumps(value, ensure_ascii=False)) return self.raw(json.dumps(value, ensure_ascii=False))
def float(self, value: float, precision: int) -> 'JsonWriter': def float(self, value: float, precision: int) -> 'JsonWriter':
""" Write out a float value with the given precision. """ Write out a float value with the given precision.
""" """
@@ -114,7 +107,6 @@ class JsonWriter:
self.pending = ',' self.pending = ','
return self return self
def raw(self, raw_json: str) -> 'JsonWriter': def raw(self, raw_json: str) -> 'JsonWriter':
""" Write out the given value as is. This function is useful if """ Write out the given value as is. This function is useful if
a value is already available in JSON format. a value is already available in JSON format.
@@ -125,7 +117,6 @@ class JsonWriter:
self.data.write(raw_json) self.data.write(raw_json)
return self return self
def keyval(self, key: str, value: Any) -> 'JsonWriter': def keyval(self, key: str, value: Any) -> 'JsonWriter':
""" Write out an object element with the given key and value. """ Write out an object element with the given key and value.
This is a shortcut for calling 'key()', 'value()' and 'next()'. This is a shortcut for calling 'key()', 'value()' and 'next()'.
@@ -134,7 +125,6 @@ class JsonWriter:
self.value(value) self.value(value)
return self.next() return self.next()
def keyval_not_none(self, key: str, value: Optional[T], def keyval_not_none(self, key: str, value: Optional[T],
transform: Optional[Callable[[T], Any]] = None) -> 'JsonWriter': transform: Optional[Callable[[T], Any]] = None) -> 'JsonWriter':
""" Write out an object element only if the value is not None. """ Write out an object element only if the value is not None.

View File

@@ -8,6 +8,4 @@
Implementation of API version v1 (aka the legacy version). Implementation of API version v1 (aka the legacy version).
""" """
#pylint: disable=useless-import-alias
from .server_glue import ROUTES as ROUTES from .server_glue import ROUTES as ROUTES

View File

@@ -15,6 +15,7 @@ from typing import Tuple, Optional, Mapping, Union
from ..results import ReverseResult, SearchResult from ..results import ReverseResult, SearchResult
from ..types import Bbox from ..types import Bbox
def get_label_tag(category: Tuple[str, str], extratags: Optional[Mapping[str, str]], def get_label_tag(category: Tuple[str, str], extratags: Optional[Mapping[str, str]],
rank: int, country: Optional[str]) -> str: rank: int, country: Optional[str]) -> str:
""" Create a label tag for the given place that can be used as an XML name. """ Create a label tag for the given place that can be used as an XML name.
@@ -33,8 +34,8 @@ def get_label_tag(category: Tuple[str, str], extratags: Optional[Mapping[str, st
label = category[1] if category[1] != 'yes' else category[0] label = category[1] if category[1] != 'yes' else category[0]
elif rank < 28: elif rank < 28:
label = 'road' label = 'road'
elif category[0] == 'place'\ elif (category[0] == 'place'
and category[1] in ('house_number', 'house_name', 'country_code'): and category[1] in ('house_number', 'house_name', 'country_code')):
label = category[1] label = category[1]
else: else:
label = category[0] label = category[0]

View File

@@ -22,14 +22,17 @@ from . import format_json, format_xml
from .. import logging as loglib from .. import logging as loglib
from ..server import content_types as ct from ..server import content_types as ct
class RawDataList(List[Dict[str, Any]]): class RawDataList(List[Dict[str, Any]]):
""" Data type for formatting raw data lists 'as is' in json. """ Data type for formatting raw data lists 'as is' in json.
""" """
dispatch = FormatDispatcher({'text': ct.CONTENT_TEXT, dispatch = FormatDispatcher({'text': ct.CONTENT_TEXT,
'xml': ct.CONTENT_XML, 'xml': ct.CONTENT_XML,
'debug': ct.CONTENT_HTML}) 'debug': ct.CONTENT_HTML})
@dispatch.error_format_func @dispatch.error_format_func
def _format_error(content_type: str, msg: str, status: int) -> str: def _format_error(content_type: str, msg: str, status: int) -> str:
if content_type == ct.CONTENT_XML: if content_type == ct.CONTENT_XML:
@@ -65,13 +68,13 @@ def _format_status_json(result: StatusResult, _: Mapping[str, Any]) -> str:
out = JsonWriter() out = JsonWriter()
out.start_object()\ out.start_object()\
.keyval('status', result.status)\ .keyval('status', result.status)\
.keyval('message', result.message)\ .keyval('message', result.message)\
.keyval_not_none('data_updated', result.data_updated, .keyval_not_none('data_updated', result.data_updated,
lambda v: v.isoformat())\ lambda v: v.isoformat())\
.keyval('software_version', str(result.software_version))\ .keyval('software_version', str(result.software_version))\
.keyval_not_none('database_version', result.database_version, str)\ .keyval_not_none('database_version', result.database_version, str)\
.end_object() .end_object()
return out() return out()
@@ -119,7 +122,7 @@ def _add_parent_rows_grouped(writer: JsonWriter, rows: AddressLines,
writer.key('hierarchy').start_object() writer.key('hierarchy').start_object()
for group, grouped in data.items(): for group, grouped in data.items():
writer.key(group).start_array() writer.key(group).start_array()
grouped.sort() # sorts alphabetically by local name grouped.sort() # sorts alphabetically by local name
for line in grouped: for line in grouped:
writer.raw(line).next() writer.raw(line).next()
writer.end_array().next() writer.end_array().next()
@@ -135,32 +138,32 @@ def _format_details_json(result: DetailedResult, options: Mapping[str, Any]) ->
out = JsonWriter() out = JsonWriter()
out.start_object()\ out.start_object()\
.keyval_not_none('place_id', result.place_id)\ .keyval_not_none('place_id', result.place_id)\
.keyval_not_none('parent_place_id', result.parent_place_id) .keyval_not_none('parent_place_id', result.parent_place_id)
if result.osm_object is not None: if result.osm_object is not None:
out.keyval('osm_type', result.osm_object[0])\ out.keyval('osm_type', result.osm_object[0])\
.keyval('osm_id', result.osm_object[1]) .keyval('osm_id', result.osm_object[1])
out.keyval('category', result.category[0])\ out.keyval('category', result.category[0])\
.keyval('type', result.category[1])\ .keyval('type', result.category[1])\
.keyval('admin_level', result.admin_level)\ .keyval('admin_level', result.admin_level)\
.keyval('localname', result.locale_name or '')\ .keyval('localname', result.locale_name or '')\
.keyval('names', result.names or {})\ .keyval('names', result.names or {})\
.keyval('addresstags', result.address or {})\ .keyval('addresstags', result.address or {})\
.keyval_not_none('housenumber', result.housenumber)\ .keyval_not_none('housenumber', result.housenumber)\
.keyval_not_none('calculated_postcode', result.postcode)\ .keyval_not_none('calculated_postcode', result.postcode)\
.keyval_not_none('country_code', result.country_code)\ .keyval_not_none('country_code', result.country_code)\
.keyval_not_none('indexed_date', result.indexed_date, lambda v: v.isoformat())\ .keyval_not_none('indexed_date', result.indexed_date, lambda v: v.isoformat())\
.keyval_not_none('importance', result.importance)\ .keyval_not_none('importance', result.importance)\
.keyval('calculated_importance', result.calculated_importance())\ .keyval('calculated_importance', result.calculated_importance())\
.keyval('extratags', result.extratags or {})\ .keyval('extratags', result.extratags or {})\
.keyval_not_none('calculated_wikipedia', result.wikipedia)\ .keyval_not_none('calculated_wikipedia', result.wikipedia)\
.keyval('rank_address', result.rank_address)\ .keyval('rank_address', result.rank_address)\
.keyval('rank_search', result.rank_search)\ .keyval('rank_search', result.rank_search)\
.keyval('isarea', 'Polygon' in (geom or result.geometry.get('type') or ''))\ .keyval('isarea', 'Polygon' in (geom or result.geometry.get('type') or ''))\
.key('centroid').raw(centroid).next()\ .key('centroid').raw(centroid).next()\
.key('geometry').raw(geom or centroid).next() .key('geometry').raw(geom or centroid).next()
if options.get('icon_base_url', None): if options.get('icon_base_url', None):
icon = ICONS.get(result.category) icon = ICONS.get(result.category)
@@ -241,32 +244,32 @@ def _format_search_xml(results: SearchResults, options: Mapping[str, Any]) -> st
extra) extra)
@dispatch.format_func(SearchResults, 'geojson') @dispatch.format_func(SearchResults, 'geojson')
def _format_search_geojson(results: SearchResults, def _format_search_geojson(results: SearchResults,
options: Mapping[str, Any]) -> str: options: Mapping[str, Any]) -> str:
return format_json.format_base_geojson(results, options, False) return format_json.format_base_geojson(results, options, False)
@dispatch.format_func(SearchResults, 'geocodejson') @dispatch.format_func(SearchResults, 'geocodejson')
def _format_search_geocodejson(results: SearchResults, def _format_search_geocodejson(results: SearchResults,
options: Mapping[str, Any]) -> str: options: Mapping[str, Any]) -> str:
return format_json.format_base_geocodejson(results, options, False) return format_json.format_base_geocodejson(results, options, False)
@dispatch.format_func(SearchResults, 'json') @dispatch.format_func(SearchResults, 'json')
def _format_search_json(results: SearchResults, def _format_search_json(results: SearchResults,
options: Mapping[str, Any]) -> str: options: Mapping[str, Any]) -> str:
return format_json.format_base_json(results, options, False, return format_json.format_base_json(results, options, False,
class_label='class') class_label='class')
@dispatch.format_func(SearchResults, 'jsonv2') @dispatch.format_func(SearchResults, 'jsonv2')
def _format_search_jsonv2(results: SearchResults, def _format_search_jsonv2(results: SearchResults,
options: Mapping[str, Any]) -> str: options: Mapping[str, Any]) -> str:
return format_json.format_base_json(results, options, False, return format_json.format_base_json(results, options, False,
class_label='category') class_label='category')
@dispatch.format_func(RawDataList, 'json') @dispatch.format_func(RawDataList, 'json')
def _format_raw_data_json(results: RawDataList, _: Mapping[str, Any]) -> str: def _format_raw_data_json(results: RawDataList, _: Mapping[str, Any]) -> str:
out = JsonWriter() out = JsonWriter()
@@ -275,7 +278,7 @@ def _format_raw_data_json(results: RawDataList, _: Mapping[str, Any]) -> str:
out.start_object() out.start_object()
for k, v in res.items(): for k, v in res.items():
if isinstance(v, dt.datetime): if isinstance(v, dt.datetime):
out.keyval(k, v.isoformat(sep= ' ', timespec='seconds')) out.keyval(k, v.isoformat(sep=' ', timespec='seconds'))
else: else:
out.keyval(k, v) out.keyval(k, v)
out.end_object().next() out.end_object().next()

View File

@@ -13,7 +13,6 @@ from ..utils.json_writer import JsonWriter
from ..results import AddressLines, ReverseResults, SearchResults from ..results import AddressLines, ReverseResults, SearchResults
from . import classtypes as cl from . import classtypes as cl
#pylint: disable=too-many-branches
def _write_osm_id(out: JsonWriter, osm_object: Optional[Tuple[str, int]]) -> None: def _write_osm_id(out: JsonWriter, osm_object: Optional[Tuple[str, int]]) -> None:
if osm_object is not None: if osm_object is not None:
@@ -22,7 +21,7 @@ def _write_osm_id(out: JsonWriter, osm_object: Optional[Tuple[str, int]]) -> Non
def _write_typed_address(out: JsonWriter, address: Optional[AddressLines], def _write_typed_address(out: JsonWriter, address: Optional[AddressLines],
country_code: Optional[str]) -> None: country_code: Optional[str]) -> None:
parts = {} parts = {}
for line in (address or []): for line in (address or []):
if line.isaddress: if line.isaddress:
@@ -52,13 +51,12 @@ def _write_geocodejson_address(out: JsonWriter,
out.keyval('postcode', line.local_name) out.keyval('postcode', line.local_name)
elif line.category[1] == 'house_number': elif line.category[1] == 'house_number':
out.keyval('housenumber', line.local_name) out.keyval('housenumber', line.local_name)
elif (obj_place_id is None or obj_place_id != line.place_id) \ elif ((obj_place_id is None or obj_place_id != line.place_id)
and line.rank_address >= 4 and line.rank_address < 28: and line.rank_address >= 4 and line.rank_address < 28):
rank_name = GEOCODEJSON_RANKS[line.rank_address] rank_name = GEOCODEJSON_RANKS[line.rank_address]
if rank_name not in extra: if rank_name not in extra:
extra[rank_name] = line.local_name extra[rank_name] = line.local_name
for k, v in extra.items(): for k, v in extra.items():
out.keyval(k, v) out.keyval(k, v)
@@ -87,17 +85,16 @@ def format_base_json(results: Union[ReverseResults, SearchResults],
_write_osm_id(out, result.osm_object) _write_osm_id(out, result.osm_object)
out.keyval('lat', f"{result.centroid.lat}")\ out.keyval('lat', f"{result.centroid.lat}")\
.keyval('lon', f"{result.centroid.lon}")\ .keyval('lon', f"{result.centroid.lon}")\
.keyval(class_label, result.category[0])\ .keyval(class_label, result.category[0])\
.keyval('type', result.category[1])\ .keyval('type', result.category[1])\
.keyval('place_rank', result.rank_search)\ .keyval('place_rank', result.rank_search)\
.keyval('importance', result.calculated_importance())\ .keyval('importance', result.calculated_importance())\
.keyval('addresstype', cl.get_label_tag(result.category, result.extratags, .keyval('addresstype', cl.get_label_tag(result.category, result.extratags,
result.rank_address, result.rank_address,
result.country_code))\ result.country_code))\
.keyval('name', result.locale_name or '')\ .keyval('name', result.locale_name or '')\
.keyval('display_name', result.display_name or '') .keyval('display_name', result.display_name or '')
if options.get('icon_base_url', None): if options.get('icon_base_url', None):
icon = cl.ICONS.get(result.category) icon = cl.ICONS.get(result.category)
@@ -117,10 +114,10 @@ def format_base_json(results: Union[ReverseResults, SearchResults],
bbox = cl.bbox_from_result(result) bbox = cl.bbox_from_result(result)
out.key('boundingbox').start_array()\ out.key('boundingbox').start_array()\
.value(f"{bbox.minlat:0.7f}").next()\ .value(f"{bbox.minlat:0.7f}").next()\
.value(f"{bbox.maxlat:0.7f}").next()\ .value(f"{bbox.maxlat:0.7f}").next()\
.value(f"{bbox.minlon:0.7f}").next()\ .value(f"{bbox.minlon:0.7f}").next()\
.value(f"{bbox.maxlon:0.7f}").next()\ .value(f"{bbox.maxlon:0.7f}").next()\
.end_array().next() .end_array().next()
if result.geometry: if result.geometry:
@@ -153,9 +150,9 @@ def format_base_geojson(results: Union[ReverseResults, SearchResults],
out = JsonWriter() out = JsonWriter()
out.start_object()\ out.start_object()\
.keyval('type', 'FeatureCollection')\ .keyval('type', 'FeatureCollection')\
.keyval('licence', cl.OSM_ATTRIBUTION)\ .keyval('licence', cl.OSM_ATTRIBUTION)\
.key('features').start_array() .key('features').start_array()
for result in results: for result in results:
out.start_object()\ out.start_object()\
@@ -187,7 +184,7 @@ def format_base_geojson(results: Union[ReverseResults, SearchResults],
if options.get('namedetails', False): if options.get('namedetails', False):
out.keyval('namedetails', result.names) out.keyval('namedetails', result.names)
out.end_object().next() # properties out.end_object().next() # properties
out.key('bbox').start_array() out.key('bbox').start_array()
for coord in cl.bbox_from_result(result).coords: for coord in cl.bbox_from_result(result).coords:
@@ -214,20 +211,20 @@ def format_base_geocodejson(results: Union[ReverseResults, SearchResults],
out = JsonWriter() out = JsonWriter()
out.start_object()\ out.start_object()\
.keyval('type', 'FeatureCollection')\ .keyval('type', 'FeatureCollection')\
.key('geocoding').start_object()\ .key('geocoding').start_object()\
.keyval('version', '0.1.0')\ .keyval('version', '0.1.0')\
.keyval('attribution', cl.OSM_ATTRIBUTION)\ .keyval('attribution', cl.OSM_ATTRIBUTION)\
.keyval('licence', 'ODbL')\ .keyval('licence', 'ODbL')\
.keyval_not_none('query', options.get('query'))\ .keyval_not_none('query', options.get('query'))\
.end_object().next()\ .end_object().next()\
.key('features').start_array() .key('features').start_array()
for result in results: for result in results:
out.start_object()\ out.start_object()\
.keyval('type', 'Feature')\ .keyval('type', 'Feature')\
.key('properties').start_object()\ .key('properties').start_object()\
.key('geocoding').start_object() .key('geocoding').start_object()
out.keyval_not_none('place_id', result.place_id) out.keyval_not_none('place_id', result.place_id)

View File

@@ -15,7 +15,6 @@ from ..results import AddressLines, ReverseResult, ReverseResults, \
SearchResult, SearchResults SearchResult, SearchResults
from . import classtypes as cl from . import classtypes as cl
#pylint: disable=too-many-branches
def _write_xml_address(root: ET.Element, address: AddressLines, def _write_xml_address(root: ET.Element, address: AddressLines,
country_code: Optional[str]) -> None: country_code: Optional[str]) -> None:
@@ -30,7 +29,7 @@ def _write_xml_address(root: ET.Element, address: AddressLines,
if line.names and 'ISO3166-2' in line.names and line.admin_level: if line.names and 'ISO3166-2' in line.names and line.admin_level:
parts[f"ISO3166-2-lvl{line.admin_level}"] = line.names['ISO3166-2'] parts[f"ISO3166-2-lvl{line.admin_level}"] = line.names['ISO3166-2']
for k,v in parts.items(): for k, v in parts.items():
ET.SubElement(root, k).text = v ET.SubElement(root, k).text = v
if country_code: if country_code:
@@ -120,7 +119,7 @@ def format_base_xml(results: Union[ReverseResults, SearchResults],
if options.get('namedetails', False): if options.get('namedetails', False):
eroot = ET.SubElement(root if simple else place, 'namedetails') eroot = ET.SubElement(root if simple else place, 'namedetails')
if result.names: if result.names:
for k,v in result.names.items(): for k, v in result.names.items():
ET.SubElement(eroot, 'name', attrib={'desc': k}).text = v ET.SubElement(eroot, 'name', attrib={'desc': k}).text = v
return '<?xml version="1.0" encoding="UTF-8" ?>\n' + ET.tostring(root, encoding='unicode') return '<?xml version="1.0" encoding="UTF-8" ?>\n' + ET.tostring(root, encoding='unicode')

View File

@@ -15,6 +15,7 @@ import re
from ..results import SearchResult, SearchResults, SourceTable from ..results import SearchResult, SearchResults, SourceTable
from ..types import SearchDetails, GeometryFormat from ..types import SearchDetails, GeometryFormat
REVERSE_MAX_RANKS = [2, 2, 2, # 0-2 Continent/Sea REVERSE_MAX_RANKS = [2, 2, 2, # 0-2 Continent/Sea
4, 4, # 3-4 Country 4, 4, # 3-4 Country
8, # 5 State 8, # 5 State
@@ -28,7 +29,7 @@ REVERSE_MAX_RANKS = [2, 2, 2, # 0-2 Continent/Sea
26, # 16 Major Streets 26, # 16 Major Streets
27, # 17 Minor Streets 27, # 17 Minor Streets
30 # 18 Building 30 # 18 Building
] ]
def zoom_to_rank(zoom: int) -> int: def zoom_to_rank(zoom: int) -> int:
@@ -52,7 +53,6 @@ def feature_type_to_rank(feature_type: Optional[str]) -> Tuple[int, int]:
return FEATURE_TYPE_TO_RANK.get(feature_type, (0, 30)) return FEATURE_TYPE_TO_RANK.get(feature_type, (0, 30))
#pylint: disable=too-many-arguments,too-many-branches
def extend_query_parts(queryparts: Dict[str, Any], details: Dict[str, Any], def extend_query_parts(queryparts: Dict[str, Any], details: Dict[str, Any],
feature_type: Optional[str], feature_type: Optional[str],
namedetails: bool, extratags: bool, namedetails: bool, extratags: bool,
@@ -135,15 +135,18 @@ def _is_postcode_relation_for(result: SearchResult, postcode: str) -> bool:
and result.names.get('ref') == postcode and result.names.get('ref') == postcode
def _deg(axis:str) -> str: def _deg(axis: str) -> str:
return f"(?P<{axis}_deg>\\d+\\.\\d+)°?" return f"(?P<{axis}_deg>\\d+\\.\\d+)°?"
def _deg_min(axis: str) -> str: def _deg_min(axis: str) -> str:
return f"(?P<{axis}_deg>\\d+)[°\\s]+(?P<{axis}_min>[\\d.]+)[']*" return f"(?P<{axis}_deg>\\d+)[°\\s]+(?P<{axis}_min>[\\d.]+)[']*"
def _deg_min_sec(axis: str) -> str: def _deg_min_sec(axis: str) -> str:
return f"(?P<{axis}_deg>\\d+)[°\\s]+(?P<{axis}_min>\\d+)['\\s]+(?P<{axis}_sec>[\\d.]+)[\"″]*" return f"(?P<{axis}_deg>\\d+)[°\\s]+(?P<{axis}_min>\\d+)['\\s]+(?P<{axis}_sec>[\\d.]+)[\"″]*"
COORD_REGEX = [re.compile(r'(?:(?P<pre>.*?)\s+)??' + r + r'(?:\s+(?P<post>.*))?') for r in ( COORD_REGEX = [re.compile(r'(?:(?P<pre>.*?)\s+)??' + r + r'(?:\s+(?P<post>.*))?') for r in (
r"(?P<ns>[NS])\s*" + _deg('lat') + r"[\s,]+" + r"(?P<ew>[EW])\s*" + _deg('lon'), r"(?P<ns>[NS])\s*" + _deg('lat') + r"[\s,]+" + r"(?P<ew>[EW])\s*" + _deg('lon'),
_deg('lat') + r"\s*(?P<ns>[NS])[\s,]+" + _deg('lon') + r"\s*(?P<ew>[EW])", _deg('lat') + r"\s*(?P<ns>[NS])[\s,]+" + _deg('lon') + r"\s*(?P<ew>[EW])",
@@ -154,6 +157,7 @@ COORD_REGEX = [re.compile(r'(?:(?P<pre>.*?)\s+)??' + r + r'(?:\s+(?P<post>.*))?'
r"\[?(?P<lat_deg>[+-]?\d+\.\d+)[\s,]+(?P<lon_deg>[+-]?\d+\.\d+)\]?" r"\[?(?P<lat_deg>[+-]?\d+\.\d+)[\s,]+(?P<lon_deg>[+-]?\d+\.\d+)\]?"
)] )]
def extract_coords_from_query(query: str) -> Tuple[str, Optional[float], Optional[float]]: def extract_coords_from_query(query: str) -> Tuple[str, Optional[float], Optional[float]]:
""" Look for something that is formatted like a coordinate at the """ Look for something that is formatted like a coordinate at the
beginning or end of the query. If found, extract the coordinate and beginning or end of the query. If found, extract the coordinate and
@@ -185,6 +189,7 @@ def extract_coords_from_query(query: str) -> Tuple[str, Optional[float], Optiona
CATEGORY_REGEX = re.compile(r'(?P<pre>.*?)\[(?P<cls>[a-zA-Z_]+)=(?P<typ>[a-zA-Z_]+)\](?P<post>.*)') CATEGORY_REGEX = re.compile(r'(?P<pre>.*?)\[(?P<cls>[a-zA-Z_]+)=(?P<typ>[a-zA-Z_]+)\](?P<post>.*)')
def extract_category_from_query(query: str) -> Tuple[str, Optional[str], Optional[str]]: def extract_category_from_query(query: str) -> Tuple[str, Optional[str], Optional[str]]:
""" Extract a hidden category specification of the form '[key=value]' from """ Extract a hidden category specification of the form '[key=value]' from
the query. If found, extract key and value and the query. If found, extract key and value and

View File

@@ -27,6 +27,7 @@ from . import helpers
from ..server import content_types as ct from ..server import content_types as ct
from ..server.asgi_adaptor import ASGIAdaptor from ..server.asgi_adaptor import ASGIAdaptor
def build_response(adaptor: ASGIAdaptor, output: str, status: int = 200, def build_response(adaptor: ASGIAdaptor, output: str, status: int = 200,
num_results: int = 0) -> Any: num_results: int = 0) -> Any:
""" Create a response from the given output. Wraps a JSONP function """ Create a response from the given output. Wraps a JSONP function
@@ -47,8 +48,8 @@ def get_accepted_languages(adaptor: ASGIAdaptor) -> str:
""" Return the accepted languages. """ Return the accepted languages.
""" """
return adaptor.get('accept-language')\ return adaptor.get('accept-language')\
or adaptor.get_header('accept-language')\ or adaptor.get_header('accept-language')\
or adaptor.config().DEFAULT_LANGUAGE or adaptor.config().DEFAULT_LANGUAGE
def setup_debugging(adaptor: ASGIAdaptor) -> bool: def setup_debugging(adaptor: ASGIAdaptor) -> bool:
@@ -88,7 +89,7 @@ def parse_format(adaptor: ASGIAdaptor, result_type: Type[Any], default: str) ->
if not formatting.supports_format(result_type, fmt): if not formatting.supports_format(result_type, fmt):
adaptor.raise_error("Parameter 'format' must be one of: " + adaptor.raise_error("Parameter 'format' must be one of: " +
', '.join(formatting.list_formats(result_type))) ', '.join(formatting.list_formats(result_type)))
adaptor.content_type = formatting.get_content_type(fmt) adaptor.content_type = formatting.get_content_type(fmt)
return fmt return fmt
@@ -119,7 +120,7 @@ def parse_geometry_details(adaptor: ASGIAdaptor, fmt: str) -> Dict[str, Any]:
return {'address_details': True, return {'address_details': True,
'geometry_simplification': adaptor.get_float('polygon_threshold', 0.0), 'geometry_simplification': adaptor.get_float('polygon_threshold', 0.0),
'geometry_output': output 'geometry_output': output
} }
async def status_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any: async def status_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
@@ -135,7 +136,7 @@ async def status_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
status_code = 200 status_code = 200
return build_response(params, params.formatting().format_result(result, fmt, {}), return build_response(params, params.formatting().format_result(result, fmt, {}),
status=status_code) status=status_code)
async def details_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any: async def details_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
@@ -161,11 +162,11 @@ async def details_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
linked_places=params.get_bool('linkedplaces', True), linked_places=params.get_bool('linkedplaces', True),
parented_places=params.get_bool('hierarchy', False), parented_places=params.get_bool('hierarchy', False),
keywords=params.get_bool('keywords', False), keywords=params.get_bool('keywords', False),
geometry_output = GeometryFormat.GEOJSON geometry_output=(GeometryFormat.GEOJSON
if params.get_bool('polygon_geojson', False) if params.get_bool('polygon_geojson', False)
else GeometryFormat.NONE, else GeometryFormat.NONE),
locales=locales locales=locales
) )
if debug: if debug:
return build_response(params, loglib.get_and_disable()) return build_response(params, loglib.get_and_disable())
@@ -173,10 +174,11 @@ async def details_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
if result is None: if result is None:
params.raise_error('No place with that OSM ID found.', status=404) params.raise_error('No place with that OSM ID found.', status=404)
output = params.formatting().format_result(result, fmt, output = params.formatting().format_result(
{'locales': locales, result, fmt,
'group_hierarchy': params.get_bool('group_hierarchy', False), {'locales': locales,
'icon_base_url': params.config().MAPICON_URL}) 'group_hierarchy': params.get_bool('group_hierarchy', False),
'icon_base_url': params.config().MAPICON_URL})
return build_response(params, output, num_results=1) return build_response(params, output, num_results=1)
@@ -253,7 +255,7 @@ async def lookup_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
async def _unstructured_search(query: str, api: NominatimAPIAsync, async def _unstructured_search(query: str, api: NominatimAPIAsync,
details: Dict[str, Any]) -> SearchResults: details: Dict[str, Any]) -> SearchResults:
if not query: if not query:
return SearchResults() return SearchResults()
@@ -290,15 +292,15 @@ async def search_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
debug = setup_debugging(params) debug = setup_debugging(params)
details = parse_geometry_details(params, fmt) details = parse_geometry_details(params, fmt)
details['countries'] = params.get('countrycodes', None) details['countries'] = params.get('countrycodes', None)
details['excluded'] = params.get('exclude_place_ids', None) details['excluded'] = params.get('exclude_place_ids', None)
details['viewbox'] = params.get('viewbox', None) or params.get('viewboxlbrt', None) details['viewbox'] = params.get('viewbox', None) or params.get('viewboxlbrt', None)
details['bounded_viewbox'] = params.get_bool('bounded', False) details['bounded_viewbox'] = params.get_bool('bounded', False)
details['dedupe'] = params.get_bool('dedupe', True) details['dedupe'] = params.get_bool('dedupe', True)
max_results = max(1, min(50, params.get_int('limit', 10))) max_results = max(1, min(50, params.get_int('limit', 10)))
details['max_results'] = max_results + min(10, max_results) \ details['max_results'] = (max_results + min(10, max_results)
if details['dedupe'] else max_results if details['dedupe'] else max_results)
details['min_rank'], details['max_rank'] = \ details['min_rank'], details['max_rank'] = \
helpers.feature_type_to_rank(params.get('featureType', '')) helpers.feature_type_to_rank(params.get('featureType', ''))

View File

@@ -25,6 +25,7 @@ from .clicmd.args import NominatimArgs, Subcommand
LOG = logging.getLogger() LOG = logging.getLogger()
class CommandlineParser: class CommandlineParser:
""" Wraps some of the common functions for parsing the command line """ Wraps some of the common functions for parsing the command line
and setting up subcommands. and setting up subcommands.
@@ -57,7 +58,6 @@ class CommandlineParser:
group.add_argument('-j', '--threads', metavar='NUM', type=int, group.add_argument('-j', '--threads', metavar='NUM', type=int,
help='Number of parallel threads to use') help='Number of parallel threads to use')
def nominatim_version_text(self) -> str: def nominatim_version_text(self) -> str:
""" Program name and version number as string """ Program name and version number as string
""" """
@@ -66,7 +66,6 @@ class CommandlineParser:
text += f' ({version.GIT_COMMIT_HASH})' text += f' ({version.GIT_COMMIT_HASH})'
return text return text
def add_subcommand(self, name: str, cmd: Subcommand) -> None: def add_subcommand(self, name: str, cmd: Subcommand) -> None:
""" Add a subcommand to the parser. The subcommand must be a class """ Add a subcommand to the parser. The subcommand must be a class
with a function add_args() that adds the parameters for the with a function add_args() that adds the parameters for the
@@ -82,7 +81,6 @@ class CommandlineParser:
parser.set_defaults(command=cmd) parser.set_defaults(command=cmd)
cmd.add_args(parser) cmd.add_args(parser)
def run(self, **kwargs: Any) -> int: def run(self, **kwargs: Any) -> int:
""" Parse the command line arguments of the program and execute the """ Parse the command line arguments of the program and execute the
appropriate subcommand. appropriate subcommand.
@@ -122,7 +120,7 @@ class CommandlineParser:
return ret return ret
except UsageError as exception: except UsageError as exception:
if log.isEnabledFor(logging.DEBUG): if log.isEnabledFor(logging.DEBUG):
raise # use Python's exception printing raise # use Python's exception printing
log.fatal('FATAL: %s', exception) log.fatal('FATAL: %s', exception)
# If we get here, then execution has failed in some way. # If we get here, then execution has failed in some way.
@@ -139,7 +137,6 @@ class CommandlineParser:
# a subcommand. # a subcommand.
# #
# No need to document the functions each time. # No need to document the functions each time.
# pylint: disable=C0111
class AdminServe: class AdminServe:
"""\ """\
Start a simple web server for serving the API. Start a simple web server for serving the API.
@@ -164,15 +161,13 @@ class AdminServe:
choices=('falcon', 'starlette'), choices=('falcon', 'starlette'),
help='Webserver framework to run. (default: falcon)') help='Webserver framework to run. (default: falcon)')
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
asyncio.run(self.run_uvicorn(args)) asyncio.run(self.run_uvicorn(args))
return 0 return 0
async def run_uvicorn(self, args: NominatimArgs) -> None: async def run_uvicorn(self, args: NominatimArgs) -> None:
import uvicorn # pylint: disable=import-outside-toplevel import uvicorn
server_info = args.server.split(':', 1) server_info = args.server.split(':', 1)
host = server_info[0] host = server_info[0]
@@ -226,7 +221,7 @@ def get_set_parser() -> CommandlineParser:
parser.add_subcommand('details', apicmd.APIDetails()) parser.add_subcommand('details', apicmd.APIDetails())
parser.add_subcommand('status', apicmd.APIStatus()) parser.add_subcommand('status', apicmd.APIStatus())
except ModuleNotFoundError as ex: except ModuleNotFoundError as ex:
if not ex.name or 'nominatim_api' not in ex.name: # pylint: disable=E1135 if not ex.name or 'nominatim_api' not in ex.name:
raise ex raise ex
parser.parser.epilog = \ parser.parser.epilog = \
@@ -235,7 +230,6 @@ def get_set_parser() -> CommandlineParser:
'\n export, convert, serve, search, reverse, lookup, details, status'\ '\n export, convert, serve, search, reverse, lookup, details, status'\
"\n\nRun 'pip install nominatim-api' to install the package." "\n\nRun 'pip install nominatim-api' to install the package."
return parser return parser

View File

@@ -18,13 +18,10 @@ from .args import NominatimArgs
from ..db.connection import connect from ..db.connection import connect
from ..tools.freeze import is_frozen from ..tools.freeze import is_frozen
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=E0012,C0415
LOG = logging.getLogger() LOG = logging.getLogger()
class UpdateAddData: class UpdateAddData:
"""\ """\
Add additional data from a file or an online source. Add additional data from a file or an online source.
@@ -65,7 +62,6 @@ class UpdateAddData:
group2.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60, group2.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60,
help='Set timeout for file downloads') help='Set timeout for file downloads')
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
from ..tools import add_osm_data from ..tools import add_osm_data
@@ -103,7 +99,6 @@ class UpdateAddData:
return 0 return 0
async def _add_tiger_data(self, args: NominatimArgs) -> int: async def _add_tiger_data(self, args: NominatimArgs) -> int:
from ..tokenizer import factory as tokenizer_factory from ..tokenizer import factory as tokenizer_factory
from ..tools import tiger_data from ..tools import tiger_data
@@ -113,5 +108,5 @@ class UpdateAddData:
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
return await tiger_data.add_tiger_data(args.tiger_data, return await tiger_data.add_tiger_data(args.tiger_data,
args.config, args.config,
args.threads or psutil.cpu_count() or 1, args.threads or psutil.cpu_count() or 1,
tokenizer) tokenizer)

View File

@@ -57,7 +57,6 @@ class AdminFuncs:
mgroup.add_argument('--place-id', type=int, mgroup.add_argument('--place-id', type=int,
help='Analyse indexing of the given Nominatim object') help='Analyse indexing of the given Nominatim object')
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
if args.warm: if args.warm:
@@ -93,7 +92,6 @@ class AdminFuncs:
return 1 return 1
def _warm(self, args: NominatimArgs) -> int: def _warm(self, args: NominatimArgs) -> int:
try: try:
import nominatim_api as napi import nominatim_api as napi

View File

@@ -22,11 +22,10 @@ import nominatim_api.logging as loglib
from ..errors import UsageError from ..errors import UsageError
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
LOG = logging.getLogger() LOG = logging.getLogger()
STRUCTURED_QUERY = ( STRUCTURED_QUERY = (
('amenity', 'name and/or type of POI'), ('amenity', 'name and/or type of POI'),
('street', 'housenumber and street'), ('street', 'housenumber and street'),
@@ -37,6 +36,7 @@ STRUCTURED_QUERY = (
('postalcode', 'postcode') ('postalcode', 'postcode')
) )
EXTRADATA_PARAMS = ( EXTRADATA_PARAMS = (
('addressdetails', 'Include a breakdown of the address into elements'), ('addressdetails', 'Include a breakdown of the address into elements'),
('extratags', ("Include additional information if available " ('extratags', ("Include additional information if available "
@@ -44,6 +44,7 @@ EXTRADATA_PARAMS = (
('namedetails', 'Include a list of alternative names') ('namedetails', 'Include a list of alternative names')
) )
def _add_list_format(parser: argparse.ArgumentParser) -> None: def _add_list_format(parser: argparse.ArgumentParser) -> None:
group = parser.add_argument_group('Other options') group = parser.add_argument_group('Other options')
group.add_argument('--list-formats', action='store_true', group.add_argument('--list-formats', action='store_true',
@@ -62,7 +63,7 @@ def _add_api_output_arguments(parser: argparse.ArgumentParser) -> None:
group.add_argument('--polygon-output', group.add_argument('--polygon-output',
choices=['geojson', 'kml', 'svg', 'text'], choices=['geojson', 'kml', 'svg', 'text'],
help='Output geometry of results as a GeoJSON, KML, SVG or WKT') help='Output geometry of results as a GeoJSON, KML, SVG or WKT')
group.add_argument('--polygon-threshold', type=float, default = 0.0, group.add_argument('--polygon-threshold', type=float, default=0.0,
metavar='TOLERANCE', metavar='TOLERANCE',
help=("Simplify output geometry." help=("Simplify output geometry."
"Parameter is difference tolerance in degrees.")) "Parameter is difference tolerance in degrees."))
@@ -173,7 +174,6 @@ class APISearch:
help='Do not remove duplicates from the result list') help='Do not remove duplicates from the result list')
_add_list_format(parser) _add_list_format(parser)
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
formatter = napi.load_format_dispatcher('v1', args.project_dir) formatter = napi.load_format_dispatcher('v1', args.project_dir)
@@ -189,7 +189,7 @@ class APISearch:
try: try:
with napi.NominatimAPI(args.project_dir) as api: with napi.NominatimAPI(args.project_dir) as api:
params: Dict[str, Any] = {'max_results': args.limit + min(args.limit, 10), params: Dict[str, Any] = {'max_results': args.limit + min(args.limit, 10),
'address_details': True, # needed for display name 'address_details': True, # needed for display name
'geometry_output': _get_geometry_output(args), 'geometry_output': _get_geometry_output(args),
'geometry_simplification': args.polygon_threshold, 'geometry_simplification': args.polygon_threshold,
'countries': args.countrycodes, 'countries': args.countrycodes,
@@ -197,7 +197,7 @@ class APISearch:
'viewbox': args.viewbox, 'viewbox': args.viewbox,
'bounded_viewbox': args.bounded, 'bounded_viewbox': args.bounded,
'locales': _get_locales(args, api.config.DEFAULT_LANGUAGE) 'locales': _get_locales(args, api.config.DEFAULT_LANGUAGE)
} }
if args.query: if args.query:
results = api.search(args.query, **params) results = api.search(args.query, **params)
@@ -253,7 +253,6 @@ class APIReverse:
_add_api_output_arguments(parser) _add_api_output_arguments(parser)
_add_list_format(parser) _add_list_format(parser)
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
formatter = napi.load_format_dispatcher('v1', args.project_dir) formatter = napi.load_format_dispatcher('v1', args.project_dir)
@@ -276,7 +275,7 @@ class APIReverse:
result = api.reverse(napi.Point(args.lon, args.lat), result = api.reverse(napi.Point(args.lon, args.lat),
max_rank=zoom_to_rank(args.zoom or 18), max_rank=zoom_to_rank(args.zoom or 18),
layers=layers, layers=layers,
address_details=True, # needed for display name address_details=True, # needed for display name
geometry_output=_get_geometry_output(args), geometry_output=_get_geometry_output(args),
geometry_simplification=args.polygon_threshold, geometry_simplification=args.polygon_threshold,
locales=_get_locales(args, api.config.DEFAULT_LANGUAGE)) locales=_get_locales(args, api.config.DEFAULT_LANGUAGE))
@@ -299,7 +298,6 @@ class APIReverse:
return 42 return 42
class APILookup: class APILookup:
"""\ """\
Execute API lookup query. Execute API lookup query.
@@ -319,7 +317,6 @@ class APILookup:
_add_api_output_arguments(parser) _add_api_output_arguments(parser)
_add_list_format(parser) _add_list_format(parser)
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
formatter = napi.load_format_dispatcher('v1', args.project_dir) formatter = napi.load_format_dispatcher('v1', args.project_dir)
@@ -340,7 +337,7 @@ class APILookup:
try: try:
with napi.NominatimAPI(args.project_dir) as api: with napi.NominatimAPI(args.project_dir) as api:
results = api.lookup(places, results = api.lookup(places,
address_details=True, # needed for display name address_details=True, # needed for display name
geometry_output=_get_geometry_output(args), geometry_output=_get_geometry_output(args),
geometry_simplification=args.polygon_threshold or 0.0, geometry_simplification=args.polygon_threshold or 0.0,
locales=_get_locales(args, api.config.DEFAULT_LANGUAGE)) locales=_get_locales(args, api.config.DEFAULT_LANGUAGE))
@@ -401,7 +398,6 @@ class APIDetails:
help='Preferred language order for presenting search results') help='Preferred language order for presenting search results')
_add_list_format(parser) _add_list_format(parser)
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
formatter = napi.load_format_dispatcher('v1', args.project_dir) formatter = napi.load_format_dispatcher('v1', args.project_dir)
@@ -421,7 +417,7 @@ class APIDetails:
place = napi.OsmID('W', args.way, args.object_class) place = napi.OsmID('W', args.way, args.object_class)
elif args.relation: elif args.relation:
place = napi.OsmID('R', args.relation, args.object_class) place = napi.OsmID('R', args.relation, args.object_class)
elif args.place_id is not None: elif args.place_id is not None:
place = napi.PlaceID(args.place_id) place = napi.PlaceID(args.place_id)
else: else:
raise UsageError('One of the arguments --node/-n --way/-w ' raise UsageError('One of the arguments --node/-n --way/-w '
@@ -435,10 +431,10 @@ class APIDetails:
linked_places=args.linkedplaces, linked_places=args.linkedplaces,
parented_places=args.hierarchy, parented_places=args.hierarchy,
keywords=args.keywords, keywords=args.keywords,
geometry_output=napi.GeometryFormat.GEOJSON geometry_output=(napi.GeometryFormat.GEOJSON
if args.polygon_geojson if args.polygon_geojson
else napi.GeometryFormat.NONE, else napi.GeometryFormat.NONE),
locales=locales) locales=locales)
except napi.UsageError as ex: except napi.UsageError as ex:
raise UsageError(ex) from ex raise UsageError(ex) from ex
@@ -472,7 +468,6 @@ class APIStatus:
help='Format of result (use --list-formats to see supported formats)') help='Format of result (use --list-formats to see supported formats)')
_add_list_format(parser) _add_list_format(parser)
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
formatter = napi.load_format_dispatcher('v1', args.project_dir) formatter = napi.load_format_dispatcher('v1', args.project_dir)

View File

@@ -16,8 +16,10 @@ from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..typing import Protocol from ..typing import Protocol
LOG = logging.getLogger() LOG = logging.getLogger()
class Subcommand(Protocol): class Subcommand(Protocol):
""" """
Interface to be implemented by classes implementing a CLI subcommand. Interface to be implemented by classes implementing a CLI subcommand.
@@ -178,7 +180,6 @@ class NominatimArgs:
polygon_geojson: bool polygon_geojson: bool
group_hierarchy: bool group_hierarchy: bool
def osm2pgsql_options(self, default_cache: int, def osm2pgsql_options(self, default_cache: int,
default_threads: int) -> Dict[str, Any]: default_threads: int) -> Dict[str, Any]:
""" Return the standard osm2pgsql options that can be derived """ Return the standard osm2pgsql options that can be derived
@@ -196,9 +197,8 @@ class NominatimArgs:
slim_index=self.config.TABLESPACE_OSM_INDEX, slim_index=self.config.TABLESPACE_OSM_INDEX,
main_data=self.config.TABLESPACE_PLACE_DATA, main_data=self.config.TABLESPACE_PLACE_DATA,
main_index=self.config.TABLESPACE_PLACE_INDEX main_index=self.config.TABLESPACE_PLACE_INDEX
) )
) )
def get_osm_file_list(self) -> Optional[List[Path]]: def get_osm_file_list(self) -> Optional[List[Path]]:
""" Return the --osm-file argument as a list of Paths or None """ Return the --osm-file argument as a list of Paths or None

View File

@@ -15,10 +15,6 @@ from pathlib import Path
from ..errors import UsageError from ..errors import UsageError
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=E0012,C0415
class WithAction(argparse.Action): class WithAction(argparse.Action):
""" Special action that saves a list of flags, given on the command-line """ Special action that saves a list of flags, given on the command-line
@@ -43,7 +39,6 @@ class WithAction(argparse.Action):
super().__init__(full_option_strings, argparse.SUPPRESS, nargs=0, **kwargs) super().__init__(full_option_strings, argparse.SUPPRESS, nargs=0, **kwargs)
def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespace,
values: Union[str, Sequence[Any], None], values: Union[str, Sequence[Any], None],
option_string: Optional[str] = None) -> None: option_string: Optional[str] = None) -> None:
@@ -81,7 +76,6 @@ class ConvertDB:
group.add_argument('--details', action=WithAction, dest_set=self.options, default=True, group.add_argument('--details', action=WithAction, dest_set=self.options, default=True,
help='Enable/disable support for details API (default: enabled)') help='Enable/disable support for details API (default: enabled)')
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
if args.output.exists(): if args.output.exists():
raise UsageError(f"File '{args.output}' already exists. Refusing to overwrite.") raise UsageError(f"File '{args.output}' already exists. Refusing to overwrite.")

View File

@@ -18,20 +18,15 @@ import nominatim_api as napi
from nominatim_api.results import create_from_placex_row, ReverseResult, add_result_details from nominatim_api.results import create_from_placex_row, ReverseResult, add_result_details
from nominatim_api.types import LookupDetails from nominatim_api.types import LookupDetails
import sqlalchemy as sa # pylint: disable=C0411 import sqlalchemy as sa
from ..errors import UsageError from ..errors import UsageError
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=E0012,C0415
# Needed for SQLAlchemy
# pylint: disable=singleton-comparison
LOG = logging.getLogger() LOG = logging.getLogger()
RANK_RANGE_MAP = { RANK_RANGE_MAP = {
'country': (4, 4), 'country': (4, 4),
'state': (5, 9), 'state': (5, 9),
@@ -42,6 +37,7 @@ RANK_RANGE_MAP = {
'path': (27, 27) 'path': (27, 27)
} }
RANK_TO_OUTPUT_MAP = { RANK_TO_OUTPUT_MAP = {
4: 'country', 4: 'country',
5: 'state', 6: 'state', 7: 'state', 8: 'state', 9: 'state', 5: 'state', 6: 'state', 7: 'state', 8: 'state', 9: 'state',
@@ -50,6 +46,7 @@ RANK_TO_OUTPUT_MAP = {
17: 'suburb', 18: 'suburb', 19: 'suburb', 20: 'suburb', 21: 'suburb', 17: 'suburb', 18: 'suburb', 19: 'suburb', 20: 'suburb', 21: 'suburb',
26: 'street', 27: 'path'} 26: 'street', 27: 'path'}
class QueryExport: class QueryExport:
"""\ """\
Export places as CSV file from the database. Export places as CSV file from the database.
@@ -84,7 +81,6 @@ class QueryExport:
dest='relation', dest='relation',
help='Export only children of this OSM relation') help='Export only children of this OSM relation')
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
return asyncio.run(export(args)) return asyncio.run(export(args))
@@ -104,15 +100,15 @@ async def export(args: NominatimArgs) -> int:
t = conn.t.placex t = conn.t.placex
sql = sa.select(t.c.place_id, t.c.parent_place_id, sql = sa.select(t.c.place_id, t.c.parent_place_id,
t.c.osm_type, t.c.osm_id, t.c.name, t.c.osm_type, t.c.osm_id, t.c.name,
t.c.class_, t.c.type, t.c.admin_level, t.c.class_, t.c.type, t.c.admin_level,
t.c.address, t.c.extratags, t.c.address, t.c.extratags,
t.c.housenumber, t.c.postcode, t.c.country_code, t.c.housenumber, t.c.postcode, t.c.country_code,
t.c.importance, t.c.wikipedia, t.c.indexed_date, t.c.importance, t.c.wikipedia, t.c.indexed_date,
t.c.rank_address, t.c.rank_search, t.c.rank_address, t.c.rank_search,
t.c.centroid)\ t.c.centroid)\
.where(t.c.linked_place_id == None)\ .where(t.c.linked_place_id == None)\
.where(t.c.rank_address.between(*output_range)) .where(t.c.rank_address.between(*output_range))
parent_place_id = await get_parent_id(conn, args.node, args.way, args.relation) parent_place_id = await get_parent_id(conn, args.node, args.way, args.relation)
if parent_place_id: if parent_place_id:
@@ -159,7 +155,6 @@ async def dump_results(conn: napi.SearchConnection,
await add_result_details(conn, results, await add_result_details(conn, results,
LookupDetails(address_details=True, locales=locale)) LookupDetails(address_details=True, locales=locale))
for result in results: for result in results:
data = {'placeid': result.place_id, data = {'placeid': result.place_id,
'postcode': result.postcode} 'postcode': result.postcode}

View File

@@ -12,10 +12,6 @@ import argparse
from ..db.connection import connect from ..db.connection import connect
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=E0012,C0415
class SetupFreeze: class SetupFreeze:
"""\ """\
@@ -30,8 +26,7 @@ class SetupFreeze:
""" """
def add_args(self, parser: argparse.ArgumentParser) -> None: def add_args(self, parser: argparse.ArgumentParser) -> None:
pass # No options pass # No options
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
from ..tools import freeze from ..tools import freeze

View File

@@ -16,11 +16,6 @@ from ..db import status
from ..db.connection import connect from ..db.connection import connect
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=E0012,C0415
class UpdateIndex: class UpdateIndex:
"""\ """\
@@ -43,7 +38,6 @@ class UpdateIndex:
group.add_argument('--maxrank', '-R', type=int, metavar='RANK', default=30, group.add_argument('--maxrank', '-R', type=int, metavar='RANK', default=30,
help='Maximum/finishing rank') help='Maximum/finishing rank')
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
asyncio.run(self._do_index(args)) asyncio.run(self._do_index(args))
@@ -54,7 +48,6 @@ class UpdateIndex:
return 0 return 0
async def _do_index(self, args: NominatimArgs) -> None: async def _do_index(self, args: NominatimArgs) -> None:
from ..tokenizer import factory as tokenizer_factory from ..tokenizer import factory as tokenizer_factory
@@ -64,7 +57,7 @@ class UpdateIndex:
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
args.threads or psutil.cpu_count() or 1) args.threads or psutil.cpu_count() or 1)
has_pending = True # run at least once has_pending = True # run at least once
while has_pending: while has_pending:
if not args.no_boundaries: if not args.no_boundaries:
await indexer.index_boundaries(args.minrank, args.maxrank) await indexer.index_boundaries(args.minrank, args.maxrank)

View File

@@ -18,13 +18,10 @@ from ..db.connection import connect, table_exists
from ..tokenizer.base import AbstractTokenizer from ..tokenizer.base import AbstractTokenizer
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=E0012,C0415
LOG = logging.getLogger() LOG = logging.getLogger()
def _parse_osm_object(obj: str) -> Tuple[str, int]: def _parse_osm_object(obj: str) -> Tuple[str, int]:
""" Parse the given argument into a tuple of OSM type and ID. """ Parse the given argument into a tuple of OSM type and ID.
Raises an ArgumentError if the format is not recognized. Raises an ArgumentError if the format is not recognized.
@@ -86,8 +83,7 @@ class UpdateRefresh:
group.add_argument('--enable-debug-statements', action='store_true', group.add_argument('--enable-debug-statements', action='store_true',
help='Enable debug warning statements in functions') help='Enable debug warning statements in functions')
def run(self, args: NominatimArgs) -> int:
def run(self, args: NominatimArgs) -> int: #pylint: disable=too-many-branches, too-many-statements
from ..tools import refresh, postcodes from ..tools import refresh, postcodes
from ..indexer.indexer import Indexer from ..indexer.indexer import Indexer
@@ -131,7 +127,7 @@ class UpdateRefresh:
LOG.warning('Import secondary importance raster data from %s', args.project_dir) LOG.warning('Import secondary importance raster data from %s', args.project_dir)
if refresh.import_secondary_importance(args.config.get_libpq_dsn(), if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
args.project_dir) > 0: args.project_dir) > 0:
LOG.fatal('FATAL: Cannot update secondary importance raster data') LOG.fatal('FATAL: Cannot update secondary importance raster data')
return 1 return 1
need_function_refresh = True need_function_refresh = True
@@ -173,7 +169,6 @@ class UpdateRefresh:
return 0 return 0
def _get_tokenizer(self, config: Configuration) -> AbstractTokenizer: def _get_tokenizer(self, config: Configuration) -> AbstractTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
from ..tokenizer import factory as tokenizer_factory from ..tokenizer import factory as tokenizer_factory

View File

@@ -22,10 +22,6 @@ from .args import NominatimArgs
LOG = logging.getLogger() LOG = logging.getLogger()
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to make pyosmium optional for replication only.
# pylint: disable=C0415
class UpdateReplication: class UpdateReplication:
"""\ """\
@@ -71,7 +67,6 @@ class UpdateReplication:
group.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60, group.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60,
help='Set timeout for file downloads') help='Set timeout for file downloads')
def _init_replication(self, args: NominatimArgs) -> int: def _init_replication(self, args: NominatimArgs) -> int:
from ..tools import replication, refresh from ..tools import replication, refresh
@@ -84,7 +79,6 @@ class UpdateReplication:
refresh.create_functions(conn, args.config, True, False) refresh.create_functions(conn, args.config, True, False)
return 0 return 0
def _check_for_updates(self, args: NominatimArgs) -> int: def _check_for_updates(self, args: NominatimArgs) -> int:
from ..tools import replication from ..tools import replication
@@ -92,7 +86,6 @@ class UpdateReplication:
return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL, return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL,
socket_timeout=args.socket_timeout) socket_timeout=args.socket_timeout)
def _report_update(self, batchdate: dt.datetime, def _report_update(self, batchdate: dt.datetime,
start_import: dt.datetime, start_import: dt.datetime,
start_index: Optional[dt.datetime]) -> None: start_index: Optional[dt.datetime]) -> None:
@@ -106,7 +99,6 @@ class UpdateReplication:
round_time(end - start_import), round_time(end - start_import),
round_time(end - batchdate)) round_time(end - batchdate))
def _compute_update_interval(self, args: NominatimArgs) -> int: def _compute_update_interval(self, args: NominatimArgs) -> int:
if args.catch_up: if args.catch_up:
return 0 return 0
@@ -123,7 +115,6 @@ class UpdateReplication:
return update_interval return update_interval
async def _update(self, args: NominatimArgs) -> None: async def _update(self, args: NominatimArgs) -> None:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
from ..tools import replication from ..tools import replication
@@ -186,7 +177,6 @@ class UpdateReplication:
LOG.warning("No new changes. Sleeping for %d sec.", recheck_interval) LOG.warning("No new changes. Sleeping for %d sec.", recheck_interval)
time.sleep(recheck_interval) time.sleep(recheck_interval)
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
socket.setdefaulttimeout(args.socket_timeout) socket.setdefaulttimeout(args.socket_timeout)

View File

@@ -23,13 +23,10 @@ from ..tokenizer.base import AbstractTokenizer
from ..version import NOMINATIM_VERSION from ..version import NOMINATIM_VERSION
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=C0415
LOG = logging.getLogger() LOG = logging.getLogger()
class SetupAll: class SetupAll:
"""\ """\
Create a new Nominatim database from an OSM file. Create a new Nominatim database from an OSM file.
@@ -42,36 +39,35 @@ class SetupAll:
def add_args(self, parser: argparse.ArgumentParser) -> None: def add_args(self, parser: argparse.ArgumentParser) -> None:
group1 = parser.add_argument_group('Required arguments') group1 = parser.add_argument_group('Required arguments')
group1.add_argument('--osm-file', metavar='FILE', action='append', group1.add_argument('--osm-file', metavar='FILE', action='append',
help='OSM file to be imported' help='OSM file to be imported'
' (repeat for importing multiple files)', ' (repeat for importing multiple files)',
default=None) default=None)
group1.add_argument('--continue', dest='continue_at', group1.add_argument('--continue', dest='continue_at',
choices=['import-from-file', 'load-data', 'indexing', 'db-postprocess'], choices=['import-from-file', 'load-data', 'indexing', 'db-postprocess'],
help='Continue an import that was interrupted', help='Continue an import that was interrupted',
default=None) default=None)
group2 = parser.add_argument_group('Optional arguments') group2 = parser.add_argument_group('Optional arguments')
group2.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int, group2.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int,
help='Size of cache to be used by osm2pgsql (in MB)') help='Size of cache to be used by osm2pgsql (in MB)')
group2.add_argument('--reverse-only', action='store_true', group2.add_argument('--reverse-only', action='store_true',
help='Do not create tables and indexes for searching') help='Do not create tables and indexes for searching')
group2.add_argument('--no-partitions', action='store_true', group2.add_argument('--no-partitions', action='store_true',
help=("Do not partition search indices " help="Do not partition search indices "
"(speeds up import of single country extracts)")) "(speeds up import of single country extracts)")
group2.add_argument('--no-updates', action='store_true', group2.add_argument('--no-updates', action='store_true',
help="Do not keep tables that are only needed for " help="Do not keep tables that are only needed for "
"updating the database later") "updating the database later")
group2.add_argument('--offline', action='store_true', group2.add_argument('--offline', action='store_true',
help="Do not attempt to load any additional data from the internet") help="Do not attempt to load any additional data from the internet")
group3 = parser.add_argument_group('Expert options') group3 = parser.add_argument_group('Expert options')
group3.add_argument('--ignore-errors', action='store_true', group3.add_argument('--ignore-errors', action='store_true',
help='Continue import even when errors in SQL are present') help='Continue import even when errors in SQL are present')
group3.add_argument('--index-noanalyse', action='store_true', group3.add_argument('--index-noanalyse', action='store_true',
help='Do not perform analyse operations during index (expert only)') help='Do not perform analyse operations during index (expert only)')
group3.add_argument('--prepare-database', action='store_true', group3.add_argument('--prepare-database', action='store_true',
help='Create the database but do not import any data') help='Create the database but do not import any data')
def run(self, args: NominatimArgs) -> int:
def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements, too-many-branches
if args.osm_file is None and args.continue_at is None and not args.prepare_database: if args.osm_file is None and args.continue_at is None and not args.prepare_database:
raise UsageError("No input files (use --osm-file).") raise UsageError("No input files (use --osm-file).")
@@ -85,7 +81,6 @@ class SetupAll:
return asyncio.run(self.async_run(args)) return asyncio.run(self.async_run(args))
async def async_run(self, args: NominatimArgs) -> int: async def async_run(self, args: NominatimArgs) -> int:
from ..data import country_info from ..data import country_info
from ..tools import database_import, postcodes, freeze from ..tools import database_import, postcodes, freeze
@@ -97,7 +92,7 @@ class SetupAll:
if args.prepare_database or args.continue_at is None: if args.prepare_database or args.continue_at is None:
LOG.warning('Creating database') LOG.warning('Creating database')
database_import.setup_database_skeleton(args.config.get_libpq_dsn(), database_import.setup_database_skeleton(args.config.get_libpq_dsn(),
rouser=args.config.DATABASE_WEBUSER) rouser=args.config.DATABASE_WEBUSER)
if args.prepare_database: if args.prepare_database:
return 0 return 0
@@ -120,8 +115,7 @@ class SetupAll:
postcodes.update_postcodes(args.config.get_libpq_dsn(), postcodes.update_postcodes(args.config.get_libpq_dsn(),
args.project_dir, tokenizer) args.project_dir, tokenizer)
if args.continue_at in \ if args.continue_at in ('import-from-file', 'load-data', 'indexing', None):
('import-from-file', 'load-data', 'indexing', None):
LOG.warning('Indexing places') LOG.warning('Indexing places')
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads) indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads)
await indexer.index_full(analyse=not args.index_noanalyse) await indexer.index_full(analyse=not args.index_noanalyse)
@@ -145,7 +139,6 @@ class SetupAll:
return 0 return 0
def _base_import(self, args: NominatimArgs) -> None: def _base_import(self, args: NominatimArgs) -> None:
from ..tools import database_import, refresh from ..tools import database_import, refresh
from ..data import country_info from ..data import country_info
@@ -159,8 +152,8 @@ class SetupAll:
database_import.check_existing_database_plugins(args.config.get_libpq_dsn()) database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
LOG.warning('Setting up country tables') LOG.warning('Setting up country tables')
country_info.setup_country_tables(args.config.get_libpq_dsn(), country_info.setup_country_tables(args.config.get_libpq_dsn(),
args.config.lib_dir.data, args.config.lib_dir.data,
args.no_partitions) args.no_partitions)
LOG.warning('Importing OSM data file') LOG.warning('Importing OSM data file')
database_import.import_osm_data(files, database_import.import_osm_data(files,
@@ -171,20 +164,19 @@ class SetupAll:
LOG.warning('Importing wikipedia importance data') LOG.warning('Importing wikipedia importance data')
data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir) data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(), if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
data_path) > 0: data_path) > 0:
LOG.error('Wikipedia importance dump file not found. ' LOG.error('Wikipedia importance dump file not found. '
'Calculating importance values of locations will not ' 'Calculating importance values of locations will not '
'use Wikipedia importance data.') 'use Wikipedia importance data.')
LOG.warning('Importing secondary importance raster data') LOG.warning('Importing secondary importance raster data')
if refresh.import_secondary_importance(args.config.get_libpq_dsn(), if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
args.project_dir) != 0: args.project_dir) != 0:
LOG.error('Secondary importance file not imported. ' LOG.error('Secondary importance file not imported. '
'Falling back to default ranking.') 'Falling back to default ranking.')
self._setup_tables(args.config, args.reverse_only) self._setup_tables(args.config, args.reverse_only)
def _setup_tables(self, config: Configuration, reverse_only: bool) -> None: def _setup_tables(self, config: Configuration, reverse_only: bool) -> None:
""" Set up the basic database layout: tables, indexes and functions. """ Set up the basic database layout: tables, indexes and functions.
""" """
@@ -205,7 +197,6 @@ class SetupAll:
LOG.warning('Create functions (3rd pass)') LOG.warning('Create functions (3rd pass)')
refresh.create_functions(conn, config, False, False) refresh.create_functions(conn, config, False, False)
def _get_tokenizer(self, continue_at: Optional[str], def _get_tokenizer(self, continue_at: Optional[str],
config: Configuration) -> AbstractTokenizer: config: Configuration) -> AbstractTokenizer:
""" Set up a new tokenizer or load an already initialised one. """ Set up a new tokenizer or load an already initialised one.
@@ -219,7 +210,6 @@ class SetupAll:
# just load the tokenizer # just load the tokenizer
return tokenizer_factory.get_tokenizer_for_db(config) return tokenizer_factory.get_tokenizer_for_db(config)
def _finalize_database(self, dsn: str, offline: bool) -> None: def _finalize_database(self, dsn: str, offline: bool) -> None:
""" Determine the database date and set the status accordingly. """ Determine the database date and set the status accordingly.
""" """
@@ -230,5 +220,5 @@ class SetupAll:
dbdate = status.compute_database_date(conn, offline) dbdate = status.compute_database_date(conn, offline)
status.set_status(conn, dbdate) status.set_status(conn, dbdate)
LOG.info('Database is at %s.', dbdate) LOG.info('Database is at %s.', dbdate)
except Exception as exc: # pylint: disable=broad-except except Exception as exc:
LOG.error('Cannot determine date of database: %s', exc) LOG.error('Cannot determine date of database: %s', exc)

View File

@@ -18,12 +18,9 @@ from ..tools.special_phrases.sp_wiki_loader import SPWikiLoader
from ..tools.special_phrases.sp_csv_loader import SPCsvLoader from ..tools.special_phrases.sp_csv_loader import SPCsvLoader
from .args import NominatimArgs from .args import NominatimArgs
LOG = logging.getLogger() LOG = logging.getLogger()
# Do not repeat documentation of subcommand classes.
# pylint: disable=C0111
# Using non-top-level imports to avoid eventually unused imports.
# pylint: disable=E0012,C0415
class ImportSpecialPhrases: class ImportSpecialPhrases:
"""\ """\
@@ -62,7 +59,6 @@ class ImportSpecialPhrases:
group.add_argument('--no-replace', action='store_true', group.add_argument('--no-replace', action='store_true',
help='Keep the old phrases and only add the new ones') help='Keep the old phrases and only add the new ones')
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
if args.import_from_wiki: if args.import_from_wiki:
@@ -77,7 +73,6 @@ class ImportSpecialPhrases:
return 0 return 0
def start_import(self, args: NominatimArgs, loader: SpecialPhraseLoader) -> None: def start_import(self, args: NominatimArgs, loader: SpecialPhraseLoader) -> None:
""" """
Create the SPImporter object containing the right Create the SPImporter object containing the right

View File

@@ -25,7 +25,8 @@ from .errors import UsageError
from . import paths from . import paths
LOG = logging.getLogger() LOG = logging.getLogger()
CONFIG_CACHE : Dict[str, Any] = {} CONFIG_CACHE: Dict[str, Any] = {}
def flatten_config_list(content: Any, section: str = '') -> List[Any]: def flatten_config_list(content: Any, section: str = '') -> List[Any]:
""" Flatten YAML configuration lists that contain include sections """ Flatten YAML configuration lists that contain include sections
@@ -79,14 +80,12 @@ class Configuration:
self.lib_dir = _LibDirs() self.lib_dir = _LibDirs()
self._private_plugins: Dict[str, object] = {} self._private_plugins: Dict[str, object] = {}
def set_libdirs(self, **kwargs: StrPath) -> None: def set_libdirs(self, **kwargs: StrPath) -> None:
""" Set paths to library functions and data. """ Set paths to library functions and data.
""" """
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(self.lib_dir, key, None if value is None else Path(value)) setattr(self.lib_dir, key, None if value is None else Path(value))
def __getattr__(self, name: str) -> str: def __getattr__(self, name: str) -> str:
name = 'NOMINATIM_' + name name = 'NOMINATIM_' + name
@@ -95,7 +94,6 @@ class Configuration:
return self._config[name] or '' return self._config[name] or ''
def get_bool(self, name: str) -> bool: def get_bool(self, name: str) -> bool:
""" Return the given configuration parameter as a boolean. """ Return the given configuration parameter as a boolean.
@@ -108,7 +106,6 @@ class Configuration:
""" """
return getattr(self, name).lower() in ('1', 'yes', 'true') return getattr(self, name).lower() in ('1', 'yes', 'true')
def get_int(self, name: str) -> int: def get_int(self, name: str) -> int:
""" Return the given configuration parameter as an int. """ Return the given configuration parameter as an int.
@@ -128,7 +125,6 @@ class Configuration:
LOG.fatal("Invalid setting NOMINATIM_%s. Needs to be a number.", name) LOG.fatal("Invalid setting NOMINATIM_%s. Needs to be a number.", name)
raise UsageError("Configuration error.") from exp raise UsageError("Configuration error.") from exp
def get_str_list(self, name: str) -> Optional[List[str]]: def get_str_list(self, name: str) -> Optional[List[str]]:
""" Return the given configuration parameter as a list of strings. """ Return the given configuration parameter as a list of strings.
The values are assumed to be given as a comma-sparated list and The values are assumed to be given as a comma-sparated list and
@@ -148,7 +144,6 @@ class Configuration:
return [v.strip() for v in raw.split(',')] if raw else None return [v.strip() for v in raw.split(',')] if raw else None
def get_path(self, name: str) -> Optional[Path]: def get_path(self, name: str) -> Optional[Path]:
""" Return the given configuration parameter as a Path. """ Return the given configuration parameter as a Path.
@@ -174,7 +169,6 @@ class Configuration:
return cfgpath.resolve() return cfgpath.resolve()
def get_libpq_dsn(self) -> str: def get_libpq_dsn(self) -> str:
""" Get configured database DSN converted into the key/value format """ Get configured database DSN converted into the key/value format
understood by libpq and psycopg. understood by libpq and psycopg.
@@ -194,7 +188,6 @@ class Configuration:
return dsn return dsn
def get_database_params(self) -> Mapping[str, Union[str, int, None]]: def get_database_params(self) -> Mapping[str, Union[str, int, None]]:
""" Get the configured parameters for the database connection """ Get the configured parameters for the database connection
as a mapping. as a mapping.
@@ -206,7 +199,6 @@ class Configuration:
return conninfo_to_dict(dsn) return conninfo_to_dict(dsn)
def get_import_style_file(self) -> Path: def get_import_style_file(self) -> Path:
""" Return the import style file as a path object. Translates the """ Return the import style file as a path object. Translates the
name of the standard styles automatically into a file in the name of the standard styles automatically into a file in the
@@ -219,7 +211,6 @@ class Configuration:
return self.find_config_file('', 'IMPORT_STYLE') return self.find_config_file('', 'IMPORT_STYLE')
def get_os_env(self) -> Dict[str, str]: def get_os_env(self) -> Dict[str, str]:
""" Return a copy of the OS environment with the Nominatim configuration """ Return a copy of the OS environment with the Nominatim configuration
merged in. merged in.
@@ -229,7 +220,6 @@ class Configuration:
return env return env
def load_sub_configuration(self, filename: StrPath, def load_sub_configuration(self, filename: StrPath,
config: Optional[str] = None) -> Any: config: Optional[str] = None) -> Any:
""" Load additional configuration from a file. `filename` is the name """ Load additional configuration from a file. `filename` is the name
@@ -267,7 +257,6 @@ class Configuration:
CONFIG_CACHE[str(configfile)] = result CONFIG_CACHE[str(configfile)] = result
return result return result
def load_plugin_module(self, module_name: str, internal_path: str) -> Any: def load_plugin_module(self, module_name: str, internal_path: str) -> Any:
""" Load a Python module as a plugin. """ Load a Python module as a plugin.
@@ -310,7 +299,6 @@ class Configuration:
return sys.modules.get(module_name) or importlib.import_module(module_name) return sys.modules.get(module_name) or importlib.import_module(module_name)
def find_config_file(self, filename: StrPath, def find_config_file(self, filename: StrPath,
config: Optional[str] = None) -> Path: config: Optional[str] = None) -> Path:
""" Resolve the location of a configuration file given a filename and """ Resolve the location of a configuration file given a filename and
@@ -334,7 +322,6 @@ class Configuration:
filename = cfg_filename filename = cfg_filename
search_paths = [self.project_dir, self.config_dir] search_paths = [self.project_dir, self.config_dir]
for path in search_paths: for path in search_paths:
if path is not None and (path / filename).is_file(): if path is not None and (path / filename).is_file():
@@ -344,7 +331,6 @@ class Configuration:
filename, search_paths) filename, search_paths)
raise UsageError("Config file not found.") raise UsageError("Config file not found.")
def _load_from_yaml(self, cfgfile: Path) -> Any: def _load_from_yaml(self, cfgfile: Path) -> Any:
""" Load a YAML configuration file. This installs a special handler that """ Load a YAML configuration file. This installs a special handler that
allows to include other YAML files using the '!include' operator. allows to include other YAML files using the '!include' operator.
@@ -353,7 +339,6 @@ class Configuration:
Loader=yaml.SafeLoader) Loader=yaml.SafeLoader)
return yaml.safe_load(cfgfile.read_text(encoding='utf-8')) return yaml.safe_load(cfgfile.read_text(encoding='utf-8'))
def _yaml_include_representer(self, loader: Any, node: yaml.Node) -> Any: def _yaml_include_representer(self, loader: Any, node: yaml.Node) -> Any:
""" Handler for the '!include' operator in YAML files. """ Handler for the '!include' operator in YAML files.

View File

@@ -16,6 +16,7 @@ from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..tokenizer.base import AbstractTokenizer from ..tokenizer.base import AbstractTokenizer
def _flatten_name_list(names: Any) -> Dict[str, str]: def _flatten_name_list(names: Any) -> Dict[str, str]:
if names is None: if names is None:
return {} return {}
@@ -39,7 +40,6 @@ def _flatten_name_list(names: Any) -> Dict[str, str]:
return flat return flat
class _CountryInfo: class _CountryInfo:
""" Caches country-specific properties from the configuration file. """ Caches country-specific properties from the configuration file.
""" """
@@ -47,7 +47,6 @@ class _CountryInfo:
def __init__(self) -> None: def __init__(self) -> None:
self._info: Dict[str, Dict[str, Any]] = {} self._info: Dict[str, Dict[str, Any]] = {}
def load(self, config: Configuration) -> None: def load(self, config: Configuration) -> None:
""" Load the country properties from the configuration files, """ Load the country properties from the configuration files,
if they are not loaded yet. if they are not loaded yet.
@@ -63,7 +62,6 @@ class _CountryInfo:
for x in prop['languages'].split(',')] for x in prop['languages'].split(',')]
prop['names'] = _flatten_name_list(prop.get('names')) prop['names'] = _flatten_name_list(prop.get('names'))
def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]: def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]:
""" Return tuples of (country_code, property dict) as iterable. """ Return tuples of (country_code, property dict) as iterable.
""" """
@@ -75,7 +73,6 @@ class _CountryInfo:
return self._info.get(country_code, {}) return self._info.get(country_code, {})
_COUNTRY_INFO = _CountryInfo() _COUNTRY_INFO = _CountryInfo()
@@ -86,14 +83,17 @@ def setup_country_config(config: Configuration) -> None:
""" """
_COUNTRY_INFO.load(config) _COUNTRY_INFO.load(config)
@overload @overload
def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]: def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]:
... ...
@overload @overload
def iterate(prop: str) -> Iterable[Tuple[str, Any]]: def iterate(prop: str) -> Iterable[Tuple[str, Any]]:
... ...
def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]: def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]:
""" Iterate over country code and properties. """ Iterate over country code and properties.
@@ -168,7 +168,7 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
# country names (only in languages as provided) # country names (only in languages as provided)
if name: if name:
names.update({k : v for k, v in name.items() if _include_key(k)}) names.update({k: v for k, v in name.items() if _include_key(k)})
analyzer.add_country_names(code, names) analyzer.add_country_names(code, names)

View File

@@ -10,6 +10,7 @@ the tokenizer.
""" """
from typing import Optional, Mapping, Any, Tuple from typing import Optional, Mapping, Any, Tuple
class PlaceInfo: class PlaceInfo:
""" This data class contains all information the tokenizer can access """ This data class contains all information the tokenizer can access
about a place. about a place.
@@ -18,7 +19,6 @@ class PlaceInfo:
def __init__(self, info: Mapping[str, Any]) -> None: def __init__(self, info: Mapping[str, Any]) -> None:
self._info = info self._info = info
@property @property
def name(self) -> Optional[Mapping[str, str]]: def name(self) -> Optional[Mapping[str, str]]:
""" A dictionary with the names of the place. Keys and values represent """ A dictionary with the names of the place. Keys and values represent
@@ -28,7 +28,6 @@ class PlaceInfo:
""" """
return self._info.get('name') return self._info.get('name')
@property @property
def address(self) -> Optional[Mapping[str, str]]: def address(self) -> Optional[Mapping[str, str]]:
""" A dictionary with the address elements of the place. They key """ A dictionary with the address elements of the place. They key
@@ -43,7 +42,6 @@ class PlaceInfo:
""" """
return self._info.get('address') return self._info.get('address')
@property @property
def country_code(self) -> Optional[str]: def country_code(self) -> Optional[str]:
""" The country code of the country the place is in. Guaranteed """ The country code of the country the place is in. Guaranteed
@@ -52,7 +50,6 @@ class PlaceInfo:
""" """
return self._info.get('country_code') return self._info.get('country_code')
@property @property
def rank_address(self) -> int: def rank_address(self) -> int:
""" The [rank address][1] before any rank correction is applied. """ The [rank address][1] before any rank correction is applied.
@@ -61,7 +58,6 @@ class PlaceInfo:
""" """
return self._info.get('rank_address', 0) return self._info.get('rank_address', 0)
@property @property
def centroid(self) -> Optional[Tuple[float, float]]: def centroid(self) -> Optional[Tuple[float, float]]:
""" A center point of the place in WGS84. May be None when the """ A center point of the place in WGS84. May be None when the
@@ -70,17 +66,15 @@ class PlaceInfo:
x, y = self._info.get('centroid_x'), self._info.get('centroid_y') x, y = self._info.get('centroid_x'), self._info.get('centroid_y')
return None if x is None or y is None else (x, y) return None if x is None or y is None else (x, y)
def is_a(self, key: str, value: str) -> bool: def is_a(self, key: str, value: str) -> bool:
""" Set to True when the place's primary tag corresponds to the given """ Set to True when the place's primary tag corresponds to the given
key and value. key and value.
""" """
return self._info.get('class') == key and self._info.get('type') == value return self._info.get('class') == key and self._info.get('type') == value
def is_country(self) -> bool: def is_country(self) -> bool:
""" Set to True when the place is a valid country boundary. """ Set to True when the place is a valid country boundary.
""" """
return self.rank_address == 4 \ return self.rank_address == 4 \
and self.is_a('boundary', 'administrative') \ and self.is_a('boundary', 'administrative') \
and self.country_code is not None and self.country_code is not None

View File

@@ -9,6 +9,7 @@ Data class for a single name of a place.
""" """
from typing import Optional, Dict, Mapping from typing import Optional, Dict, Mapping
class PlaceName: class PlaceName:
""" Each name and address part of a place is encapsulated in an object of """ Each name and address part of a place is encapsulated in an object of
this class. It saves not only the name proper but also describes the this class. It saves not only the name proper but also describes the
@@ -32,11 +33,9 @@ class PlaceName:
self.suffix = suffix self.suffix = suffix
self.attr: Dict[str, str] = {} self.attr: Dict[str, str] = {}
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PlaceName(name={self.name!r},kind={self.kind!r},suffix={self.suffix!r})" return f"PlaceName(name={self.name!r},kind={self.kind!r},suffix={self.suffix!r})"
def clone(self, name: Optional[str] = None, def clone(self, name: Optional[str] = None,
kind: Optional[str] = None, kind: Optional[str] = None,
suffix: Optional[str] = None, suffix: Optional[str] = None,
@@ -57,21 +56,18 @@ class PlaceName:
return newobj return newobj
def set_attr(self, key: str, value: str) -> None: def set_attr(self, key: str, value: str) -> None:
""" Add the given property to the name. If the property was already """ Add the given property to the name. If the property was already
set, then the value is overwritten. set, then the value is overwritten.
""" """
self.attr[key] = value self.attr[key] = value
def get_attr(self, key: str, default: Optional[str] = None) -> Optional[str]: def get_attr(self, key: str, default: Optional[str] = None) -> Optional[str]:
""" Return the given property or the value of 'default' if it """ Return the given property or the value of 'default' if it
is not set. is not set.
""" """
return self.attr.get(key, default) return self.attr.get(key, default)
def has_attr(self, key: str) -> bool: def has_attr(self, key: str) -> bool:
""" Check if the given attribute is set. """ Check if the given attribute is set.
""" """

View File

@@ -14,6 +14,7 @@ import re
from ..errors import UsageError from ..errors import UsageError
from . import country_info from . import country_info
class CountryPostcodeMatcher: class CountryPostcodeMatcher:
""" Matches and formats a postcode according to a format definition """ Matches and formats a postcode according to a format definition
of the given country. of the given country.
@@ -30,7 +31,6 @@ class CountryPostcodeMatcher:
self.output = config.get('output', r'\g<0>') self.output = config.get('output', r'\g<0>')
def match(self, postcode: str) -> Optional[Match[str]]: def match(self, postcode: str) -> Optional[Match[str]]:
""" Match the given postcode against the postcode pattern for this """ Match the given postcode against the postcode pattern for this
matcher. Returns a `re.Match` object if the match was successful matcher. Returns a `re.Match` object if the match was successful
@@ -44,7 +44,6 @@ class CountryPostcodeMatcher:
return None return None
def normalize(self, match: Match[str]) -> str: def normalize(self, match: Match[str]) -> str:
""" Return the default format of the postcode for the given match. """ Return the default format of the postcode for the given match.
`match` must be a `re.Match` object previously returned by `match` must be a `re.Match` object previously returned by
@@ -71,14 +70,12 @@ class PostcodeFormatter:
else: else:
raise UsageError(f"Invalid entry 'postcode' for country '{ccode}'") raise UsageError(f"Invalid entry 'postcode' for country '{ccode}'")
def set_default_pattern(self, pattern: str) -> None: def set_default_pattern(self, pattern: str) -> None:
""" Set the postcode match pattern to use, when a country does not """ Set the postcode match pattern to use, when a country does not
have a specific pattern. have a specific pattern.
""" """
self.default_matcher = CountryPostcodeMatcher('', {'pattern': pattern}) self.default_matcher = CountryPostcodeMatcher('', {'pattern': pattern})
def get_matcher(self, country_code: Optional[str]) -> Optional[CountryPostcodeMatcher]: def get_matcher(self, country_code: Optional[str]) -> Optional[CountryPostcodeMatcher]:
""" Return the CountryPostcodeMatcher for the given country. """ Return the CountryPostcodeMatcher for the given country.
Returns None if the country doesn't have a postcode and the Returns None if the country doesn't have a postcode and the
@@ -92,7 +89,6 @@ class PostcodeFormatter:
return self.country_matcher.get(country_code, self.default_matcher) return self.country_matcher.get(country_code, self.default_matcher)
def match(self, country_code: Optional[str], postcode: str) -> Optional[Match[str]]: def match(self, country_code: Optional[str], postcode: str) -> Optional[Match[str]]:
""" Match the given postcode against the postcode pattern for this """ Match the given postcode against the postcode pattern for this
matcher. Returns a `re.Match` object if the country has a pattern matcher. Returns a `re.Match` object if the country has a pattern
@@ -105,7 +101,6 @@ class PostcodeFormatter:
return self.country_matcher.get(country_code, self.default_matcher).match(postcode) return self.country_matcher.get(country_code, self.default_matcher).match(postcode)
def normalize(self, country_code: str, match: Match[str]) -> str: def normalize(self, country_code: str, match: Match[str]) -> str:
""" Return the default format of the postcode for the given match. """ Return the default format of the postcode for the given match.
`match` must be a `re.Match` object previously returned by `match` must be a `re.Match` object previously returned by

View File

@@ -23,6 +23,7 @@ LOG = logging.getLogger()
Cursor = psycopg.Cursor[Any] Cursor = psycopg.Cursor[Any]
Connection = psycopg.Connection[Any] Connection = psycopg.Connection[Any]
def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any: def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned. """ Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised. If the query yields more than one row, a ValueError is raised.
@@ -42,9 +43,10 @@ def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -
def table_exists(conn: Connection, table: str) -> bool: def table_exists(conn: Connection, table: str) -> bool:
""" Check that a table with the given name exists in the database. """ Check that a table with the given name exists in the database.
""" """
num = execute_scalar(conn, num = execute_scalar(
"""SELECT count(*) FROM pg_tables conn,
WHERE tablename = %s and schemaname = 'public'""", (table, )) """SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False return num == 1 if isinstance(num, int) else False
@@ -52,9 +54,9 @@ def table_has_column(conn: Connection, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'. """ Check if the table 'table' exists and has a column with name 'column'.
""" """
has_column = execute_scalar(conn, has_column = execute_scalar(conn,
"""SELECT count(*) FROM information_schema.columns """SELECT count(*) FROM information_schema.columns
WHERE table_name = %s and column_name = %s""", WHERE table_name = %s and column_name = %s""",
(table, column)) (table, column))
return has_column > 0 if isinstance(has_column, int) else False return has_column > 0 if isinstance(has_column, int) else False
@@ -77,8 +79,9 @@ def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> b
return True return True
def drop_tables(conn: Connection, *names: str, def drop_tables(conn: Connection, *names: str,
if_exists: bool = True, cascade: bool = False) -> None: if_exists: bool = True, cascade: bool = False) -> None:
""" Drop one or more tables with the given names. """ Drop one or more tables with the given names.
Set `if_exists` to False if a non-existent table should raise Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. `cascade` will cause an exception instead of just being ignored. `cascade` will cause

View File

@@ -11,6 +11,7 @@ from typing import Optional, cast
from .connection import Connection, table_exists from .connection import Connection, table_exists
def set_property(conn: Connection, name: str, value: str) -> None: def set_property(conn: Connection, name: str, value: str) -> None:
""" Add or replace the property with the given name. """ Add or replace the property with the given name.
""" """

View File

@@ -18,6 +18,7 @@ LOG = logging.getLogger()
QueueItem = Optional[Tuple[psycopg.abc.Query, Any]] QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
class QueryPool: class QueryPool:
""" Pool to run SQL queries in parallel asynchronous execution. """ Pool to run SQL queries in parallel asynchronous execution.
@@ -32,7 +33,6 @@ class QueryPool:
self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args)) self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
for _ in range(pool_size)] for _ in range(pool_size)]
async def put_query(self, query: psycopg.abc.Query, params: Any) -> None: async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
""" Schedule a query for execution. """ Schedule a query for execution.
""" """
@@ -41,7 +41,6 @@ class QueryPool:
self.wait_time += time.time() - tstart self.wait_time += time.time() - tstart
await asyncio.sleep(0) await asyncio.sleep(0)
async def finish(self) -> None: async def finish(self) -> None:
""" Wait for all queries to finish and close the pool. """ Wait for all queries to finish and close the pool.
""" """
@@ -57,7 +56,6 @@ class QueryPool:
if excp is not None: if excp is not None:
raise excp raise excp
async def _worker_loop(self, dsn: str, **conn_args: Any) -> None: async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
conn_args['autocommit'] = True conn_args['autocommit'] = True
aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args) aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
@@ -78,10 +76,8 @@ class QueryPool:
str(item[0]), str(item[1])) str(item[0]), str(item[1]))
# item is still valid here, causing a retry # item is still valid here, causing a retry
async def __aenter__(self) -> 'QueryPool': async def __aenter__(self) -> 'QueryPool':
return self return self
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self.finish() await self.finish()

View File

@@ -15,6 +15,7 @@ from .connection import Connection, server_version_tuple, postgis_version_tuple
from ..config import Configuration from ..config import Configuration
from ..db.query_pool import QueryPool from ..db.query_pool import QueryPool
def _get_partitions(conn: Connection) -> Set[int]: def _get_partitions(conn: Connection) -> Set[int]:
""" Get the set of partitions currently in use. """ Get the set of partitions currently in use.
""" """
@@ -35,6 +36,7 @@ def _get_tables(conn: Connection) -> Set[str]:
return set((row[0] for row in list(cur))) return set((row[0] for row in list(cur)))
def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str: def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
""" Returns the version of the slim middle tables. """ Returns the version of the slim middle tables.
""" """
@@ -73,9 +75,10 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
ps3 = postgis_version >= (3, 0) ps3 = postgis_version >= (3, 0)
return { return {
'has_index_non_key_column': pg11plus, 'has_index_non_key_column': pg11plus,
'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST' 'spgist_geom': 'SPGIST' if pg11plus and ps3 else 'GIST'
} }
class SQLPreprocessor: class SQLPreprocessor:
""" A environment for preprocessing SQL files from the """ A environment for preprocessing SQL files from the
lib-sql directory. lib-sql directory.
@@ -102,7 +105,6 @@ class SQLPreprocessor:
self.env.globals['db'] = db_info self.env.globals['db'] = db_info
self.env.globals['postgres'] = _setup_postgresql_features(conn) self.env.globals['postgres'] = _setup_postgresql_features(conn)
def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None: def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
""" Execute the given SQL template string on the connection. """ Execute the given SQL template string on the connection.
The keyword arguments may supply additional parameters The keyword arguments may supply additional parameters
@@ -114,7 +116,6 @@ class SQLPreprocessor:
cur.execute(sql) cur.execute(sql)
conn.commit() conn.commit()
def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None: def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
""" Execute the given SQL file on the connection. The keyword arguments """ Execute the given SQL file on the connection. The keyword arguments
may supply additional parameters for preprocessing. may supply additional parameters for preprocessing.
@@ -125,7 +126,6 @@ class SQLPreprocessor:
cur.execute(sql) cur.execute(sql)
conn.commit() conn.commit()
async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1, async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
**kwargs: Any) -> None: **kwargs: Any) -> None:
""" Execute the given SQL files using parallel asynchronous connections. """ Execute the given SQL files using parallel asynchronous connections.

View File

@@ -18,6 +18,7 @@ from ..errors import UsageError
LOG = logging.getLogger() LOG = logging.getLogger()
def _pipe_to_proc(proc: 'subprocess.Popen[bytes]', def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
fdesc: Union[IO[bytes], gzip.GzipFile]) -> int: fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
assert proc.stdin is not None assert proc.stdin is not None
@@ -31,6 +32,7 @@ def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
return len(chunk) return len(chunk)
def execute_file(dsn: str, fname: Path, def execute_file(dsn: str, fname: Path,
ignore_errors: bool = False, ignore_errors: bool = False,
pre_code: Optional[str] = None, pre_code: Optional[str] = None,

View File

@@ -8,6 +8,7 @@
Custom exception and error classes for Nominatim. Custom exception and error classes for Nominatim.
""" """
class UsageError(Exception): class UsageError(Exception):
""" An error raised because of bad user input. This error will usually """ An error raised because of bad user input. This error will usually
not cause a stack trace to be printed unless debugging is enabled. not cause a stack trace to be printed unless debugging is enabled.

View File

@@ -21,6 +21,7 @@ from . import runners
LOG = logging.getLogger() LOG = logging.getLogger()
class Indexer: class Indexer:
""" Main indexing routine. """ Main indexing routine.
""" """
@@ -30,7 +31,6 @@ class Indexer:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_threads = num_threads self.num_threads = num_threads
def has_pending(self) -> bool: def has_pending(self) -> bool:
""" Check if any data still needs indexing. """ Check if any data still needs indexing.
This function must only be used after the import has finished. This function must only be used after the import has finished.
@@ -41,7 +41,6 @@ class Indexer:
cur.execute("SELECT 'a' FROM placex WHERE indexed_status > 0 LIMIT 1") cur.execute("SELECT 'a' FROM placex WHERE indexed_status > 0 LIMIT 1")
return cur.rowcount > 0 return cur.rowcount > 0
async def index_full(self, analyse: bool = True) -> None: async def index_full(self, analyse: bool = True) -> None:
""" Index the complete database. This will first index boundaries """ Index the complete database. This will first index boundaries
followed by all other objects. When `analyse` is True, then the followed by all other objects. When `analyse` is True, then the
@@ -75,7 +74,6 @@ class Indexer:
if not self.has_pending(): if not self.has_pending():
break break
async def index_boundaries(self, minrank: int, maxrank: int) -> int: async def index_boundaries(self, minrank: int, maxrank: int) -> int:
""" Index only administrative boundaries within the given rank range. """ Index only administrative boundaries within the given rank range.
""" """
@@ -138,7 +136,6 @@ class Indexer:
(minrank, maxrank)) (minrank, maxrank))
total_tuples = {row.rank_address: row.count for row in cur} total_tuples = {row.rank_address: row.count for row in cur}
with self.tokenizer.name_analyzer() as analyzer: with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(1, minrank), maxrank + 1): for rank in range(max(1, minrank), maxrank + 1):
if rank >= 30: if rank >= 30:
@@ -156,7 +153,6 @@ class Indexer:
return total return total
async def index_postcodes(self) -> int: async def index_postcodes(self) -> int:
"""Index the entries of the location_postcode table. """Index the entries of the location_postcode table.
""" """
@@ -164,7 +160,6 @@ class Indexer:
return await self._index(runners.PostcodeRunner(), batch=20) return await self._index(runners.PostcodeRunner(), batch=20)
def update_status_table(self) -> None: def update_status_table(self) -> None:
""" Update the status in the status table to 'indexed'. """ Update the status in the status table to 'indexed'.
""" """
@@ -193,7 +188,7 @@ class Indexer:
if total_tuples > 0: if total_tuples > 0:
async with await psycopg.AsyncConnection.connect( async with await psycopg.AsyncConnection.connect(
self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\ self.dsn, row_factory=psycopg.rows.dict_row) as aconn, \
QueryPool(self.dsn, self.num_threads, autocommit=True) as pool: QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
fetcher_time = 0.0 fetcher_time = 0.0
tstart = time.time() tstart = time.time()
@@ -224,7 +219,6 @@ class Indexer:
return progress.done() return progress.done()
def _prepare_indexing(self, runner: runners.Runner) -> int: def _prepare_indexing(self, runner: runners.Runner) -> int:
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore") hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")

View File

@@ -14,6 +14,7 @@ LOG = logging.getLogger()
INITIAL_PROGRESS = 10 INITIAL_PROGRESS = 10
class ProgressLogger: class ProgressLogger:
""" Tracks and prints progress for the indexing process. """ Tracks and prints progress for the indexing process.
`name` is the name of the indexing step being tracked. `name` is the name of the indexing step being tracked.

View File

@@ -19,11 +19,11 @@ from ..typing import Protocol
from ..data.place_info import PlaceInfo from ..data.place_info import PlaceInfo
from ..tokenizer.base import AbstractAnalyzer from ..tokenizer.base import AbstractAnalyzer
# pylint: disable=C0111
def _mk_valuelist(template: str, num: int) -> pysql.Composed: def _mk_valuelist(template: str, num: int) -> pysql.Composed:
return pysql.SQL(',').join([pysql.SQL(template)] * num) return pysql.SQL(',').join([pysql.SQL(template)] * num)
def _analyze_place(place: DictRow, analyzer: AbstractAnalyzer) -> Json: def _analyze_place(place: DictRow, analyzer: AbstractAnalyzer) -> Json:
return Json(analyzer.process_place(PlaceInfo(place))) return Json(analyzer.process_place(PlaceInfo(place)))
@@ -41,6 +41,7 @@ SELECT_SQL = pysql.SQL("""SELECT place_id, extra.*
LATERAL placex_indexing_prepare(px) as extra """) LATERAL placex_indexing_prepare(px) as extra """)
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)" UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
class AbstractPlacexRunner: class AbstractPlacexRunner:
""" Returns SQL commands for indexing of the placex table. """ Returns SQL commands for indexing of the placex table.
""" """
@@ -49,7 +50,6 @@ class AbstractPlacexRunner:
self.rank = rank self.rank = rank
self.analyzer = analyzer self.analyzer = analyzer
def index_places_query(self, batch_size: int) -> Query: def index_places_query(self, batch_size: int) -> Query:
return pysql.SQL( return pysql.SQL(
""" UPDATE placex """ UPDATE placex
@@ -59,7 +59,6 @@ class AbstractPlacexRunner:
WHERE place_id = v.id WHERE place_id = v.id
""").format(_mk_valuelist(UPDATE_LINE, batch_size)) """).format(_mk_valuelist(UPDATE_LINE, batch_size))
def index_places_params(self, place: DictRow) -> Sequence[Any]: def index_places_params(self, place: DictRow) -> Sequence[Any]:
return (place['place_id'], return (place['place_id'],
place['name'], place['name'],
@@ -118,7 +117,6 @@ class InterpolationRunner:
def __init__(self, analyzer: AbstractAnalyzer) -> None: def __init__(self, analyzer: AbstractAnalyzer) -> None:
self.analyzer = analyzer self.analyzer = analyzer
def name(self) -> str: def name(self) -> str:
return "interpolation lines (location_property_osmline)" return "interpolation lines (location_property_osmline)"
@@ -126,14 +124,12 @@ class InterpolationRunner:
return """SELECT count(*) FROM location_property_osmline return """SELECT count(*) FROM location_property_osmline
WHERE indexed_status > 0""" WHERE indexed_status > 0"""
def sql_get_objects(self) -> Query: def sql_get_objects(self) -> Query:
return """SELECT place_id, get_interpolation_address(address, osm_id) as address return """SELECT place_id, get_interpolation_address(address, osm_id) as address
FROM location_property_osmline FROM location_property_osmline
WHERE indexed_status > 0 WHERE indexed_status > 0
ORDER BY geometry_sector""" ORDER BY geometry_sector"""
def index_places_query(self, batch_size: int) -> Query: def index_places_query(self, batch_size: int) -> Query:
return pysql.SQL("""UPDATE location_property_osmline return pysql.SQL("""UPDATE location_property_osmline
SET indexed_status = 0, address = v.addr, token_info = v.ti SET indexed_status = 0, address = v.addr, token_info = v.ti
@@ -141,13 +137,11 @@ class InterpolationRunner:
WHERE place_id = v.id WHERE place_id = v.id
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", batch_size)) """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", batch_size))
def index_places_params(self, place: DictRow) -> Sequence[Any]: def index_places_params(self, place: DictRow) -> Sequence[Any]:
return (place['place_id'], place['address'], return (place['place_id'], place['address'],
_analyze_place(place, self.analyzer)) _analyze_place(place, self.analyzer))
class PostcodeRunner(Runner): class PostcodeRunner(Runner):
""" Provides the SQL commands for indexing the location_postcode table. """ Provides the SQL commands for indexing the location_postcode table.
""" """
@@ -155,22 +149,18 @@ class PostcodeRunner(Runner):
def name(self) -> str: def name(self) -> str:
return "postcodes (location_postcode)" return "postcodes (location_postcode)"
def sql_count_objects(self) -> Query: def sql_count_objects(self) -> Query:
return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0' return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
def sql_get_objects(self) -> Query: def sql_get_objects(self) -> Query:
return """SELECT place_id FROM location_postcode return """SELECT place_id FROM location_postcode
WHERE indexed_status > 0 WHERE indexed_status > 0
ORDER BY country_code, postcode""" ORDER BY country_code, postcode"""
def index_places_query(self, batch_size: int) -> Query: def index_places_query(self, batch_size: int) -> Query:
return pysql.SQL("""UPDATE location_postcode SET indexed_status = 0 return pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
WHERE place_id IN ({})""")\ WHERE place_id IN ({})""")\
.format(pysql.SQL(',').join((pysql.Placeholder() for _ in range(batch_size)))) .format(pysql.SQL(',').join((pysql.Placeholder() for _ in range(batch_size))))
def index_places_params(self, place: DictRow) -> Sequence[Any]: def index_places_params(self, place: DictRow) -> Sequence[Any]:
return (place['place_id'], ) return (place['place_id'], )

View File

@@ -17,6 +17,7 @@ from ..config import Configuration
from ..db.connection import Connection from ..db.connection import Connection
from ..data.place_info import PlaceInfo from ..data.place_info import PlaceInfo
class AbstractAnalyzer(ABC): class AbstractAnalyzer(ABC):
""" The analyzer provides the functions for analysing names and building """ The analyzer provides the functions for analysing names and building
the token database. the token database.
@@ -28,17 +29,14 @@ class AbstractAnalyzer(ABC):
def __enter__(self) -> 'AbstractAnalyzer': def __enter__(self) -> 'AbstractAnalyzer':
return self return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close() self.close()
@abstractmethod @abstractmethod
def close(self) -> None: def close(self) -> None:
""" Free all resources used by the analyzer. """ Free all resources used by the analyzer.
""" """
@abstractmethod @abstractmethod
def get_word_token_info(self, words: List[str]) -> List[Tuple[str, str, int]]: def get_word_token_info(self, words: List[str]) -> List[Tuple[str, str, int]]:
""" Return token information for the given list of words. """ Return token information for the given list of words.
@@ -57,7 +55,6 @@ class AbstractAnalyzer(ABC):
(original word, word token, word id). (original word, word token, word id).
""" """
@abstractmethod @abstractmethod
def normalize_postcode(self, postcode: str) -> str: def normalize_postcode(self, postcode: str) -> str:
""" Convert the postcode to its standardized form. """ Convert the postcode to its standardized form.
@@ -72,14 +69,12 @@ class AbstractAnalyzer(ABC):
The given postcode after normalization. The given postcode after normalization.
""" """
@abstractmethod @abstractmethod
def update_postcodes_from_db(self) -> None: def update_postcodes_from_db(self) -> None:
""" Update the tokenizer's postcode tokens from the current content """ Update the tokenizer's postcode tokens from the current content
of the `location_postcode` table. of the `location_postcode` table.
""" """
@abstractmethod @abstractmethod
def update_special_phrases(self, def update_special_phrases(self,
phrases: Iterable[Tuple[str, str, str, str]], phrases: Iterable[Tuple[str, str, str, str]],
@@ -95,7 +90,6 @@ class AbstractAnalyzer(ABC):
ones that already exist. ones that already exist.
""" """
@abstractmethod @abstractmethod
def add_country_names(self, country_code: str, names: Dict[str, str]) -> None: def add_country_names(self, country_code: str, names: Dict[str, str]) -> None:
""" Add the given names to the tokenizer's list of country tokens. """ Add the given names to the tokenizer's list of country tokens.
@@ -106,7 +100,6 @@ class AbstractAnalyzer(ABC):
names: Dictionary of name type to name. names: Dictionary of name type to name.
""" """
@abstractmethod @abstractmethod
def process_place(self, place: PlaceInfo) -> Any: def process_place(self, place: PlaceInfo) -> Any:
""" Extract tokens for the given place and compute the """ Extract tokens for the given place and compute the
@@ -122,7 +115,6 @@ class AbstractAnalyzer(ABC):
""" """
class AbstractTokenizer(ABC): class AbstractTokenizer(ABC):
""" The tokenizer instance is the central instance of the tokenizer in """ The tokenizer instance is the central instance of the tokenizer in
the system. There will only be a single instance of the tokenizer the system. There will only be a single instance of the tokenizer
@@ -146,7 +138,6 @@ class AbstractTokenizer(ABC):
tokenizers. tokenizers.
""" """
@abstractmethod @abstractmethod
def init_from_project(self, config: Configuration) -> None: def init_from_project(self, config: Configuration) -> None:
""" Initialise the tokenizer from an existing database setup. """ Initialise the tokenizer from an existing database setup.
@@ -158,7 +149,6 @@ class AbstractTokenizer(ABC):
config: Read-only object with configuration options. config: Read-only object with configuration options.
""" """
@abstractmethod @abstractmethod
def finalize_import(self, config: Configuration) -> None: def finalize_import(self, config: Configuration) -> None:
""" This function is called at the very end of an import when all """ This function is called at the very end of an import when all
@@ -170,7 +160,6 @@ class AbstractTokenizer(ABC):
config: Read-only object with configuration options. config: Read-only object with configuration options.
""" """
@abstractmethod @abstractmethod
def update_sql_functions(self, config: Configuration) -> None: def update_sql_functions(self, config: Configuration) -> None:
""" Update the SQL part of the tokenizer. This function is called """ Update the SQL part of the tokenizer. This function is called
@@ -184,7 +173,6 @@ class AbstractTokenizer(ABC):
config: Read-only object with configuration options. config: Read-only object with configuration options.
""" """
@abstractmethod @abstractmethod
def check_database(self, config: Configuration) -> Optional[str]: def check_database(self, config: Configuration) -> Optional[str]:
""" Check that the database is set up correctly and ready for being """ Check that the database is set up correctly and ready for being
@@ -199,7 +187,6 @@ class AbstractTokenizer(ABC):
how to resolve the issue. If everything is okay, return `None`. how to resolve the issue. If everything is okay, return `None`.
""" """
@abstractmethod @abstractmethod
def update_statistics(self, config: Configuration, threads: int = 1) -> None: def update_statistics(self, config: Configuration, threads: int = 1) -> None:
""" Recompute any tokenizer statistics necessary for efficient lookup. """ Recompute any tokenizer statistics necessary for efficient lookup.
@@ -208,14 +195,12 @@ class AbstractTokenizer(ABC):
it to be called in order to work. it to be called in order to work.
""" """
@abstractmethod @abstractmethod
def update_word_tokens(self) -> None: def update_word_tokens(self) -> None:
""" Do house-keeping on the tokenizers internal data structures. """ Do house-keeping on the tokenizers internal data structures.
Remove unused word tokens, resort data etc. Remove unused word tokens, resort data etc.
""" """
@abstractmethod @abstractmethod
def name_analyzer(self) -> AbstractAnalyzer: def name_analyzer(self) -> AbstractAnalyzer:
""" Create a new analyzer for tokenizing names and queries """ Create a new analyzer for tokenizing names and queries
@@ -231,7 +216,6 @@ class AbstractTokenizer(ABC):
call the close() function before destructing the analyzer. call the close() function before destructing the analyzer.
""" """
@abstractmethod @abstractmethod
def most_frequent_words(self, conn: Connection, num: int) -> List[str]: def most_frequent_words(self, conn: Connection, num: int) -> List[str]:
""" Return a list of the most frequent full words in the database. """ Return a list of the most frequent full words in the database.

View File

@@ -29,6 +29,7 @@ from ..tokenizer.base import AbstractTokenizer, TokenizerModule
LOG = logging.getLogger() LOG = logging.getLogger()
def _import_tokenizer(name: str) -> TokenizerModule: def _import_tokenizer(name: str) -> TokenizerModule:
""" Load the tokenizer.py module from project directory. """ Load the tokenizer.py module from project directory.
""" """

View File

@@ -61,7 +61,6 @@ class ICURuleLoader:
# Load optional sanitizer rule set. # Load optional sanitizer rule set.
self.sanitizer_rules = rules.get('sanitizers', []) self.sanitizer_rules = rules.get('sanitizers', [])
def load_config_from_db(self, conn: Connection) -> None: def load_config_from_db(self, conn: Connection) -> None:
""" Get previously saved parts of the configuration from the """ Get previously saved parts of the configuration from the
database. database.
@@ -81,7 +80,6 @@ class ICURuleLoader:
self.analysis_rules = [] self.analysis_rules = []
self._setup_analysis() self._setup_analysis()
def save_config_to_db(self, conn: Connection) -> None: def save_config_to_db(self, conn: Connection) -> None:
""" Save the part of the configuration that cannot be changed into """ Save the part of the configuration that cannot be changed into
the database. the database.
@@ -90,20 +88,17 @@ class ICURuleLoader:
set_property(conn, DBCFG_IMPORT_TRANS_RULES, self.transliteration_rules) set_property(conn, DBCFG_IMPORT_TRANS_RULES, self.transliteration_rules)
set_property(conn, DBCFG_IMPORT_ANALYSIS_RULES, json.dumps(self.analysis_rules)) set_property(conn, DBCFG_IMPORT_ANALYSIS_RULES, json.dumps(self.analysis_rules))
def make_sanitizer(self) -> PlaceSanitizer: def make_sanitizer(self) -> PlaceSanitizer:
""" Create a place sanitizer from the configured rules. """ Create a place sanitizer from the configured rules.
""" """
return PlaceSanitizer(self.sanitizer_rules, self.config) return PlaceSanitizer(self.sanitizer_rules, self.config)
def make_token_analysis(self) -> ICUTokenAnalysis: def make_token_analysis(self) -> ICUTokenAnalysis:
""" Create a token analyser from the reviouly loaded rules. """ Create a token analyser from the reviouly loaded rules.
""" """
return ICUTokenAnalysis(self.normalization_rules, return ICUTokenAnalysis(self.normalization_rules,
self.transliteration_rules, self.analysis) self.transliteration_rules, self.analysis)
def get_search_rules(self) -> str: def get_search_rules(self) -> str:
""" Return the ICU rules to be used during search. """ Return the ICU rules to be used during search.
The rules combine normalization and transliteration. The rules combine normalization and transliteration.
@@ -116,23 +111,20 @@ class ICURuleLoader:
rules.write(self.transliteration_rules) rules.write(self.transliteration_rules)
return rules.getvalue() return rules.getvalue()
def get_normalization_rules(self) -> str: def get_normalization_rules(self) -> str:
""" Return rules for normalisation of a term. """ Return rules for normalisation of a term.
""" """
return self.normalization_rules return self.normalization_rules
def get_transliteration_rules(self) -> str: def get_transliteration_rules(self) -> str:
""" Return the rules for converting a string into its asciii representation. """ Return the rules for converting a string into its asciii representation.
""" """
return self.transliteration_rules return self.transliteration_rules
def _setup_analysis(self) -> None: def _setup_analysis(self) -> None:
""" Process the rules used for creating the various token analyzers. """ Process the rules used for creating the various token analyzers.
""" """
self.analysis: Dict[Optional[str], TokenAnalyzerRule] = {} self.analysis: Dict[Optional[str], TokenAnalyzerRule] = {}
if not isinstance(self.analysis_rules, list): if not isinstance(self.analysis_rules, list):
raise UsageError("Configuration section 'token-analysis' must be a list.") raise UsageError("Configuration section 'token-analysis' must be a list.")
@@ -140,7 +132,7 @@ class ICURuleLoader:
norm = Transliterator.createFromRules("rule_loader_normalization", norm = Transliterator.createFromRules("rule_loader_normalization",
self.normalization_rules) self.normalization_rules)
trans = Transliterator.createFromRules("rule_loader_transliteration", trans = Transliterator.createFromRules("rule_loader_transliteration",
self.transliteration_rules) self.transliteration_rules)
for section in self.analysis_rules: for section in self.analysis_rules:
name = section.get('id', None) name = section.get('id', None)
@@ -154,7 +146,6 @@ class ICURuleLoader:
self.analysis[name] = TokenAnalyzerRule(section, norm, trans, self.analysis[name] = TokenAnalyzerRule(section, norm, trans,
self.config) self.config)
@staticmethod @staticmethod
def _cfg_to_icu_rules(rules: Mapping[str, Any], section: str) -> str: def _cfg_to_icu_rules(rules: Mapping[str, Any], section: str) -> str:
""" Load an ICU ruleset from the given section. If the section is a """ Load an ICU ruleset from the given section. If the section is a
@@ -189,7 +180,6 @@ class TokenAnalyzerRule:
self.config = self._analysis_mod.configure(rules, normalizer, self.config = self._analysis_mod.configure(rules, normalizer,
transliterator) transliterator)
def create(self, normalizer: Any, transliterator: Any) -> Analyzer: def create(self, normalizer: Any, transliterator: Any) -> Analyzer:
""" Create a new analyser instance for the given rule. """ Create a new analyser instance for the given rule.
""" """

View File

@@ -14,8 +14,9 @@ from icu import Transliterator
from .token_analysis.base import Analyzer from .token_analysis.base import Analyzer
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any # noqa
from .icu_rule_loader import TokenAnalyzerRule # pylint: disable=cyclic-import from .icu_rule_loader import TokenAnalyzerRule
class ICUTokenAnalysis: class ICUTokenAnalysis:
""" Container class collecting the transliterators and token analysis """ Container class collecting the transliterators and token analysis
@@ -35,7 +36,6 @@ class ICUTokenAnalysis:
self.analysis = {name: arules.create(self.normalizer, self.to_ascii) self.analysis = {name: arules.create(self.normalizer, self.to_ascii)
for name, arules in analysis_rules.items()} for name, arules in analysis_rules.items()}
def get_analyzer(self, name: Optional[str]) -> Analyzer: def get_analyzer(self, name: Optional[str]) -> Analyzer:
""" Return the given named analyzer. If no analyzer with that """ Return the given named analyzer. If no analyzer with that
name exists, return the default analyzer. name exists, return the default analyzer.

View File

@@ -17,7 +17,7 @@ from pathlib import Path
from psycopg.types.json import Jsonb from psycopg.types.json import Jsonb
from psycopg import sql as pysql from psycopg import sql as pysql
from ..db.connection import connect, Connection, Cursor, server_version_tuple,\ from ..db.connection import connect, Connection, Cursor, server_version_tuple, \
drop_tables, table_exists, execute_scalar drop_tables, table_exists, execute_scalar
from ..config import Configuration from ..config import Configuration
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
@@ -32,10 +32,11 @@ DBCFG_TERM_NORMALIZATION = "tokenizer_term_normalization"
LOG = logging.getLogger() LOG = logging.getLogger()
WORD_TYPES =(('country_names', 'C'), WORD_TYPES = (('country_names', 'C'),
('postcodes', 'P'), ('postcodes', 'P'),
('full_word', 'W'), ('full_word', 'W'),
('housenumbers', 'H')) ('housenumbers', 'H'))
def create(dsn: str, data_dir: Path) -> 'ICUTokenizer': def create(dsn: str, data_dir: Path) -> 'ICUTokenizer':
""" Create a new instance of the tokenizer provided by this module. """ Create a new instance of the tokenizer provided by this module.
@@ -54,7 +55,6 @@ class ICUTokenizer(AbstractTokenizer):
self.data_dir = data_dir self.data_dir = data_dir
self.loader: Optional[ICURuleLoader] = None self.loader: Optional[ICURuleLoader] = None
def init_new_db(self, config: Configuration, init_db: bool = True) -> None: def init_new_db(self, config: Configuration, init_db: bool = True) -> None:
""" Set up a new tokenizer for the database. """ Set up a new tokenizer for the database.
@@ -70,7 +70,6 @@ class ICUTokenizer(AbstractTokenizer):
self._setup_db_tables(config) self._setup_db_tables(config)
self._create_base_indices(config, 'word') self._create_base_indices(config, 'word')
def init_from_project(self, config: Configuration) -> None: def init_from_project(self, config: Configuration) -> None:
""" Initialise the tokenizer from the project directory. """ Initialise the tokenizer from the project directory.
""" """
@@ -79,14 +78,12 @@ class ICUTokenizer(AbstractTokenizer):
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
self.loader.load_config_from_db(conn) self.loader.load_config_from_db(conn)
def finalize_import(self, config: Configuration) -> None: def finalize_import(self, config: Configuration) -> None:
""" Do any required postprocessing to make the tokenizer data ready """ Do any required postprocessing to make the tokenizer data ready
for use. for use.
""" """
self._create_lookup_indices(config, 'word') self._create_lookup_indices(config, 'word')
def update_sql_functions(self, config: Configuration) -> None: def update_sql_functions(self, config: Configuration) -> None:
""" Reimport the SQL functions for this tokenizer. """ Reimport the SQL functions for this tokenizer.
""" """
@@ -94,14 +91,12 @@ class ICUTokenizer(AbstractTokenizer):
sqlp = SQLPreprocessor(conn, config) sqlp = SQLPreprocessor(conn, config)
sqlp.run_sql_file(conn, 'tokenizer/icu_tokenizer.sql') sqlp.run_sql_file(conn, 'tokenizer/icu_tokenizer.sql')
def check_database(self, config: Configuration) -> None: def check_database(self, config: Configuration) -> None:
""" Check that the tokenizer is set up correctly. """ Check that the tokenizer is set up correctly.
""" """
# Will throw an error if there is an issue. # Will throw an error if there is an issue.
self.init_from_project(config) self.init_from_project(config)
def update_statistics(self, config: Configuration, threads: int = 2) -> None: def update_statistics(self, config: Configuration, threads: int = 2) -> None:
""" Recompute frequencies for all name words. """ Recompute frequencies for all name words.
""" """
@@ -126,28 +121,29 @@ class ICUTokenizer(AbstractTokenizer):
SELECT unnest(nameaddress_vector) as id, count(*) SELECT unnest(nameaddress_vector) as id, count(*)
FROM search_name GROUP BY id""") FROM search_name GROUP BY id""")
cur.execute('CREATE INDEX ON addressword_frequencies(id)') cur.execute('CREATE INDEX ON addressword_frequencies(id)')
cur.execute("""CREATE OR REPLACE FUNCTION word_freq_update(wid INTEGER, cur.execute("""
INOUT info JSONB) CREATE OR REPLACE FUNCTION word_freq_update(wid INTEGER,
AS $$ INOUT info JSONB)
DECLARE rec RECORD; AS $$
BEGIN DECLARE rec RECORD;
IF info is null THEN BEGIN
info = '{}'::jsonb; IF info is null THEN
END IF; info = '{}'::jsonb;
FOR rec IN SELECT count FROM word_frequencies WHERE id = wid END IF;
LOOP FOR rec IN SELECT count FROM word_frequencies WHERE id = wid
info = info || jsonb_build_object('count', rec.count); LOOP
END LOOP; info = info || jsonb_build_object('count', rec.count);
FOR rec IN SELECT count FROM addressword_frequencies WHERE id = wid END LOOP;
LOOP FOR rec IN SELECT count FROM addressword_frequencies WHERE id = wid
info = info || jsonb_build_object('addr_count', rec.count); LOOP
END LOOP; info = info || jsonb_build_object('addr_count', rec.count);
IF info = '{}'::jsonb THEN END LOOP;
info = null; IF info = '{}'::jsonb THEN
END IF; info = null;
END; END IF;
$$ LANGUAGE plpgsql IMMUTABLE; END;
""") $$ LANGUAGE plpgsql IMMUTABLE;
""")
LOG.info('Update word table with recomputed frequencies') LOG.info('Update word table with recomputed frequencies')
drop_tables(conn, 'tmp_word') drop_tables(conn, 'tmp_word')
cur.execute("""CREATE TABLE tmp_word AS cur.execute("""CREATE TABLE tmp_word AS
@@ -200,8 +196,6 @@ class ICUTokenizer(AbstractTokenizer):
self._create_lookup_indices(config, 'tmp_word') self._create_lookup_indices(config, 'tmp_word')
self._move_temporary_word_table('tmp_word') self._move_temporary_word_table('tmp_word')
def _cleanup_housenumbers(self) -> None: def _cleanup_housenumbers(self) -> None:
""" Remove unused house numbers. """ Remove unused house numbers.
""" """
@@ -235,8 +229,6 @@ class ICUTokenizer(AbstractTokenizer):
(list(candidates.values()), )) (list(candidates.values()), ))
conn.commit() conn.commit()
def update_word_tokens(self) -> None: def update_word_tokens(self) -> None:
""" Remove unused tokens. """ Remove unused tokens.
""" """
@@ -244,7 +236,6 @@ class ICUTokenizer(AbstractTokenizer):
self._cleanup_housenumbers() self._cleanup_housenumbers()
LOG.warning("Tokenizer house-keeping done.") LOG.warning("Tokenizer house-keeping done.")
def name_analyzer(self) -> 'ICUNameAnalyzer': def name_analyzer(self) -> 'ICUNameAnalyzer':
""" Create a new analyzer for tokenizing names and queries """ Create a new analyzer for tokenizing names and queries
using this tokinzer. Analyzers are context managers and should using this tokinzer. Analyzers are context managers and should
@@ -264,7 +255,6 @@ class ICUTokenizer(AbstractTokenizer):
return ICUNameAnalyzer(self.dsn, self.loader.make_sanitizer(), return ICUNameAnalyzer(self.dsn, self.loader.make_sanitizer(),
self.loader.make_token_analysis()) self.loader.make_token_analysis())
def most_frequent_words(self, conn: Connection, num: int) -> List[str]: def most_frequent_words(self, conn: Connection, num: int) -> List[str]:
""" Return a list of the `num` most frequent full words """ Return a list of the `num` most frequent full words
in the database. in the database.
@@ -276,7 +266,6 @@ class ICUTokenizer(AbstractTokenizer):
ORDER BY count DESC LIMIT %s""", (num,)) ORDER BY count DESC LIMIT %s""", (num,))
return list(s[0].split('@')[0] for s in cur) return list(s[0].split('@')[0] for s in cur)
def _save_config(self) -> None: def _save_config(self) -> None:
""" Save the configuration that needs to remain stable for the given """ Save the configuration that needs to remain stable for the given
database as database properties. database as database properties.
@@ -285,7 +274,6 @@ class ICUTokenizer(AbstractTokenizer):
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
self.loader.save_config_to_db(conn) self.loader.save_config_to_db(conn)
def _setup_db_tables(self, config: Configuration) -> None: def _setup_db_tables(self, config: Configuration) -> None:
""" Set up the word table and fill it with pre-computed word """ Set up the word table and fill it with pre-computed word
frequencies. frequencies.
@@ -309,7 +297,6 @@ class ICUTokenizer(AbstractTokenizer):
""") """)
conn.commit() conn.commit()
def _create_base_indices(self, config: Configuration, table_name: str) -> None: def _create_base_indices(self, config: Configuration, table_name: str) -> None:
""" Set up the word table and fill it with pre-computed word """ Set up the word table and fill it with pre-computed word
frequencies. frequencies.
@@ -330,21 +317,21 @@ class ICUTokenizer(AbstractTokenizer):
column_type=ctype) column_type=ctype)
conn.commit() conn.commit()
def _create_lookup_indices(self, config: Configuration, table_name: str) -> None: def _create_lookup_indices(self, config: Configuration, table_name: str) -> None:
""" Create additional indexes used when running the API. """ Create additional indexes used when running the API.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
sqlp = SQLPreprocessor(conn, config) sqlp = SQLPreprocessor(conn, config)
# Index required for details lookup. # Index required for details lookup.
sqlp.run_string(conn, """ sqlp.run_string(
conn,
"""
CREATE INDEX IF NOT EXISTS idx_{{table_name}}_word_id CREATE INDEX IF NOT EXISTS idx_{{table_name}}_word_id
ON {{table_name}} USING BTREE (word_id) {{db.tablespace.search_index}} ON {{table_name}} USING BTREE (word_id) {{db.tablespace.search_index}}
""", """,
table_name=table_name) table_name=table_name)
conn.commit() conn.commit()
def _move_temporary_word_table(self, old: str) -> None: def _move_temporary_word_table(self, old: str) -> None:
""" Rename all tables and indexes used by the tokenizer. """ Rename all tables and indexes used by the tokenizer.
""" """
@@ -361,8 +348,6 @@ class ICUTokenizer(AbstractTokenizer):
conn.commit() conn.commit()
class ICUNameAnalyzer(AbstractAnalyzer): class ICUNameAnalyzer(AbstractAnalyzer):
""" The ICU analyzer uses the ICU library for splitting names. """ The ICU analyzer uses the ICU library for splitting names.
@@ -379,7 +364,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
self._cache = _TokenCache() self._cache = _TokenCache()
def close(self) -> None: def close(self) -> None:
""" Free all resources used by the analyzer. """ Free all resources used by the analyzer.
""" """
@@ -387,20 +371,17 @@ class ICUNameAnalyzer(AbstractAnalyzer):
self.conn.close() self.conn.close()
self.conn = None self.conn = None
def _search_normalized(self, name: str) -> str: def _search_normalized(self, name: str) -> str:
""" Return the search token transliteration of the given name. """ Return the search token transliteration of the given name.
""" """
return cast(str, self.token_analysis.search.transliterate(name)).strip() return cast(str, self.token_analysis.search.transliterate(name)).strip()
def _normalized(self, name: str) -> str: def _normalized(self, name: str) -> str:
""" Return the normalized version of the given name with all """ Return the normalized version of the given name with all
non-relevant information removed. non-relevant information removed.
""" """
return cast(str, self.token_analysis.normalizer.transliterate(name)).strip() return cast(str, self.token_analysis.normalizer.transliterate(name)).strip()
def get_word_token_info(self, words: Sequence[str]) -> List[Tuple[str, str, int]]: def get_word_token_info(self, words: Sequence[str]) -> List[Tuple[str, str, int]]:
""" Return token information for the given list of words. """ Return token information for the given list of words.
If a word starts with # it is assumed to be a full name If a word starts with # it is assumed to be a full name
@@ -432,8 +413,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
part_ids = {r[0]: r[1] for r in cur} part_ids = {r[0]: r[1] for r in cur}
return [(k, v, full_ids.get(v, None)) for k, v in full_tokens.items()] \ return [(k, v, full_ids.get(v, None)) for k, v in full_tokens.items()] \
+ [(k, v, part_ids.get(v, None)) for k, v in partial_tokens.items()] + [(k, v, part_ids.get(v, None)) for k, v in partial_tokens.items()]
def normalize_postcode(self, postcode: str) -> str: def normalize_postcode(self, postcode: str) -> str:
""" Convert the postcode to a standardized form. """ Convert the postcode to a standardized form.
@@ -443,7 +423,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
""" """
return postcode.strip().upper() return postcode.strip().upper()
def update_postcodes_from_db(self) -> None: def update_postcodes_from_db(self) -> None:
""" Update postcode tokens in the word table from the location_postcode """ Update postcode tokens in the word table from the location_postcode
table. table.
@@ -516,9 +495,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.executemany("""SELECT create_postcode_word(%s, %s)""", terms) cur.executemany("""SELECT create_postcode_word(%s, %s)""", terms)
def update_special_phrases(self, phrases: Iterable[Tuple[str, str, str, str]], def update_special_phrases(self, phrases: Iterable[Tuple[str, str, str, str]],
should_replace: bool) -> None: should_replace: bool) -> None:
""" Replace the search index for special phrases with the new phrases. """ Replace the search index for special phrases with the new phrases.
@@ -548,7 +524,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
LOG.info("Total phrases: %s. Added: %s. Deleted: %s", LOG.info("Total phrases: %s. Added: %s. Deleted: %s",
len(norm_phrases), added, deleted) len(norm_phrases), added, deleted)
def _add_special_phrases(self, cursor: Cursor, def _add_special_phrases(self, cursor: Cursor,
new_phrases: Set[Tuple[str, str, str, str]], new_phrases: Set[Tuple[str, str, str, str]],
existing_phrases: Set[Tuple[str, str, str, str]]) -> int: existing_phrases: Set[Tuple[str, str, str, str]]) -> int:
@@ -568,10 +543,9 @@ class ICUNameAnalyzer(AbstractAnalyzer):
return added return added
def _remove_special_phrases(self, cursor: Cursor, def _remove_special_phrases(self, cursor: Cursor,
new_phrases: Set[Tuple[str, str, str, str]], new_phrases: Set[Tuple[str, str, str, str]],
existing_phrases: Set[Tuple[str, str, str, str]]) -> int: existing_phrases: Set[Tuple[str, str, str, str]]) -> int:
""" Remove all phrases from the database that are no longer in the """ Remove all phrases from the database that are no longer in the
new phrase list. new phrase list.
""" """
@@ -587,7 +561,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
return len(to_delete) return len(to_delete)
def add_country_names(self, country_code: str, names: Mapping[str, str]) -> None: def add_country_names(self, country_code: str, names: Mapping[str, str]) -> None:
""" Add default names for the given country to the search index. """ Add default names for the given country to the search index.
""" """
@@ -599,7 +572,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
self.sanitizer.process_names(info)[0], self.sanitizer.process_names(info)[0],
internal=True) internal=True)
def _add_country_full_names(self, country_code: str, names: Sequence[PlaceName], def _add_country_full_names(self, country_code: str, names: Sequence[PlaceName],
internal: bool = False) -> None: internal: bool = False) -> None:
""" Add names for the given country from an already sanitized """ Add names for the given country from an already sanitized
@@ -651,7 +623,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
""" """
cur.execute(sql, (country_code, list(new_tokens))) cur.execute(sql, (country_code, list(new_tokens)))
def process_place(self, place: PlaceInfo) -> Mapping[str, Any]: def process_place(self, place: PlaceInfo) -> Mapping[str, Any]:
""" Determine tokenizer information about the given place. """ Determine tokenizer information about the given place.
@@ -674,7 +645,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
return token_info.to_dict() return token_info.to_dict()
def _process_place_address(self, token_info: '_TokenInfo', def _process_place_address(self, token_info: '_TokenInfo',
address: Sequence[PlaceName]) -> None: address: Sequence[PlaceName]) -> None:
for item in address: for item in address:
@@ -687,12 +657,11 @@ class ICUNameAnalyzer(AbstractAnalyzer):
elif item.kind == 'place': elif item.kind == 'place':
if not item.suffix: if not item.suffix:
token_info.add_place(itertools.chain(*self._compute_name_tokens([item]))) token_info.add_place(itertools.chain(*self._compute_name_tokens([item])))
elif not item.kind.startswith('_') and not item.suffix and \ elif (not item.kind.startswith('_') and not item.suffix and
item.kind not in ('country', 'full', 'inclusion'): item.kind not in ('country', 'full', 'inclusion')):
token_info.add_address_term(item.kind, token_info.add_address_term(item.kind,
itertools.chain(*self._compute_name_tokens([item]))) itertools.chain(*self._compute_name_tokens([item])))
def _compute_housenumber_token(self, hnr: PlaceName) -> Tuple[Optional[int], Optional[str]]: def _compute_housenumber_token(self, hnr: PlaceName) -> Tuple[Optional[int], Optional[str]]:
""" Normalize the housenumber and return the word token and the """ Normalize the housenumber and return the word token and the
canonical form. canonical form.
@@ -728,7 +697,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
return result return result
def _retrieve_full_tokens(self, name: str) -> List[int]: def _retrieve_full_tokens(self, name: str) -> List[int]:
""" Get the full name token for the given name, if it exists. """ Get the full name token for the given name, if it exists.
The name is only retrieved for the standard analyser. The name is only retrieved for the standard analyser.
@@ -749,7 +717,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
return full return full
def _compute_name_tokens(self, names: Sequence[PlaceName]) -> Tuple[Set[int], Set[int]]: def _compute_name_tokens(self, names: Sequence[PlaceName]) -> Tuple[Set[int], Set[int]]:
""" Computes the full name and partial name tokens for the given """ Computes the full name and partial name tokens for the given
dictionary of names. dictionary of names.
@@ -787,7 +754,6 @@ class ICUNameAnalyzer(AbstractAnalyzer):
return full_tokens, partial_tokens return full_tokens, partial_tokens
def _add_postcode(self, item: PlaceName) -> Optional[str]: def _add_postcode(self, item: PlaceName) -> Optional[str]:
""" Make sure the normalized postcode is present in the word table. """ Make sure the normalized postcode is present in the word table.
""" """
@@ -835,11 +801,9 @@ class _TokenInfo:
self.address_tokens: Dict[str, str] = {} self.address_tokens: Dict[str, str] = {}
self.postcode: Optional[str] = None self.postcode: Optional[str] = None
def _mk_array(self, tokens: Iterable[Any]) -> str: def _mk_array(self, tokens: Iterable[Any]) -> str:
return f"{{{','.join((str(s) for s in tokens))}}}" return f"{{{','.join((str(s) for s in tokens))}}}"
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" Return the token information in database importable format. """ Return the token information in database importable format.
""" """
@@ -866,13 +830,11 @@ class _TokenInfo:
return out return out
def set_names(self, fulls: Iterable[int], partials: Iterable[int]) -> None: def set_names(self, fulls: Iterable[int], partials: Iterable[int]) -> None:
""" Adds token information for the normalised names. """ Adds token information for the normalised names.
""" """
self.names = self._mk_array(itertools.chain(fulls, partials)) self.names = self._mk_array(itertools.chain(fulls, partials))
def add_housenumber(self, token: Optional[int], hnr: Optional[str]) -> None: def add_housenumber(self, token: Optional[int], hnr: Optional[str]) -> None:
""" Extract housenumber information from a list of normalised """ Extract housenumber information from a list of normalised
housenumbers. housenumbers.
@@ -882,7 +844,6 @@ class _TokenInfo:
self.housenumbers.add(hnr) self.housenumbers.add(hnr)
self.housenumber_tokens.add(token) self.housenumber_tokens.add(token)
def add_street(self, tokens: Iterable[int]) -> None: def add_street(self, tokens: Iterable[int]) -> None:
""" Add addr:street match terms. """ Add addr:street match terms.
""" """
@@ -890,13 +851,11 @@ class _TokenInfo:
self.street_tokens = set() self.street_tokens = set()
self.street_tokens.update(tokens) self.street_tokens.update(tokens)
def add_place(self, tokens: Iterable[int]) -> None: def add_place(self, tokens: Iterable[int]) -> None:
""" Add addr:place search and match terms. """ Add addr:place search and match terms.
""" """
self.place_tokens.update(tokens) self.place_tokens.update(tokens)
def add_address_term(self, key: str, partials: Iterable[int]) -> None: def add_address_term(self, key: str, partials: Iterable[int]) -> None:
""" Add additional address terms. """ Add additional address terms.
""" """

View File

@@ -39,7 +39,6 @@ class PlaceSanitizer:
self.handlers.append(module.create(SanitizerConfig(func))) self.handlers.append(module.create(SanitizerConfig(func)))
def process_names(self, place: PlaceInfo) -> Tuple[List[PlaceName], List[PlaceName]]: def process_names(self, place: PlaceInfo) -> Tuple[List[PlaceName], List[PlaceName]]:
""" Extract a sanitized list of names and address parts from the """ Extract a sanitized list of names and address parts from the
given place. The function returns a tuple given place. The function returns a tuple

View File

@@ -27,7 +27,6 @@ class ProcessInfo:
self.names = self._convert_name_dict(place.name) self.names = self._convert_name_dict(place.name)
self.address = self._convert_name_dict(place.address) self.address = self._convert_name_dict(place.address)
@staticmethod @staticmethod
def _convert_name_dict(names: Optional[Mapping[str, str]]) -> List[PlaceName]: def _convert_name_dict(names: Optional[Mapping[str, str]]) -> List[PlaceName]:
""" Convert a dictionary of names into a list of PlaceNames. """ Convert a dictionary of names into a list of PlaceNames.

View File

@@ -30,6 +30,7 @@ from ...data.place_name import PlaceName
from .base import ProcessInfo from .base import ProcessInfo
from .config import SanitizerConfig from .config import SanitizerConfig
class _HousenumberSanitizer: class _HousenumberSanitizer:
def __init__(self, config: SanitizerConfig) -> None: def __init__(self, config: SanitizerConfig) -> None:
@@ -38,7 +39,6 @@ class _HousenumberSanitizer:
self.filter_name = config.get_filter('convert-to-name', 'FAIL_ALL') self.filter_name = config.get_filter('convert-to-name', 'FAIL_ALL')
def __call__(self, obj: ProcessInfo) -> None: def __call__(self, obj: ProcessInfo) -> None:
if not obj.address: if not obj.address:
return return
@@ -57,7 +57,6 @@ class _HousenumberSanitizer:
obj.address = new_address obj.address = new_address
def sanitize(self, value: str) -> Iterator[str]: def sanitize(self, value: str) -> Iterator[str]:
""" Extract housenumbers in a regularized format from an OSM value. """ Extract housenumbers in a regularized format from an OSM value.
@@ -68,7 +67,6 @@ class _HousenumberSanitizer:
if hnr: if hnr:
yield from self._regularize(hnr) yield from self._regularize(hnr)
def _regularize(self, hnr: str) -> Iterator[str]: def _regularize(self, hnr: str) -> Iterator[str]:
yield hnr yield hnr

View File

@@ -26,6 +26,7 @@ from ...data.postcode_format import PostcodeFormatter
from .base import ProcessInfo from .base import ProcessInfo
from .config import SanitizerConfig from .config import SanitizerConfig
class _PostcodeSanitizer: class _PostcodeSanitizer:
def __init__(self, config: SanitizerConfig) -> None: def __init__(self, config: SanitizerConfig) -> None:
@@ -36,7 +37,6 @@ class _PostcodeSanitizer:
if default_pattern is not None and isinstance(default_pattern, str): if default_pattern is not None and isinstance(default_pattern, str):
self.matcher.set_default_pattern(default_pattern) self.matcher.set_default_pattern(default_pattern)
def __call__(self, obj: ProcessInfo) -> None: def __call__(self, obj: ProcessInfo) -> None:
if not obj.address: if not obj.address:
return return
@@ -55,7 +55,6 @@ class _PostcodeSanitizer:
postcode.name = formatted[0] postcode.name = formatted[0]
postcode.set_attr('variant', formatted[1]) postcode.set_attr('variant', formatted[1])
def scan(self, postcode: str, country: Optional[str]) -> Optional[Tuple[str, str]]: def scan(self, postcode: str, country: Optional[str]) -> Optional[Tuple[str, str]]:
""" Check the postcode for correct formatting and return the """ Check the postcode for correct formatting and return the
normalized version. Returns None if the postcode does not normalized version. Returns None if the postcode does not
@@ -67,10 +66,8 @@ class _PostcodeSanitizer:
assert country is not None assert country is not None
return self.matcher.normalize(country, match),\ return self.matcher.normalize(country, match), \
' '.join(filter(lambda p: p is not None, match.groups())) ' '.join(filter(lambda p: p is not None, match.groups()))
def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]: def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]:

View File

@@ -19,6 +19,7 @@ from .config import SanitizerConfig
COUNTY_MATCH = re.compile('(.*), [A-Z][A-Z]') COUNTY_MATCH = re.compile('(.*), [A-Z][A-Z]')
def _clean_tiger_county(obj: ProcessInfo) -> None: def _clean_tiger_county(obj: ProcessInfo) -> None:
""" Remove the state reference from tiger:county tags. """ Remove the state reference from tiger:county tags.

View File

@@ -20,6 +20,7 @@ if TYPE_CHECKING:
else: else:
_BaseUserDict = UserDict _BaseUserDict = UserDict
class SanitizerConfig(_BaseUserDict): class SanitizerConfig(_BaseUserDict):
""" The `SanitizerConfig` class is a read-only dictionary """ The `SanitizerConfig` class is a read-only dictionary
with configuration options for the sanitizer. with configuration options for the sanitizer.
@@ -61,7 +62,6 @@ class SanitizerConfig(_BaseUserDict):
return values return values
def get_bool(self, param: str, default: Optional[bool] = None) -> bool: def get_bool(self, param: str, default: Optional[bool] = None) -> bool:
""" Extract a configuration parameter as a boolean. """ Extract a configuration parameter as a boolean.
@@ -82,7 +82,6 @@ class SanitizerConfig(_BaseUserDict):
return value return value
def get_delimiter(self, default: str = ',;') -> Pattern[str]: def get_delimiter(self, default: str = ',;') -> Pattern[str]:
""" Return the 'delimiters' parameter in the configuration as a """ Return the 'delimiters' parameter in the configuration as a
compiled regular expression that can be used to split strings on compiled regular expression that can be used to split strings on
@@ -105,7 +104,6 @@ class SanitizerConfig(_BaseUserDict):
return re.compile('\\s*[{}]+\\s*'.format(''.join('\\' + d for d in delimiter_set))) return re.compile('\\s*[{}]+\\s*'.format(''.join('\\' + d for d in delimiter_set)))
def get_filter(self, param: str, default: Union[str, Sequence[str]] = 'PASS_ALL' def get_filter(self, param: str, default: Union[str, Sequence[str]] = 'PASS_ALL'
) -> Callable[[str], bool]: ) -> Callable[[str], bool]:
""" Returns a filter function for the given parameter of the sanitizer """ Returns a filter function for the given parameter of the sanitizer

View File

@@ -60,6 +60,7 @@ from ...data.place_name import PlaceName
from .base import ProcessInfo from .base import ProcessInfo
from .config import SanitizerConfig from .config import SanitizerConfig
class _TagSanitizer: class _TagSanitizer:
def __init__(self, config: SanitizerConfig) -> None: def __init__(self, config: SanitizerConfig) -> None:
@@ -74,7 +75,6 @@ class _TagSanitizer:
self.has_country_code = config.get('country_code', None) is not None self.has_country_code = config.get('country_code', None) is not None
def __call__(self, obj: ProcessInfo) -> None: def __call__(self, obj: ProcessInfo) -> None:
tags = obj.names if self.type == 'name' else obj.address tags = obj.names if self.type == 'name' else obj.address
@@ -93,13 +93,11 @@ class _TagSanitizer:
or not self.filter_name(tag.name): or not self.filter_name(tag.name):
filtered_tags.append(tag) filtered_tags.append(tag)
if self.type == 'name': if self.type == 'name':
obj.names = filtered_tags obj.names = filtered_tags
else: else:
obj.address = filtered_tags obj.address = filtered_tags
def _set_allowed_ranks(self, ranks: Sequence[str]) -> Tuple[bool, ...]: def _set_allowed_ranks(self, ranks: Sequence[str]) -> Tuple[bool, ...]:
""" Returns a tuple of 31 boolean values corresponding to the """ Returns a tuple of 31 boolean values corresponding to the
address ranks 0-30. Value at index 'i' is True if rank 'i' address ranks 0-30. Value at index 'i' is True if rank 'i'
@@ -117,7 +115,6 @@ class _TagSanitizer:
for i in range(start, end + 1): for i in range(start, end + 1):
allowed_ranks[i] = True allowed_ranks[i] = True
return tuple(allowed_ranks) return tuple(allowed_ranks)

View File

@@ -16,6 +16,7 @@ from typing import Callable
from .base import ProcessInfo from .base import ProcessInfo
from .config import SanitizerConfig from .config import SanitizerConfig
def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]: def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]:
""" Create a name processing function that splits name values with """ Create a name processing function that splits name values with
multiple values into their components. multiple values into their components.

View File

@@ -36,6 +36,7 @@ from ...data import country_info
from .base import ProcessInfo from .base import ProcessInfo
from .config import SanitizerConfig from .config import SanitizerConfig
class _AnalyzerByLanguage: class _AnalyzerByLanguage:
""" Processor for tagging the language of names in a place. """ Processor for tagging the language of names in a place.
""" """
@@ -47,7 +48,6 @@ class _AnalyzerByLanguage:
self._compute_default_languages(config.get('use-defaults', 'no')) self._compute_default_languages(config.get('use-defaults', 'no'))
def _compute_default_languages(self, use_defaults: str) -> None: def _compute_default_languages(self, use_defaults: str) -> None:
self.deflangs: Dict[Optional[str], List[str]] = {} self.deflangs: Dict[Optional[str], List[str]] = {}
@@ -55,18 +55,16 @@ class _AnalyzerByLanguage:
for ccode, clangs in country_info.iterate('languages'): for ccode, clangs in country_info.iterate('languages'):
if len(clangs) == 1 or use_defaults == 'all': if len(clangs) == 1 or use_defaults == 'all':
if self.whitelist: if self.whitelist:
self.deflangs[ccode] = [l for l in clangs if l in self.whitelist] self.deflangs[ccode] = [cl for cl in clangs if cl in self.whitelist]
else: else:
self.deflangs[ccode] = clangs self.deflangs[ccode] = clangs
def _suffix_matches(self, suffix: str) -> bool: def _suffix_matches(self, suffix: str) -> bool:
if self.whitelist is None: if self.whitelist is None:
return len(suffix) in (2, 3) and suffix.islower() return len(suffix) in (2, 3) and suffix.islower()
return suffix in self.whitelist return suffix in self.whitelist
def __call__(self, obj: ProcessInfo) -> None: def __call__(self, obj: ProcessInfo) -> None:
if not obj.names: if not obj.names:
return return
@@ -80,14 +78,13 @@ class _AnalyzerByLanguage:
else: else:
langs = self.deflangs.get(obj.place.country_code) langs = self.deflangs.get(obj.place.country_code)
if langs: if langs:
if self.replace: if self.replace:
name.set_attr('analyzer', langs[0]) name.set_attr('analyzer', langs[0])
else: else:
more_names.append(name.clone(attr={'analyzer': langs[0]})) more_names.append(name.clone(attr={'analyzer': langs[0]}))
more_names.extend(name.clone(attr={'analyzer': l}) for l in langs[1:]) more_names.extend(name.clone(attr={'analyzer': lg}) for lg in langs[1:])
obj.names.extend(more_names) obj.names.extend(more_names)

View File

@@ -18,11 +18,13 @@ from .base import ProcessInfo
from .config import SanitizerConfig from .config import SanitizerConfig
from ...data.place_name import PlaceName from ...data.place_name import PlaceName
def create(_: SanitizerConfig) -> Callable[[ProcessInfo], None]: def create(_: SanitizerConfig) -> Callable[[ProcessInfo], None]:
"""Set up the sanitizer """Set up the sanitizer
""" """
return tag_japanese return tag_japanese
def reconbine_housenumber( def reconbine_housenumber(
new_address: List[PlaceName], new_address: List[PlaceName],
tmp_housenumber: Optional[str], tmp_housenumber: Optional[str],
@@ -56,6 +58,7 @@ def reconbine_housenumber(
) )
return new_address return new_address
def reconbine_place( def reconbine_place(
new_address: List[PlaceName], new_address: List[PlaceName],
tmp_neighbourhood: Optional[str], tmp_neighbourhood: Optional[str],
@@ -88,6 +91,8 @@ def reconbine_place(
) )
) )
return new_address return new_address
def tag_japanese(obj: ProcessInfo) -> None: def tag_japanese(obj: ProcessInfo) -> None:
"""Recombine kind of address """Recombine kind of address
""" """

View File

@@ -12,6 +12,7 @@ from typing import Mapping, List, Any
from ...typing import Protocol from ...typing import Protocol
from ...data.place_name import PlaceName from ...data.place_name import PlaceName
class Analyzer(Protocol): class Analyzer(Protocol):
""" The `create()` function of an analysis module needs to return an """ The `create()` function of an analysis module needs to return an
object that implements the following functions. object that implements the following functions.

View File

@@ -15,6 +15,7 @@ import re
from ...config import flatten_config_list from ...config import flatten_config_list
from ...errors import UsageError from ...errors import UsageError
class ICUVariant(NamedTuple): class ICUVariant(NamedTuple):
""" A single replacement rule for variant creation. """ A single replacement rule for variant creation.
""" """
@@ -64,7 +65,6 @@ class _VariantMaker:
def __init__(self, normalizer: Any) -> None: def __init__(self, normalizer: Any) -> None:
self.norm = normalizer self.norm = normalizer
def compute(self, rule: Any) -> Iterator[ICUVariant]: def compute(self, rule: Any) -> Iterator[ICUVariant]:
""" Generator for all ICUVariant tuples from a single variant rule. """ Generator for all ICUVariant tuples from a single variant rule.
""" """
@@ -88,7 +88,6 @@ class _VariantMaker:
for froms, tos in _create_variants(*src, repl, decompose): for froms, tos in _create_variants(*src, repl, decompose):
yield ICUVariant(froms, tos) yield ICUVariant(froms, tos)
def _parse_variant_word(self, name: str) -> Optional[Tuple[str, str, str]]: def _parse_variant_word(self, name: str) -> Optional[Tuple[str, str, str]]:
name = name.strip() name = name.strip()
match = re.fullmatch(r'([~^]?)([^~$^]*)([~$]?)', name) match = re.fullmatch(r'([~^]?)([^~$^]*)([~$]?)', name)

View File

@@ -17,7 +17,8 @@ from ...data.place_name import PlaceName
from .config_variants import get_variant_config from .config_variants import get_variant_config
from .generic_mutation import MutationVariantGenerator from .generic_mutation import MutationVariantGenerator
### Configuration section # Configuration section
def configure(rules: Mapping[str, Any], normalizer: Any, _: Any) -> Dict[str, Any]: def configure(rules: Mapping[str, Any], normalizer: Any, _: Any) -> Dict[str, Any]:
""" Extract and preprocess the configuration for this module. """ Extract and preprocess the configuration for this module.
@@ -47,7 +48,7 @@ def configure(rules: Mapping[str, Any], normalizer: Any, _: Any) -> Dict[str, An
return config return config
### Analysis section # Analysis section
def create(normalizer: Any, transliterator: Any, def create(normalizer: Any, transliterator: Any,
config: Mapping[str, Any]) -> 'GenericTokenAnalysis': config: Mapping[str, Any]) -> 'GenericTokenAnalysis':
@@ -77,14 +78,12 @@ class GenericTokenAnalysis:
# set up mutation rules # set up mutation rules
self.mutations = [MutationVariantGenerator(*cfg) for cfg in config['mutations']] self.mutations = [MutationVariantGenerator(*cfg) for cfg in config['mutations']]
def get_canonical_id(self, name: PlaceName) -> str: def get_canonical_id(self, name: PlaceName) -> str:
""" Return the normalized form of the name. This is the standard form """ Return the normalized form of the name. This is the standard form
from which possible variants for the name can be derived. from which possible variants for the name can be derived.
""" """
return cast(str, self.norm.transliterate(name.name)).strip() return cast(str, self.norm.transliterate(name.name)).strip()
def compute_variants(self, norm_name: str) -> List[str]: def compute_variants(self, norm_name: str) -> List[str]:
""" Compute the spelling variants for the given normalized name """ Compute the spelling variants for the given normalized name
and transliterate the result. and transliterate the result.
@@ -96,7 +95,6 @@ class GenericTokenAnalysis:
return [name for name in self._transliterate_unique_list(norm_name, variants) if name] return [name for name in self._transliterate_unique_list(norm_name, variants) if name]
def _transliterate_unique_list(self, norm_name: str, def _transliterate_unique_list(self, norm_name: str,
iterable: Iterable[str]) -> Iterator[Optional[str]]: iterable: Iterable[str]) -> Iterator[Optional[str]]:
seen = set() seen = set()
@@ -108,7 +106,6 @@ class GenericTokenAnalysis:
seen.add(variant) seen.add(variant)
yield self.to_ascii.transliterate(variant).strip() yield self.to_ascii.transliterate(variant).strip()
def _generate_word_variants(self, norm_name: str) -> Iterable[str]: def _generate_word_variants(self, norm_name: str) -> Iterable[str]:
baseform = '^ ' + norm_name + ' ^' baseform = '^ ' + norm_name + ' ^'
baselen = len(baseform) baselen = len(baseform)

View File

@@ -16,6 +16,7 @@ from ...errors import UsageError
LOG = logging.getLogger() LOG = logging.getLogger()
def _zigzag(outer: Iterable[str], inner: Iterable[str]) -> Iterator[str]: def _zigzag(outer: Iterable[str], inner: Iterable[str]) -> Iterator[str]:
return itertools.chain.from_iterable(itertools.zip_longest(outer, inner, fillvalue='')) return itertools.chain.from_iterable(itertools.zip_longest(outer, inner, fillvalue=''))
@@ -36,7 +37,6 @@ class MutationVariantGenerator:
"This is not allowed.", pattern) "This is not allowed.", pattern)
raise UsageError("Bad mutation pattern in configuration.") raise UsageError("Bad mutation pattern in configuration.")
def generate(self, names: Iterable[str]) -> Iterator[str]: def generate(self, names: Iterable[str]) -> Iterator[str]:
""" Generator function for the name variants. 'names' is an iterable """ Generator function for the name variants. 'names' is an iterable
over a set of names for which the variants are to be generated. over a set of names for which the variants are to be generated.
@@ -49,7 +49,6 @@ class MutationVariantGenerator:
for seps in self._fillers(len(parts)): for seps in self._fillers(len(parts)):
yield ''.join(_zigzag(parts, seps)) yield ''.join(_zigzag(parts, seps))
def _fillers(self, num_parts: int) -> Iterator[Tuple[str, ...]]: def _fillers(self, num_parts: int) -> Iterator[Tuple[str, ...]]:
""" Returns a generator for strings to join the given number of string """ Returns a generator for strings to join the given number of string
parts in all possible combinations. parts in all possible combinations.

View File

@@ -19,16 +19,18 @@ RE_DIGIT_ALPHA = re.compile(r'(\d)\s*([^\d\s␣])')
RE_ALPHA_DIGIT = re.compile(r'([^\s\d␣])\s*(\d)') RE_ALPHA_DIGIT = re.compile(r'([^\s\d␣])\s*(\d)')
RE_NAMED_PART = re.compile(r'[a-z]{4}') RE_NAMED_PART = re.compile(r'[a-z]{4}')
### Configuration section # Configuration section
def configure(*_: Any) -> None: def configure(*_: Any) -> None:
""" All behaviour is currently hard-coded. """ All behaviour is currently hard-coded.
""" """
return None return None
### Analysis section # Analysis section
def create(normalizer: Any, transliterator: Any, config: None) -> 'HousenumberTokenAnalysis': # pylint: disable=W0613
def create(normalizer: Any, transliterator: Any, config: None) -> 'HousenumberTokenAnalysis':
""" Create a new token analysis instance for this module. """ Create a new token analysis instance for this module.
""" """
return HousenumberTokenAnalysis(normalizer, transliterator) return HousenumberTokenAnalysis(normalizer, transliterator)

View File

@@ -2,7 +2,7 @@
# #
# This file is part of Nominatim. (https://nominatim.org) # This file is part of Nominatim. (https://nominatim.org)
# #
# Copyright (C) 2022 by the Nominatim developer community. # Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log. # For a full list of authors see the git log.
""" """
Specialized processor for postcodes. Supports a 'lookup' variant of the Specialized processor for postcodes. Supports a 'lookup' variant of the
@@ -13,16 +13,18 @@ from typing import Any, List
from ...data.place_name import PlaceName from ...data.place_name import PlaceName
from .generic_mutation import MutationVariantGenerator from .generic_mutation import MutationVariantGenerator
### Configuration section # Configuration section
def configure(*_: Any) -> None: def configure(*_: Any) -> None:
""" All behaviour is currently hard-coded. """ All behaviour is currently hard-coded.
""" """
return None return None
### Analysis section # Analysis section
def create(normalizer: Any, transliterator: Any, config: None) -> 'PostcodeTokenAnalysis': # pylint: disable=W0613
def create(normalizer: Any, transliterator: Any, config: None) -> 'PostcodeTokenAnalysis':
""" Create a new token analysis instance for this module. """ Create a new token analysis instance for this module.
""" """
return PostcodeTokenAnalysis(normalizer, transliterator) return PostcodeTokenAnalysis(normalizer, transliterator)
@@ -44,13 +46,11 @@ class PostcodeTokenAnalysis:
self.mutator = MutationVariantGenerator(' ', (' ', '')) self.mutator = MutationVariantGenerator(' ', (' ', ''))
def get_canonical_id(self, name: PlaceName) -> str: def get_canonical_id(self, name: PlaceName) -> str:
""" Return the standard form of the postcode. """ Return the standard form of the postcode.
""" """
return name.name.strip().upper() return name.name.strip().upper()
def compute_variants(self, norm_name: str) -> List[str]: def compute_variants(self, norm_name: str) -> List[str]:
""" Compute the spelling variants for the given normalized postcode. """ Compute the spelling variants for the given normalized postcode.

View File

@@ -18,6 +18,7 @@ from .exec_utils import run_osm2pgsql
LOG = logging.getLogger() LOG = logging.getLogger()
def _run_osm2pgsql(dsn: str, options: MutableMapping[str, Any]) -> None: def _run_osm2pgsql(dsn: str, options: MutableMapping[str, Any]) -> None:
run_osm2pgsql(options) run_osm2pgsql(options)

View File

@@ -22,6 +22,7 @@ from ..data.place_info import PlaceInfo
LOG = logging.getLogger() LOG = logging.getLogger()
def _get_place_info(cursor: Cursor, osm_id: Optional[str], def _get_place_info(cursor: Cursor, osm_id: Optional[str],
place_id: Optional[int]) -> DictCursorResult: place_id: Optional[int]) -> DictCursorResult:
sql = """SELECT place_id, extra.* sql = """SELECT place_id, extra.*

View File

@@ -12,7 +12,7 @@ from enum import Enum
from textwrap import dedent from textwrap import dedent
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, Connection, server_version_tuple,\ from ..db.connection import connect, Connection, server_version_tuple, \
index_exists, table_exists, execute_scalar index_exists, table_exists, execute_scalar
from ..db import properties from ..db import properties
from ..errors import UsageError from ..errors import UsageError
@@ -22,6 +22,7 @@ from ..version import NOMINATIM_VERSION, parse_version
CHECKLIST = [] CHECKLIST = []
class CheckState(Enum): class CheckState(Enum):
""" Possible states of a check. FATAL stops check execution entirely. """ Possible states of a check. FATAL stops check execution entirely.
""" """
@@ -31,9 +32,11 @@ class CheckState(Enum):
NOT_APPLICABLE = 3 NOT_APPLICABLE = 3
WARN = 4 WARN = 4
CheckResult = Union[CheckState, Tuple[CheckState, Mapping[str, Any]]] CheckResult = Union[CheckState, Tuple[CheckState, Mapping[str, Any]]]
CheckFunc = Callable[[Connection, Configuration], CheckResult] CheckFunc = Callable[[Connection, Configuration], CheckResult]
def _check(hint: Optional[str] = None) -> Callable[[CheckFunc], CheckFunc]: def _check(hint: Optional[str] = None) -> Callable[[CheckFunc], CheckFunc]:
""" Decorator for checks. It adds the function to the list of """ Decorator for checks. It adds the function to the list of
checks to execute and adds the code for printing progress messages. checks to execute and adds the code for printing progress messages.
@@ -68,6 +71,7 @@ def _check(hint: Optional[str] = None) -> Callable[[CheckFunc], CheckFunc]:
return decorator return decorator
class _BadConnection: class _BadConnection:
def __init__(self, msg: str) -> None: def __init__(self, msg: str) -> None:
@@ -77,13 +81,14 @@ class _BadConnection:
""" Dummy function to provide the implementation. """ Dummy function to provide the implementation.
""" """
def check_database(config: Configuration) -> int: def check_database(config: Configuration) -> int:
""" Run a number of checks on the database and return the status. """ Run a number of checks on the database and return the status.
""" """
try: try:
conn = connect(config.get_libpq_dsn()) conn = connect(config.get_libpq_dsn())
except UsageError as err: except UsageError as err:
conn = _BadConnection(str(err)) # type: ignore[assignment] conn = _BadConnection(str(err)) # type: ignore[assignment]
overall_result = 0 overall_result = 0
for check in CHECKLIST: for check in CHECKLIST:
@@ -110,7 +115,7 @@ def _get_indexes(conn: Connection) -> List[str]:
'idx_osmline_parent_osm_id', 'idx_osmline_parent_osm_id',
'idx_postcode_id', 'idx_postcode_id',
'idx_postcode_postcode' 'idx_postcode_postcode'
] ]
# These won't exist if --reverse-only import was used # These won't exist if --reverse-only import was used
if table_exists(conn, 'search_name'): if table_exists(conn, 'search_name'):
@@ -154,6 +159,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult:
return CheckState.OK return CheckState.OK
@_check(hint="""\ @_check(hint="""\
Database version ({db_version}) doesn't match Nominatim version ({nom_version}) Database version ({db_version}) doesn't match Nominatim version ({nom_version})
@@ -195,6 +201,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
instruction=instruction, instruction=instruction,
config=config) config=config)
@_check(hint="""\ @_check(hint="""\
placex table not found placex table not found
@@ -274,7 +281,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
return CheckState.OK return CheckState.OK
if freeze.is_frozen(conn): if freeze.is_frozen(conn):
index_cmd="""\ index_cmd = """\
Database is marked frozen, it cannot be updated. Database is marked frozen, it cannot be updated.
Low counts of unindexed places are fine.""" Low counts of unindexed places are fine."""
return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd) return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)

View File

@@ -21,6 +21,7 @@ from nominatim_api.sql.sqlalchemy_types import Geometry, IntArray
LOG = logging.getLogger() LOG = logging.getLogger()
async def convert(project_dir: Optional[Union[str, Path]], async def convert(project_dir: Optional[Union[str, Path]],
outfile: Path, options: Set[str]) -> None: outfile: Path, options: Set[str]) -> None:
""" Export an existing database to sqlite. The resulting database """ Export an existing database to sqlite. The resulting database
@@ -53,7 +54,6 @@ class SqliteWriter:
self.dest = dest self.dest = dest
self.options = options self.options = options
async def write(self) -> None: async def write(self) -> None:
""" Create the database structure and copy the data from """ Create the database structure and copy the data from
the source database to the destination. the source database to the destination.
@@ -67,7 +67,6 @@ class SqliteWriter:
await self.create_word_table() await self.create_word_table()
await self.create_indexes() await self.create_indexes()
async def create_tables(self) -> None: async def create_tables(self) -> None:
""" Set up the database tables. """ Set up the database tables.
""" """
@@ -87,7 +86,6 @@ class SqliteWriter:
sa.func.RecoverGeometryColumn(table.name, col.name, 4326, sa.func.RecoverGeometryColumn(table.name, col.name, 4326,
col.type.subtype.upper(), 'XY'))) col.type.subtype.upper(), 'XY')))
async def create_class_tables(self) -> None: async def create_class_tables(self) -> None:
""" Set up the table that serve class/type-specific geometries. """ Set up the table that serve class/type-specific geometries.
""" """
@@ -99,7 +97,6 @@ class SqliteWriter:
sa.Column('place_id', sa.BigInteger), sa.Column('place_id', sa.BigInteger),
sa.Column('centroid', Geometry)) sa.Column('centroid', Geometry))
async def create_word_table(self) -> None: async def create_word_table(self) -> None:
""" Create the word table. """ Create the word table.
This table needs the property information to determine the This table needs the property information to determine the
@@ -122,7 +119,6 @@ class SqliteWriter:
await self.dest.connection.run_sync(sa.Index('idx_word_woken', dest.c.word_token).create) await self.dest.connection.run_sync(sa.Index('idx_word_woken', dest.c.word_token).create)
async def copy_data(self) -> None: async def copy_data(self) -> None:
""" Copy data for all registered tables. """ Copy data for all registered tables.
""" """
@@ -151,7 +147,6 @@ class SqliteWriter:
data = [{'tablename': t} for t in self.dest.t.meta.tables] data = [{'tablename': t} for t in self.dest.t.meta.tables]
await self.dest.execute(pg_tables.insert().values(data)) await self.dest.execute(pg_tables.insert().values(data))
async def create_indexes(self) -> None: async def create_indexes(self) -> None:
""" Add indexes necessary for the frontend. """ Add indexes necessary for the frontend.
""" """
@@ -197,14 +192,12 @@ class SqliteWriter:
await self.dest.execute(sa.select( await self.dest.execute(sa.select(
sa.func.CreateSpatialIndex(t, 'centroid'))) sa.func.CreateSpatialIndex(t, 'centroid')))
async def create_spatial_index(self, table: str, column: str) -> None: async def create_spatial_index(self, table: str, column: str) -> None:
""" Create a spatial index on the given table and column. """ Create a spatial index on the given table and column.
""" """
await self.dest.execute(sa.select( await self.dest.execute(sa.select(
sa.func.CreateSpatialIndex(getattr(self.dest.t, table).name, column))) sa.func.CreateSpatialIndex(getattr(self.dest.t, table).name, column)))
async def create_index(self, table_name: str, column: str) -> None: async def create_index(self, table_name: str, column: str) -> None:
""" Create a simple index on the given table and column. """ Create a simple index on the given table and column.
""" """
@@ -212,7 +205,6 @@ class SqliteWriter:
await self.dest.connection.run_sync( await self.dest.connection.run_sync(
sa.Index(f"idx_{table}_{column}", getattr(table.c, column)).create) sa.Index(f"idx_{table}_{column}", getattr(table.c, column)).create)
async def create_search_index(self) -> None: async def create_search_index(self) -> None:
""" Create the tables and indexes needed for word lookup. """ Create the tables and indexes needed for word lookup.
""" """
@@ -242,7 +234,6 @@ class SqliteWriter:
await self.dest.connection.run_sync( await self.dest.connection.run_sync(
sa.Index('idx_reverse_search_name_word', rsn.c.word).create) sa.Index('idx_reverse_search_name_word', rsn.c.word).create)
def select_from(self, table: str) -> SaSelect: def select_from(self, table: str) -> SaSelect:
""" Create the SQL statement to select the source columns and rows. """ Create the SQL statement to select the source columns and rows.
""" """
@@ -258,9 +249,9 @@ class SqliteWriter:
columns.geometry), columns.geometry),
else_=sa.func.ST_SimplifyPreserveTopology( else_=sa.func.ST_SimplifyPreserveTopology(
columns.geometry, 0.0001) columns.geometry, 0.0001)
)).label('geometry')) )).label('geometry'))
sql = sa.select(*(sa.func.ST_AsText(c).label(c.name) sql = sa.select(*(sa.func.ST_AsText(c).label(c.name)
if isinstance(c.type, Geometry) else c for c in columns)) if isinstance(c.type, Geometry) else c for c in columns))
return sql return sql

View File

@@ -20,7 +20,7 @@ from psycopg import sql as pysql
from ..errors import UsageError from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\ from ..db.connection import connect, get_pg_env, Connection, server_version_tuple, \
postgis_version_tuple, drop_tables, table_exists, execute_scalar postgis_version_tuple, drop_tables, table_exists, execute_scalar
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
from ..db.query_pool import QueryPool from ..db.query_pool import QueryPool
@@ -29,6 +29,7 @@ from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
LOG = logging.getLogger() LOG = logging.getLogger()
def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None: def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
""" Compares the version for the given module and raises an exception """ Compares the version for the given module and raises an exception
if the actual version is too old. if the actual version is too old.
@@ -251,7 +252,7 @@ async def _progress_print() -> None:
async def create_search_indices(conn: Connection, config: Configuration, async def create_search_indices(conn: Connection, config: Configuration,
drop: bool = False, threads: int = 1) -> None: drop: bool = False, threads: int = 1) -> None:
""" Create tables that have explicit partitioning. """ Create tables that have explicit partitioning.
""" """

View File

@@ -35,7 +35,7 @@ def run_osm2pgsql(options: Mapping[str, Any]) -> None:
'--number-processes', '1' if options['append'] else str(options['threads']), '--number-processes', '1' if options['append'] else str(options['threads']),
'--cache', str(options['osm2pgsql_cache']), '--cache', str(options['osm2pgsql_cache']),
'--style', str(options['osm2pgsql_style']) '--style', str(options['osm2pgsql_style'])
] ]
if str(options['osm2pgsql_style']).endswith('.lua'): if str(options['osm2pgsql_style']).endswith('.lua'):
env['LUA_PATH'] = ';'.join((str(options['osm2pgsql_style_path'] / '?.lua'), env['LUA_PATH'] = ';'.join((str(options['osm2pgsql_style_path'] / '?.lua'),
@@ -50,7 +50,6 @@ def run_osm2pgsql(options: Mapping[str, Any]) -> None:
cmd.extend(('--output', 'gazetteer', '--hstore', '--latlon')) cmd.extend(('--output', 'gazetteer', '--hstore', '--latlon'))
cmd.extend(_mk_tablespace_options('main', options)) cmd.extend(_mk_tablespace_options('main', options))
if options['flatnode_file']: if options['flatnode_file']:
cmd.extend(('--flat-nodes', options['flatnode_file'])) cmd.extend(('--flat-nodes', options['flatnode_file']))

View File

@@ -28,6 +28,7 @@ UPDATE_TABLES = [
'wikipedia_%' 'wikipedia_%'
] ]
def drop_update_tables(conn: Connection) -> None: def drop_update_tables(conn: Connection) -> None:
""" Drop all tables only necessary for updating the database from """ Drop all tables only necessary for updating the database from
OSM replication data. OSM replication data.
@@ -49,8 +50,8 @@ def drop_flatnode_file(fpath: Optional[Path]) -> None:
if fpath and fpath.exists(): if fpath and fpath.exists():
fpath.unlink() fpath.unlink()
def is_frozen(conn: Connection) -> bool: def is_frozen(conn: Connection) -> bool:
""" Returns true if database is in a frozen state """ Returns true if database is in a frozen state
""" """
return table_exists(conn, 'place') is False return table_exists(conn, 'place') is False

View File

@@ -13,7 +13,7 @@ import logging
from ..errors import UsageError from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..db import properties from ..db import properties
from ..db.connection import connect, Connection,\ from ..db.connection import connect, Connection, \
table_exists, register_hstore table_exists, register_hstore
from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
from ..tokenizer import factory as tokenizer_factory from ..tokenizer import factory as tokenizer_factory
@@ -21,7 +21,8 @@ from . import refresh
LOG = logging.getLogger() LOG = logging.getLogger()
_MIGRATION_FUNCTIONS : List[Tuple[NominatimVersion, Callable[..., None]]] = [] _MIGRATION_FUNCTIONS: List[Tuple[NominatimVersion, Callable[..., None]]] = []
def migrate(config: Configuration, paths: Any) -> int: def migrate(config: Configuration, paths: Any) -> int:
""" Check for the current database version and execute migrations, """ Check for the current database version and execute migrations,

View File

@@ -25,6 +25,7 @@ from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
LOG = logging.getLogger() LOG = logging.getLogger()
def _to_float(numstr: str, max_value: float) -> float: def _to_float(numstr: str, max_value: float) -> float:
""" Convert the number in string into a float. The number is expected """ Convert the number in string into a float. The number is expected
to be in the range of [-max_value, max_value]. Otherwise rises a to be in the range of [-max_value, max_value]. Otherwise rises a
@@ -36,6 +37,7 @@ def _to_float(numstr: str, max_value: float) -> float:
return num return num
class _PostcodeCollector: class _PostcodeCollector:
""" Collector for postcodes of a single country. """ Collector for postcodes of a single country.
""" """
@@ -46,7 +48,6 @@ class _PostcodeCollector:
self.collected: Dict[str, PointsCentroid] = defaultdict(PointsCentroid) self.collected: Dict[str, PointsCentroid] = defaultdict(PointsCentroid)
self.normalization_cache: Optional[Tuple[str, Optional[str]]] = None self.normalization_cache: Optional[Tuple[str, Optional[str]]] = None
def add(self, postcode: str, x: float, y: float) -> None: def add(self, postcode: str, x: float, y: float) -> None:
""" Add the given postcode to the collection cache. If the postcode """ Add the given postcode to the collection cache. If the postcode
already existed, it is overwritten with the new centroid. already existed, it is overwritten with the new centroid.
@@ -63,7 +64,6 @@ class _PostcodeCollector:
if normalized: if normalized:
self.collected[normalized] += (x, y) self.collected[normalized] += (x, y)
def commit(self, conn: Connection, analyzer: AbstractAnalyzer, project_dir: Path) -> None: def commit(self, conn: Connection, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
""" Update postcodes for the country from the postcodes selected so far """ Update postcodes for the country from the postcodes selected so far
as well as any externally supplied postcodes. as well as any externally supplied postcodes.
@@ -97,9 +97,9 @@ class _PostcodeCollector:
""").format(pysql.Literal(self.country)), """).format(pysql.Literal(self.country)),
to_update) to_update)
def _compute_changes(
def _compute_changes(self, conn: Connection) \ self, conn: Connection
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[float, float, str]]]: ) -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[float, float, str]]]:
""" Compute which postcodes from the collected postcodes have to be """ Compute which postcodes from the collected postcodes have to be
added or modified and which from the location_postcode table added or modified and which from the location_postcode table
have to be deleted. have to be deleted.
@@ -125,7 +125,6 @@ class _PostcodeCollector:
return to_add, to_delete, to_update return to_add, to_delete, to_update
def _update_from_external(self, analyzer: AbstractAnalyzer, project_dir: Path) -> None: def _update_from_external(self, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
""" Look for an external postcode file for the active country in """ Look for an external postcode file for the active country in
the project directory and add missing postcodes when found. the project directory and add missing postcodes when found.
@@ -155,7 +154,6 @@ class _PostcodeCollector:
finally: finally:
csvfile.close() csvfile.close()
def _open_external(self, project_dir: Path) -> Optional[TextIO]: def _open_external(self, project_dir: Path) -> Optional[TextIO]:
fname = project_dir / f'{self.country}_postcodes.csv' fname = project_dir / f'{self.country}_postcodes.csv'
@@ -225,6 +223,7 @@ def update_postcodes(dsn: str, project_dir: Path, tokenizer: AbstractTokenizer)
analyzer.update_postcodes_from_db() analyzer.update_postcodes_from_db()
def can_compute(dsn: str) -> bool: def can_compute(dsn: str) -> bool:
""" """
Check that the place table exists so that Check that the place table exists so that

View File

@@ -16,7 +16,7 @@ from pathlib import Path
from psycopg import sql as pysql from psycopg import sql as pysql
from ..config import Configuration from ..config import Configuration
from ..db.connection import Connection, connect, postgis_version_tuple,\ from ..db.connection import Connection, connect, postgis_version_tuple, \
drop_tables drop_tables
from ..db.utils import execute_file from ..db.utils import execute_file
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
@@ -25,6 +25,7 @@ LOG = logging.getLogger()
OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'} OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'}
def _add_address_level_rows_from_entry(rows: MutableSequence[Tuple[Any, ...]], def _add_address_level_rows_from_entry(rows: MutableSequence[Tuple[Any, ...]],
entry: Mapping[str, Any]) -> None: entry: Mapping[str, Any]) -> None:
""" Converts a single entry from the JSON format for address rank """ Converts a single entry from the JSON format for address rank
@@ -51,7 +52,7 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
The table has the following columns: The table has the following columns:
country, class, type, rank_search, rank_address country, class, type, rank_search, rank_address
""" """
rows: List[Tuple[Any, ...]] = [] rows: List[Tuple[Any, ...]] = []
for entry in levels: for entry in levels:
_add_address_level_rows_from_entry(rows, entry) _add_address_level_rows_from_entry(rows, entry)
@@ -199,6 +200,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool =
return 0 return 0
def recompute_importance(conn: Connection) -> None: def recompute_importance(conn: Connection) -> None:
""" Recompute wikipedia links and importance for all entries in placex. """ Recompute wikipedia links and importance for all entries in placex.
This is a long-running operations that must not be executed in This is a long-running operations that must not be executed in

Some files were not shown because too many files have changed in this diff Show More