namespace-Pt
commited on
Commit
•
7c3c80c
1
Parent(s):
2fee8e7
Upload modeling_retrieval.py with huggingface_hub
Browse files- modeling_retrieval.py +106 -0
modeling_retrieval.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
from tqdm import tqdm
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class BM25Retriever:
|
9 |
+
def __init__(self, k1:float=0.9, b:float=0.4) -> None:
|
10 |
+
self.k1 = k1
|
11 |
+
self.b = b
|
12 |
+
|
13 |
+
def index(self, corpus: List[Union[str, List[int]]], verbose: bool=False, stop_tokens: Optional[set]=None):
|
14 |
+
"""Build in-memory BM25 index."""
|
15 |
+
if stop_tokens is None:
|
16 |
+
stop_tokens = {}
|
17 |
+
|
18 |
+
dfs = defaultdict(int)
|
19 |
+
tfs = []
|
20 |
+
inverted_lists = defaultdict(list)
|
21 |
+
doc_lengths = np.zeros(len(corpus), dtype=np.float32)
|
22 |
+
|
23 |
+
if verbose:
|
24 |
+
iterator = tqdm(corpus, desc="Indexing")
|
25 |
+
else:
|
26 |
+
iterator = corpus
|
27 |
+
|
28 |
+
for i, doc in enumerate(iterator):
|
29 |
+
if isinstance(doc, str):
|
30 |
+
doc = doc.split(" ")
|
31 |
+
df = {}
|
32 |
+
tf = defaultdict(int)
|
33 |
+
for token in doc:
|
34 |
+
if token not in stop_tokens:
|
35 |
+
tf[token] += 1
|
36 |
+
df[token] = 1
|
37 |
+
tfs.append(dict(tf))
|
38 |
+
for token in df:
|
39 |
+
dfs[token] += 1
|
40 |
+
# store the doc offset in the inverted lists of the corresponding token
|
41 |
+
inverted_lists[token].append(i)
|
42 |
+
|
43 |
+
doc_lengths[i] = len(doc)
|
44 |
+
|
45 |
+
self.dfs = dict(dfs)
|
46 |
+
self.tfs = tfs
|
47 |
+
self.doc_length = doc_lengths
|
48 |
+
self.inverted_lists = {k: np.array(v) for k, v in inverted_lists.items()}
|
49 |
+
self.N = len(corpus)
|
50 |
+
|
51 |
+
def search(self, queries: Union[str, List[int], List[str], List[List[int]]], hits: int=100, k1: Optional[float]=None, b: Optional[float]=None, verbose: bool=False):
|
52 |
+
"""Search over the BM25 index."""
|
53 |
+
if k1 is None:
|
54 |
+
k1 = self.k1
|
55 |
+
if b is None:
|
56 |
+
b = self.b
|
57 |
+
|
58 |
+
hits = min(self.N, hits)
|
59 |
+
|
60 |
+
global_scores = np.zeros(self.N, dtype=np.float32)
|
61 |
+
|
62 |
+
if isinstance(queries, str):
|
63 |
+
queries = [queries]
|
64 |
+
elif isinstance(queries, list) and isinstance(queries[0], int):
|
65 |
+
queries = [queries]
|
66 |
+
|
67 |
+
all_scores = np.zeros((len(queries), hits), dtype=np.float32)
|
68 |
+
all_indices = np.zeros((len(queries), hits), dtype=np.int64)
|
69 |
+
|
70 |
+
if verbose:
|
71 |
+
iterator = tqdm(queries, desc="Searching")
|
72 |
+
else:
|
73 |
+
iterator = queries
|
74 |
+
|
75 |
+
for i, query in enumerate(iterator):
|
76 |
+
if isinstance(query, str):
|
77 |
+
query = query.split(" ")
|
78 |
+
# TODO: stem
|
79 |
+
|
80 |
+
for token in query:
|
81 |
+
if token in self.inverted_lists:
|
82 |
+
candidates = self.inverted_lists[token]
|
83 |
+
else:
|
84 |
+
continue
|
85 |
+
|
86 |
+
tfs = np.array([self.tfs[candidate][token] for candidate in candidates], dtype=np.float32)
|
87 |
+
df = self.dfs[token]
|
88 |
+
idf = np.log((self.N - df + 0.5) / (df + 0.5) + 1)
|
89 |
+
|
90 |
+
candidate_scores = idf * (k1 + 1) * tfs / (tfs + k1 * (1 - b + b * self.doc_length[candidates]))
|
91 |
+
global_scores[candidates] += candidate_scores
|
92 |
+
|
93 |
+
indice = np.argpartition(-global_scores, hits - 1)[:hits]
|
94 |
+
score = global_scores[indice]
|
95 |
+
|
96 |
+
sorted_idx = np.argsort(score)[::-1]
|
97 |
+
indice = indice[sorted_idx]
|
98 |
+
score = score[sorted_idx]
|
99 |
+
|
100 |
+
invalid_pos = score == 0
|
101 |
+
indice[invalid_pos] = -1
|
102 |
+
score[invalid_pos] = -float('inf')
|
103 |
+
|
104 |
+
all_scores[i] = score
|
105 |
+
all_indices[i] = indice
|
106 |
+
return all_scores, all_indices
|