make API formatter loadable from project directory

This commit is contained in:
Sarah Hoffmann
2024-08-13 23:21:38 +02:00
parent 0c25e80be0
commit 52ee5dc73c
3 changed files with 45 additions and 11 deletions

View File

@@ -7,8 +7,10 @@
"""
Helper classes and functions for formatting results into API responses.
"""
from typing import Type, TypeVar, Dict, List, Callable, Any, Mapping
from typing import Type, TypeVar, Dict, List, Callable, Any, Mapping, Optional, cast
from collections import defaultdict
from pathlib import Path
import importlib
T = TypeVar('T') # pylint: disable=invalid-name
FormatFunc = Callable[[T, Mapping[str, Any]], str]
@@ -54,3 +56,30 @@ class FormatDispatcher:
`list_formats()`.
"""
return self.format_functions[type(result)][fmt](result, options)
def load_format_dispatcher(api_name: str, project_dir: Optional[Path]) -> FormatDispatcher:
""" Load the dispatcher for the given API.
The function first tries to find a module api/<api_name>/format.py
in the project directory. This file must export a single variable
`dispatcher`.
If the function does not exist, the default formatter is loaded.
"""
if project_dir is not None:
priv_module = project_dir / 'api' / api_name / 'format.py'
if priv_module.is_file():
spec = importlib.util.spec_from_file_location(f'api.{api_name},format',
str(priv_module))
if spec:
module = importlib.util.module_from_spec(spec)
# Do not add to global modules because there is no standard
# module name that Python can resolve.
assert spec.loader is not None
spec.loader.exec_module(module)
return cast(FormatDispatcher, module.dispatch)
return cast(FormatDispatcher,
importlib.import_module(f'nominatim_api.{api_name}.format').dispatch)

View File

@@ -17,8 +17,7 @@ from falcon.asgi import App, Request, Response
from ...config import Configuration
from ...core import NominatimAPIAsync
from ... import v1 as api_impl
from ...result_formatting import FormatDispatcher
from ...v1.format import dispatch as formatting
from ...result_formatting import FormatDispatcher, load_format_dispatcher
from ... import logging as loglib
from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
@@ -64,10 +63,12 @@ class ParamWrapper(ASGIAdaptor):
""" Adaptor class for server glue to Falcon framework.
"""
def __init__(self, req: Request, resp: Response, config: Configuration) -> None:
def __init__(self, req: Request, resp: Response,
config: Configuration, formatter: FormatDispatcher) -> None:
self.request = req
self.response = resp
self._config = config
self._formatter = formatter
def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
@@ -96,23 +97,26 @@ class ParamWrapper(ASGIAdaptor):
return self._config
def formatting(self) -> FormatDispatcher:
return formatting
return self._formatter
class EndpointWrapper:
""" Converter for server glue endpoint functions to Falcon request handlers.
"""
def __init__(self, name: str, func: EndpointFunc, api: NominatimAPIAsync) -> None:
def __init__(self, name: str, func: EndpointFunc, api: NominatimAPIAsync,
formatter: FormatDispatcher) -> None:
self.name = name
self.func = func
self.api = api
self.formatter = formatter
async def on_get(self, req: Request, resp: Response) -> None:
""" Implementation of the endpoint.
"""
await self.func(self.api, ParamWrapper(req, resp, self.api.config))
await self.func(self.api, ParamWrapper(req, resp, self.api.config,
self.formatter))
class FileLoggingMiddleware:
@@ -182,8 +186,9 @@ def get_application(project_dir: Path,
app.add_error_handler(asyncio.TimeoutError, timeout_error_handler)
legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, api)
endpoint = EndpointWrapper(name, func, api, formatter)
app.add_route(f"/{name}", endpoint)
if legacy_urls:
app.add_route(f"/{name}.php", endpoint)

View File

@@ -24,8 +24,7 @@ from starlette.middleware.cors import CORSMiddleware
from ...config import Configuration
from ...core import NominatimAPIAsync
from ... import v1 as api_impl
from ...result_formatting import FormatDispatcher
from ...v1.format import dispatch as formatting
from ...result_formatting import FormatDispatcher, load_format_dispatcher
from ..asgi_adaptor import ASGIAdaptor, EndpointFunc
from ... import logging as loglib
@@ -73,7 +72,7 @@ class ParamWrapper(ASGIAdaptor):
def formatting(self) -> FormatDispatcher:
return formatting
return cast(FormatDispatcher, self.request.app.state.API.formatter)
def _wrap_endpoint(func: EndpointFunc)\
@@ -171,6 +170,7 @@ def get_application(project_dir: Path,
on_shutdown=[_shutdown])
app.state.API = NominatimAPIAsync(project_dir, environ)
app.state.formatter = load_format_dispatcher('v1', project_dir)
return app