Merge pull request #3692 from lonvia/word-lookup-variants

Avoid matching penalty for abbreviated search terms
This commit is contained in:
Sarah Hoffmann
2025-03-31 16:38:31 +02:00
committed by GitHub
7 changed files with 104 additions and 41 deletions

View File

@@ -128,16 +128,14 @@ DECLARE
partial_terms TEXT[] = '{}'::TEXT[];
term TEXT;
term_id INTEGER;
term_count INTEGER;
BEGIN
SELECT min(word_id) INTO full_token
FROM word WHERE word = norm_term and type = 'W';
IF full_token IS NULL THEN
full_token := nextval('seq_word');
INSERT INTO word (word_id, word_token, type, word, info)
SELECT full_token, lookup_term, 'W', norm_term,
json_build_object('count', 0)
INSERT INTO word (word_id, word_token, type, word)
SELECT full_token, lookup_term, 'W', norm_term
FROM unnest(lookup_terms) as lookup_term;
END IF;
@@ -150,14 +148,67 @@ BEGIN
partial_tokens := '{}'::INT[];
FOR term IN SELECT unnest(partial_terms) LOOP
SELECT min(word_id), max(info->>'count') INTO term_id, term_count
SELECT min(word_id) INTO term_id
FROM word WHERE word_token = term and type = 'w';
IF term_id IS NULL THEN
term_id := nextval('seq_word');
term_count := 0;
INSERT INTO word (word_id, word_token, type, info)
VALUES (term_id, term, 'w', json_build_object('count', term_count));
INSERT INTO word (word_id, word_token, type)
VALUES (term_id, term, 'w');
END IF;
partial_tokens := array_merge(partial_tokens, ARRAY[term_id]);
END LOOP;
END;
$$
LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION getorcreate_full_word(norm_term TEXT,
lookup_terms TEXT[],
lookup_norm_terms TEXT[],
OUT full_token INT,
OUT partial_tokens INT[])
AS $$
DECLARE
partial_terms TEXT[] = '{}'::TEXT[];
term TEXT;
term_id INTEGER;
BEGIN
SELECT min(word_id) INTO full_token
FROM word WHERE word = norm_term and type = 'W';
IF full_token IS NULL THEN
full_token := nextval('seq_word');
IF lookup_norm_terms IS NULL THEN
INSERT INTO word (word_id, word_token, type, word)
SELECT full_token, lookup_term, 'W', norm_term
FROM unnest(lookup_terms) as lookup_term;
ELSE
INSERT INTO word (word_id, word_token, type, word, info)
SELECT full_token, t.lookup, 'W', norm_term,
CASE WHEN norm_term = t.norm THEN null
ELSE json_build_object('lookup', t.norm) END
FROM unnest(lookup_terms, lookup_norm_terms) as t(lookup, norm);
END IF;
END IF;
FOR term IN SELECT unnest(string_to_array(unnest(lookup_terms), ' ')) LOOP
term := trim(term);
IF NOT (ARRAY[term] <@ partial_terms) THEN
partial_terms := partial_terms || term;
END IF;
END LOOP;
partial_tokens := '{}'::INT[];
FOR term IN SELECT unnest(partial_terms) LOOP
SELECT min(word_id) INTO term_id
FROM word WHERE word_token = term and type = 'w';
IF term_id IS NULL THEN
term_id := nextval('seq_word');
INSERT INTO word (word_id, word_token, type)
VALUES (term_id, term, 'w');
END IF;
partial_tokens := array_merge(partial_tokens, ARRAY[term_id]);

View File

