Merge pull request #3692 from lonvia/word-lookup-variants

Avoid matching penalty for abbreviated search terms
This commit is contained in:
Sarah Hoffmann
2025-03-31 16:38:31 +02:00
committed by GitHub
7 changed files with 104 additions and 41 deletions

View File

@@ -121,10 +121,10 @@ class ICUTokenizer(AbstractTokenizer):
SELECT unnest(nameaddress_vector) as id, count(*)
FROM search_name GROUP BY id)
SELECT coalesce(a.id, w.id) as id,
(CASE WHEN w.count is null THEN '{}'::JSONB
(CASE WHEN w.count is null or w.count <= 1 THEN '{}'::JSONB
ELSE jsonb_build_object('count', w.count) END
||
CASE WHEN a.count is null THEN '{}'::JSONB
CASE WHEN a.count is null or a.count <= 1 THEN '{}'::JSONB
ELSE jsonb_build_object('addr_count', a.count) END) as info
FROM word_freq w FULL JOIN addr_freq a ON a.id = w.id;
""")
@@ -134,9 +134,10 @@ class ICUTokenizer(AbstractTokenizer):
drop_tables(conn, 'tmp_word')
cur.execute("""CREATE TABLE tmp_word AS
SELECT word_id, word_token, type, word,
(CASE WHEN wf.info is null THEN word.info
ELSE coalesce(word.info, '{}'::jsonb) || wf.info
END) as info
coalesce(word.info, '{}'::jsonb)
- 'count' - 'addr_count' ||
coalesce(wf.info, '{}'::jsonb)
as info
FROM word LEFT JOIN word_frequencies wf
ON word.word_id = wf.id
""")
@@ -584,10 +585,14 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if word_id:
result = self._cache.housenumbers.get(word_id, result)
if result[0] is None:
variants = analyzer.compute_variants(word_id)
varout = analyzer.compute_variants(word_id)
if isinstance(varout, tuple):
variants = varout[0]
else:
variants = varout
if variants:
hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
(word_id, list(variants)))
(word_id, variants))
result = hid, variants[0]
self._cache.housenumbers[word_id] = result
@@ -632,13 +637,17 @@ class ICUNameAnalyzer(AbstractAnalyzer):
full, part = self._cache.names.get(token_id, (None, None))
if full is None:
variants = analyzer.compute_variants(word_id)
varset = analyzer.compute_variants(word_id)
if isinstance(varset, tuple):
variants, lookups = varset
else:
variants, lookups = varset, None
if not variants:
continue
with self.conn.cursor() as cur:
cur.execute("SELECT * FROM getorcreate_full_word(%s, %s)",
(token_id, variants))
cur.execute("SELECT * FROM getorcreate_full_word(%s, %s, %s)",
(token_id, variants, lookups))
full, part = cast(Tuple[int, List[int]], cur.fetchone())
self._cache.names[token_id] = (full, part)

View File

@@ -7,7 +7,7 @@
"""
Common data types and protocols for analysers.
"""
from typing import Mapping, List, Any
from typing import Mapping, List, Any, Union, Tuple
from ...typing import Protocol
from ...data.place_name import PlaceName
@@ -33,7 +33,7 @@ class Analyzer(Protocol):
for example because the character set in use does not match.
"""
def compute_variants(self, canonical_id: str) -> List[str]:
def compute_variants(self, canonical_id: str) -> Union[List[str], Tuple[List[str], List[str]]]:
""" Compute the transliterated spelling variants for the given
canonical ID.

View File

@@ -7,7 +7,7 @@
"""
Generic processor for names that creates abbreviation variants.
"""
from typing import Mapping, Dict, Any, Iterable, Iterator, Optional, List, cast
from typing import Mapping, Dict, Any, Iterable, Optional, List, cast, Tuple
import itertools
from ...errors import UsageError
@@ -78,7 +78,7 @@ class GenericTokenAnalysis:
"""
return cast(str, self.norm.transliterate(name.name)).strip()
def compute_variants(self, norm_name: str) -> List[str]:
def compute_variants(self, norm_name: str) -> Tuple[List[str], List[str]]:
""" Compute the spelling variants for the given normalized name
and transliterate the result.
"""
@@ -87,18 +87,20 @@ class GenericTokenAnalysis:
for mutation in self.mutations:
variants = mutation.generate(variants)
return [name for name in self._transliterate_unique_list(norm_name, variants) if name]
def _transliterate_unique_list(self, norm_name: str,
iterable: Iterable[str]) -> Iterator[Optional[str]]:
seen = set()
varset = set(map(str.strip, variants))
if self.variant_only:
seen.add(norm_name)
varset.discard(norm_name)
for variant in map(str.strip, iterable):
if variant not in seen:
seen.add(variant)
yield self.to_ascii.transliterate(variant).strip()
trans = []
norm = []
for var in varset:
t = self.to_ascii.transliterate(var).strip()
if t:
trans.append(t)
norm.append(var)
return trans, norm
def _generate_word_variants(self, norm_name: str) -> Iterable[str]:
baseform = '^ ' + norm_name + ' ^'