query analyzer for ICU tokenizer

This commit is contained in:
Sarah Hoffmann
2023-05-22 08:46:19 +02:00
parent ff66595f7a
commit 004883bdb1
5 changed files with 547 additions and 4 deletions

View File

@@ -13,6 +13,6 @@ ignored-classes=NominatimArgs,closing
# 'too-many-ancestors' is triggered already by deriving from UserDict # 'too-many-ancestors' is triggered already by deriving from UserDict
# 'not-context-manager' disabled because it causes false positives once # 'not-context-manager' disabled because it causes false positives once
# typed Python is enabled. See also https://github.com/PyCQA/pylint/issues/5273 # typed Python is enabled. See also https://github.com/PyCQA/pylint/issues/5273
disable=too-few-public-methods,duplicate-code,too-many-ancestors,bad-option-value,no-self-use,not-context-manager,use-dict-literal,chained-comparison disable=too-few-public-methods,duplicate-code,too-many-ancestors,bad-option-value,no-self-use,not-context-manager,use-dict-literal,chained-comparison,attribute-defined-outside-init
good-names=i,x,y,m,t,fd,db,cc,x1,x2,y1,y2,pt,k,v good-names=i,j,x,y,m,t,fd,db,cc,x1,x2,y1,y2,pt,k,v

View File

@@ -7,7 +7,7 @@
""" """
Functions for specialised logging with HTML output. Functions for specialised logging with HTML output.
""" """
from typing import Any, cast from typing import Any, Iterator, Optional, List, cast
from contextvars import ContextVar from contextvars import ContextVar
import textwrap import textwrap
import io import io
@@ -56,6 +56,11 @@ class BaseLogger:
""" """
def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
""" Print the table generated by the generator function.
"""
def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None: def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
""" Print the SQL for the given statement. """ Print the SQL for the given statement.
""" """
@@ -101,9 +106,28 @@ class HTMLLogger(BaseLogger):
def var_dump(self, heading: str, var: Any) -> None: def var_dump(self, heading: str, var: Any) -> None:
if callable(var):
var = var()
self._write(f'<h5>{heading}</h5>{self._python_var(var)}') self._write(f'<h5>{heading}</h5>{self._python_var(var)}')
def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
head = next(rows)
assert head
self._write(f'<table><thead><tr><th colspan="{len(head)}">{heading}</th></tr><tr>')
for cell in head:
self._write(f'<th>{cell}</th>')
self._write('</tr></thead><tbody>')
for row in rows:
if row is not None:
self._write('<tr>')
for cell in row:
self._write(f'<td>{cell}</td>')
self._write('</tr>')
self._write('</tbody></table>')
def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None: def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
sqlstr = self.format_sql(conn, statement) sqlstr = self.format_sql(conn, statement)
if CODE_HIGHLIGHT: if CODE_HIGHLIGHT:
@@ -155,9 +179,33 @@ class TextLogger(BaseLogger):
def var_dump(self, heading: str, var: Any) -> None: def var_dump(self, heading: str, var: Any) -> None:
if callable(var):
var = var()
self._write(f'{heading}:\n {self._python_var(var)}\n\n') self._write(f'{heading}:\n {self._python_var(var)}\n\n')
def table_dump(self, heading: str, rows: Iterator[Optional[List[Any]]]) -> None:
self._write(f'{heading}:\n')
data = [list(map(self._python_var, row)) if row else None for row in rows]
assert data[0] is not None
num_cols = len(data[0])
maxlens = [max(len(d[i]) for d in data if d) for i in range(num_cols)]
tablewidth = sum(maxlens) + 3 * num_cols + 1
row_format = '| ' +' | '.join(f'{{:<{l}}}' for l in maxlens) + ' |\n'
self._write('-'*tablewidth + '\n')
self._write(row_format.format(*data[0]))
self._write('-'*tablewidth + '\n')
for row in data[1:]:
if row:
self._write(row_format.format(*row))
else:
self._write('-'*tablewidth + '\n')
if data[-1]:
self._write('-'*tablewidth + '\n')
def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None: def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78)) sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
self._write(f"| {sqlstr}\n\n") self._write(f"| {sqlstr}\n\n")

View File