@@ -121,10 +121,10 @@ class ICUTokenizer(AbstractTokenizer):
SELECT unnest(nameaddress_vector) as id, count(*)
FROM search_name GROUP BY id)
SELECT coalesce(a.id, w.id) as id,
(CASE WHEN w.count is null THEN '{}'::JSONB
(CASE WHEN w.count is null or w.count <= 1 THEN '{}'::JSONB
ELSE jsonb_build_object('count', w.count) END
||
CASE WHEN a.count is null THEN '{}'::JSONB
CASE WHEN a.count is null or a.count <= 1 THEN '{}'::JSONB
ELSE jsonb_build_object('addr_count', a.count) END) as info
FROM word_freq w FULL JOIN addr_freq a ON a.id = w.id;
""")
@@ -134,9 +134,10 @@ class ICUTokenizer(AbstractTokenizer):
drop_tables(conn, 'tmp_word')
cur.execute("""CREATE TABLE tmp_word AS
SELECT word_id, word_token, type, word,
(CASE WHEN wf.info is null THEN word.info
ELSE coalesce(word.info, '{}'::jsonb) || wf.info
END) as info
coalesce(word.info, '{}'::jsonb)
- 'count' - 'addr_count' ||
coalesce(wf.info, '{}'::jsonb)
as info
FROM word LEFT JOIN word_frequencies wf
ON word.word_id = wf.id
""")
@@ -584,10 +585,14 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if word_id:
result = self._cache.housenumbers.get(word_id, result)
if result[0] is None:
variants = analyzer.compute_variants(word_id)
varout = analyzer.compute_variants(word_id)
if isinstance(varout, tuple):
variants = varout[0]
else:
variants = varout
if variants:
hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
(word_id, list(variants)))
(word_id, variants))
result = hid, variants[0]
self._cache.housenumbers[word_id] = result
@@ -632,13 +637,17 @@ class ICUNameAnalyzer(AbstractAnalyzer):
full, part = self._cache.names.get(token_id, (None, None))
if full is None:
variants = analyzer.compute_variants(word_id)
varset = analyzer.compute_variants(word_id)
if isinstance(varset, tuple):
variants, lookups = varset
else:
variants, lookups = varset, None
if not variants:
continue
with self.conn.cursor() as cur:
cur.execute("SELECT * FROM getorcreate_full_word(%s, %s)",
(token_id, variants))
cur.execute("SELECT * FROM getorcreate_full_word(%s, %s, %s)",
(token_id, variants, lookups))
full, part = cast(Tuple[int, List[int]], cur.fetchone())
self._cache.names[token_id] = (full, part)

View File

@@ -7,7 +7,7 @@
"""
Common data types and protocols for analysers.
"""
from typing import Mapping, List, Any
from typing import Mapping, List, Any, Union, Tuple
from ...typing import Protocol
from ...data.place_name import PlaceName
@@ -33,7 +33,7 @@ class Analyzer(Protocol):
for example because the character set in use does not match.
"""
def compute_variants(self, canonical_id: str) -> List[str]:
def compute_variants(self, canonical_id: str) -> Union[List[str], Tuple[List[str], List[str]]]:
""" Compute the transliterated spelling variants for the given
canonical ID.

View File

@@ -7,7 +7,7 @@
"""
Generic processor for names that creates abbreviation variants.
"""
from typing import Mapping, Dict, Any, Iterable, Iterator, Optional, List, cast
from typing import Mapping, Dict, Any, Iterable, Optional, List, cast, Tuple
import itertools
from ...errors import UsageError
@@ -78,7 +78,7 @@ class GenericTokenAnalysis:
"""
return cast(str, self.norm.transliterate(name.name)).strip()
def compute_variants(self, norm_name: str) -> List[str]:
def compute_variants(self, norm_name: str) -> Tuple[List[str], List[str]]:
""" Compute the spelling variants for the given normalized name
and transliterate the result.
"""
@@ -87,18 +87,20 @@ class GenericTokenAnalysis:
for mutation in self.mutations:
variants = mutation.generate(variants)
return [name for name in self._transliterate_unique_list(norm_name, variants) if name]
def _transliterate_unique_list(self, norm_name: str,
iterable: Iterable[str]) -> Iterator[Optional[str]]:
seen = set()
varset = set(map(str.strip, variants))
if self.variant_only:
seen.add(norm_name)
varset.discard(norm_name)
for variant in map(str.strip, iterable):
if variant not in seen:
seen.add(variant)
yield self.to_ascii.transliterate(variant).strip()
trans = []
norm = []
for var in varset:
t = self.to_ascii.transliterate(var).strip()
if t:
trans.append(t)
norm.append(var)
return trans, norm
def _generate_word_variants(self, norm_name: str) -> Iterable[str]:
baseform = '^ ' + norm_name + ' ^'

View File

@@ -230,19 +230,20 @@ def test_update_statistics(word_table, table_factory, temp_db_cursor,
tokenizer_factory, test_config):
word_table.add_full_word(1000, 'hello')
word_table.add_full_word(1001, 'bye')
word_table.add_full_word(1002, 'town')
table_factory('search_name',
'place_id BIGINT, name_vector INT[], nameaddress_vector INT[]',
[(12, [1000], [1001])])
[(12, [1000], [1001]), (13, [1001], [1002]), (14, [1000, 1001], [1002])])
tok = tokenizer_factory()
tok.update_statistics(test_config)
assert temp_db_cursor.scalar("""SELECT count(*) FROM word
WHERE type = 'W' and word_id = 1000 and
(info->>'count')::int > 0""") == 1
assert temp_db_cursor.scalar("""SELECT count(*) FROM word
WHERE type = 'W' and word_id = 1001 and
(info->>'addr_count')::int > 0""") == 1
assert temp_db_cursor.row_set("""SELECT word_id,
(info->>'count')::int,
(info->>'addr_count')::int
FROM word
WHERE type = 'W'""") == \
{(1000, 2, None), (1001, 2, None), (1002, None, 2)}
def test_normalize_postcode(analyzer):

View File

@@ -40,7 +40,7 @@ def make_analyser(*variants, variant_only=False):
def get_normalized_variants(proc, name):
norm = Transliterator.createFromRules("test_norm", DEFAULT_NORMALIZATION)
return proc.compute_variants(norm.transliterate(name).strip())
return proc.compute_variants(norm.transliterate(name).strip())[0]
def test_no_variants():

View File

@@ -40,7 +40,7 @@ class TestMutationNoVariants:
def variants(self, name):
norm = Transliterator.createFromRules("test_norm", DEFAULT_NORMALIZATION)
return set(self.analysis.compute_variants(norm.transliterate(name).strip()))
return set(self.analysis.compute_variants(norm.transliterate(name).strip())[0])
@pytest.mark.parametrize('pattern', ('(capture)', ['a list']))
def test_bad_pattern(self, pattern):