@@ -0,0 +1,294 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of query analysis for the ICU tokenizer.
"""
from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
from copy import copy
from collections import defaultdict
import dataclasses
import difflib
from icu import Transliterator
import sqlalchemy as sa
from nominatim.typing import SaRow
from nominatim.api.connection import SearchConnection
from nominatim.api.logging import log
from nominatim.api.search import query as qmod
# XXX: TODO
class AbstractQueryAnalyzer:
pass
DB_TO_TOKEN_TYPE = {
'W': qmod.TokenType.WORD,
'w': qmod.TokenType.PARTIAL,
'H': qmod.TokenType.HOUSENUMBER,
'P': qmod.TokenType.POSTCODE,
'C': qmod.TokenType.COUNTRY
}
class QueryPart(NamedTuple):
""" Normalized and transliterated form of a single term in the query.
When the term came out of a split during the transliteration,
the normalized string is the full word before transliteration.
The word number keeps track of the word before transliteration
and can be used to identify partial transliterated terms.
"""
token: str
normalized: str
word_number: int
QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]]
def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]:
""" Return all combinations of words in the terms list after the
given position.
"""
total = len(terms)
for first in range(start, total):
word = terms[first].token
yield word, qmod.TokenRange(first, first + 1)
for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, terms[last].token))
yield word, qmod.TokenRange(first, last + 1)
@dataclasses.dataclass
class ICUToken(qmod.Token):
""" Specialised token for ICU tokenizer.
"""
word_token: str
info: Optional[Dict[str, Any]]
def get_category(self) -> Tuple[str, str]:
assert self.info
return self.info.get('class', ''), self.info.get('type', '')
def rematch(self, norm: str) -> None:
""" Check how well the token matches the given normalized string
and add a penalty, if necessary.
"""
if not self.lookup_word:
return
seq = difflib.SequenceMatcher(a=self.lookup_word, b=norm)
distance = 0
for tag, afrom, ato, bfrom, bto in seq.get_opcodes():
if tag == 'delete' and (afrom == 0 or ato == len(self.lookup_word)):
distance += 1
elif tag == 'replace':
distance += max((ato-afrom), (bto-bfrom))
elif tag != 'equal':
distance += abs((ato-afrom) - (bto-bfrom))
self.penalty += (distance/len(self.lookup_word))
@staticmethod
def from_db_row(row: SaRow) -> 'ICUToken':
""" Create a ICUToken from the row of the word table.
"""
count = 1 if row.info is None else row.info.get('count', 1)
penalty = 0.0
if row.type == 'w':
penalty = 0.3
elif row.type == 'H':
penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
if all(not c.isdigit() for c in row.word_token):
penalty += 0.2 * (len(row.word_token) - 1)
if row.info is None:
lookup_word = row.word
else:
lookup_word = row.info.get('lookup', row.word)
if lookup_word:
lookup_word = lookup_word.split('@', 1)[0]
else:
lookup_word = row.word_token
return ICUToken(penalty=penalty, token=row.word_id, count=count,
lookup_word=lookup_word, is_indexed=True,
word_token=row.word_token, info=row.info)
class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" Converter for query strings into a tokenized query
using the tokens created by a ICU tokenizer.
"""
def __init__(self, conn: SearchConnection) -> None:
self.conn = conn
async def setup(self) -> None:
""" Set up static data structures needed for the analysis.
"""
rules = await self.conn.get_property('tokenizer_import_normalisation')
self.normalizer = Transliterator.createFromRules("normalization", rules)
rules = await self.conn.get_property('tokenizer_import_transliteration')
self.transliterator = Transliterator.createFromRules("transliteration", rules)
if 'word' not in self.conn.t.meta.tables:
sa.Table('word', self.conn.t.meta,
sa.Column('word_id', sa.Integer),
sa.Column('word_token', sa.Text, nullable=False),
sa.Column('type', sa.Text, nullable=False),
sa.Column('word', sa.Text),
sa.Column('info', self.conn.t.types.Json))
async def analyze_query(self, phrases: List[qmod.Phrase]) -> qmod.QueryStruct:
""" Analyze the given list of phrases and return the
tokenized query.
"""
log().section('Analyze query (using ICU tokenizer)')
normalized = list(filter(lambda p: p.text,
(qmod.Phrase(p.ptype, self.normalizer.transliterate(p.text))
for p in phrases)))
query = qmod.QueryStruct(normalized)
log().var_dump('Normalized query', query.source)
if not query.source:
return query
parts, words = self.split_query(query)
log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]:
token = ICUToken.from_db_row(row)
if row.type == 'S':
if row.info['op'] in ('in', 'near'):
if trange.start == 0:
query.add_token(trange, qmod.TokenType.CATEGORY, token)
else:
query.add_token(trange, qmod.TokenType.QUALIFIER, token)
if trange.start == 0 or trange.end == query.num_token_slots():
token = copy(token)
token.penalty += 0.1 * (query.num_token_slots())
query.add_token(trange, qmod.TokenType.CATEGORY, token)
else:
query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
self.add_extra_tokens(query, parts)
self.rerank_tokens(query, parts)
log().table_dump('Word tokens', _dump_word_tokens(query))
return query
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
""" Transliterate the phrases and split them into tokens.
Returns the list of transliterated tokens together with their
normalized form and a dictionary of words for lookup together
with their position.
"""
parts: QueryParts = []
phrase_start = 0
words = defaultdict(list)
wordnr = 0
for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype
for word in phrase.text.split(' '):
trans = self.transliterator.transliterate(word)
if trans:
for term in trans.split(' '):
if term:
parts.append(QueryPart(term, word, wordnr))
query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
query.nodes[-1].btype = qmod.BreakType.WORD
wordnr += 1
query.nodes[-1].btype = qmod.BreakType.PHRASE
for word, wrange in yield_words(parts, phrase_start):
words[word].append(wrange)
phrase_start = len(parts)
query.nodes[-1].btype = qmod.BreakType.END
return parts, words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the
given word tokens.
"""
t = self.conn.t.meta.tables['word']
return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add tokens to query that are not saved in the database.
"""
for part, node, i in zip(parts, query.nodes, range(1000)):
if len(part.token) <= 4 and part[0].isdigit()\
and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
ICUToken(0.5, 0, 1, part.token, True, part.token, None))
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
""" Add penalties to tokens that depend on presence of other token.
"""
for i, node, tlist in query.iter_token_lists():
if tlist.ttype == qmod.TokenType.POSTCODE:
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.POSTCODE \
and (repl.ttype != qmod.TokenType.HOUSENUMBER
or len(tlist.tokens[0].lookup_word) > 4):
repl.add_penalty(0.39)
elif tlist.ttype == qmod.TokenType.HOUSENUMBER:
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TokenType.HOUSENUMBER \
and (repl.ttype != qmod.TokenType.HOUSENUMBER
or len(tlist.tokens[0].lookup_word) <= 3):
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
elif tlist.ttype not in (qmod.TokenType.COUNTRY, qmod.TokenType.PARTIAL):
norm = parts[i].normalized
for j in range(i + 1, tlist.end):
if parts[j - 1].word_number != parts[j].word_number:
norm += ' ' + parts[j].normalized
for token in tlist.tokens:
cast(ICUToken, token).rematch(norm)
def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
out = query.nodes[0].btype.value
for node, part in zip(query.nodes[1:], parts):
out += part.token + node.btype.value
return out
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
for node in query.nodes:
for tlist in node.starting:
for token in tlist.tokens:
t = cast(ICUToken, token)
yield [tlist.ttype.name, t.token, t.word_token or '',
t.lookup_word or '', t.penalty, t.count, t.info]
async def create_query_analyzer(conn: SearchConnection) -> AbstractQueryAnalyzer:
""" Create and set up a new query analyzer for a database based
on the ICU tokenizer.
"""
out = ICUQueryAnalyzer(conn)
await out.setup()
return out

View File

@@ -7,7 +7,7 @@
""" """
Datastructures for a tokenized query. Datastructures for a tokenized query.
""" """
from typing import List, Tuple, Optional, NamedTuple from typing import List, Tuple, Optional, NamedTuple, Iterator
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import dataclasses import dataclasses
import enum import enum
@@ -124,6 +124,13 @@ class TokenList:
tokens: List[Token] tokens: List[Token]
def add_penalty(self, penalty: float) -> None:
""" Add the given penalty to all tokens in the list.
"""
for token in self.tokens:
token.penalty += penalty
@dataclasses.dataclass @dataclasses.dataclass
class QueryNode: class QueryNode:
""" A node of the querry representing a break between terms. """ A node of the querry representing a break between terms.
@@ -226,6 +233,14 @@ class QueryStruct:
for i in range(trange.start, trange.end)] for i in range(trange.start, trange.end)]
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
""" Iterator over all token lists in the query.
"""
for i, node in enumerate(self.nodes):
for tlist in node.starting:
yield i, node, tlist
def find_lookup_word_by_id(self, token: int) -> str: def find_lookup_word_by_id(self, token: int) -> str:
""" Find the first token with the given token ID and return """ Find the first token with the given token ID and return
its lookup word. Returns 'None' if no such token exists. its lookup word. Returns 'None' if no such token exists.

View File

@@ -0,0 +1,186 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Tests for query analyzer for ICU tokenizer.
"""
from pathlib import Path
import pytest
import pytest_asyncio
from nominatim.api import NominatimAPIAsync
from nominatim.api.search.query import Phrase, PhraseType, TokenType, BreakType
import nominatim.api.search.icu_tokenizer as tok
from nominatim.api.logging import set_log_output, get_and_disable
async def add_word(conn, word_id, word_token, wtype, word, info = None):
t = conn.t.meta.tables['word']
await conn.execute(t.insert(), {'word_id': word_id,
'word_token': word_token,
'type': wtype,
'word': word,
'info': info})
def make_phrase(query):
return [Phrase(PhraseType.NONE, s) for s in query.split(',')]
@pytest_asyncio.fixture
async def conn(table_factory):
""" Create an asynchronous SQLAlchemy engine for the test DB.
"""
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer_import_normalisation', ':: lower();'),
('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))
table_factory('word',
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
@pytest.mark.asyncio
async def test_empty_phrase(conn):
ana = await tok.create_query_analyzer(conn)
query = await ana.analyze_query([])
assert len(query.source) == 0
assert query.num_token_slots() == 0
@pytest.mark.asyncio
async def test_single_phrase_with_unknown_terms(conn):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, 'foo', 'w', 'FOO')
query = await ana.analyze_query(make_phrase('foo BAR'))
assert len(query.source) == 1
assert query.source[0].ptype == PhraseType.NONE
assert query.source[0].text == 'foo bar'
assert query.num_token_slots() == 2
assert len(query.nodes[0].starting) == 1
assert not query.nodes[1].starting
@pytest.mark.asyncio
async def test_multiple_phrases(conn):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, 'one', 'w', 'one')
await add_word(conn, 2, 'two', 'w', 'two')
await add_word(conn, 100, 'one two', 'W', 'one two')
await add_word(conn, 3, 'three', 'w', 'three')
query = await ana.analyze_query(make_phrase('one two,three'))
assert len(query.source) == 2
@pytest.mark.asyncio
async def test_splitting_in_transliteration(conn):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, '', 'W', 'ma')
await add_word(conn, 2, 'fo', 'W', 'fo')
query = await ana.analyze_query(make_phrase('mäfo'))
assert query.num_token_slots() == 2
assert query.nodes[0].starting
assert query.nodes[1].starting
assert query.nodes[1].btype == BreakType.TOKEN
@pytest.mark.asyncio
@pytest.mark.parametrize('term,order', [('23456', ['POSTCODE', 'HOUSENUMBER', 'WORD', 'PARTIAL']),
('3', ['HOUSENUMBER', 'POSTCODE', 'WORD', 'PARTIAL'])
])
async def test_penalty_postcodes_and_housenumbers(conn, term, order):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, term, 'P', None)
await add_word(conn, 2, term, 'H', term)
await add_word(conn, 3, term, 'w', term)
await add_word(conn, 4, term, 'W', term)
query = await ana.analyze_query(make_phrase(term))
assert query.num_token_slots() == 1
torder = [(tl.tokens[0].penalty, tl.ttype) for tl in query.nodes[0].starting]
torder.sort()
assert [t[1] for t in torder] == [TokenType[o] for o in order]
@pytest.mark.asyncio
async def test_category_words_only_at_beginning(conn):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, 'foo', 'S', 'FOO', {'op': 'in'})
await add_word(conn, 2, 'bar', 'w', 'BAR')
query = await ana.analyze_query(make_phrase('foo BAR foo'))
assert query.num_token_slots() == 3
assert len(query.nodes[0].starting) == 1
assert query.nodes[0].starting[0].ttype == TokenType.CATEGORY
assert not query.nodes[2].starting
@pytest.mark.asyncio
async def test_qualifier_words(conn):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, 'foo', 'S', None, {'op': '-'})
await add_word(conn, 2, 'bar', 'w', None)
query = await ana.analyze_query(make_phrase('foo BAR foo BAR foo'))
assert query.num_token_slots() == 5
assert set(t.ttype for t in query.nodes[0].starting) == {TokenType.CATEGORY, TokenType.QUALIFIER}
assert set(t.ttype for t in query.nodes[2].starting) == {TokenType.QUALIFIER}
assert set(t.ttype for t in query.nodes[4].starting) == {TokenType.CATEGORY, TokenType.QUALIFIER}
@pytest.mark.asyncio
async def test_add_unknown_housenumbers(conn):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, '23', 'H', '23')
query = await ana.analyze_query(make_phrase('466 23 99834 34a'))
assert query.num_token_slots() == 4
assert query.nodes[0].starting[0].ttype == TokenType.HOUSENUMBER
assert len(query.nodes[0].starting[0].tokens) == 1
assert query.nodes[0].starting[0].tokens[0].token == 0
assert query.nodes[1].starting[0].ttype == TokenType.HOUSENUMBER
assert len(query.nodes[1].starting[0].tokens) == 1
assert query.nodes[1].starting[0].tokens[0].token == 1
assert not query.nodes[2].starting
assert not query.nodes[3].starting
@pytest.mark.asyncio
@pytest.mark.parametrize('logtype', ['text', 'html'])
async def test_log_output(conn, logtype):
ana = await tok.create_query_analyzer(conn)
await add_word(conn, 1, 'foo', 'w', 'FOO')
set_log_output(logtype)
await ana.analyze_query(make_phrase('foo'))
assert get_and_disable()