namespace-Pt commited on
Commit
7c3c80c
1 Parent(s): 2fee8e7

Upload modeling_retrieval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_retrieval.py +106 -0
modeling_retrieval.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Optional, Union
3
+ from tqdm import tqdm
4
+ from collections import defaultdict
5
+
6
+
7
+
8
+ class BM25Retriever:
9
+ def __init__(self, k1:float=0.9, b:float=0.4) -> None:
10
+ self.k1 = k1
11
+ self.b = b
12
+
13
+ def index(self, corpus: List[Union[str, List[int]]], verbose: bool=False, stop_tokens: Optional[set]=None):
14
+ """Build in-memory BM25 index."""
15
+ if stop_tokens is None:
16
+ stop_tokens = {}
17
+
18
+ dfs = defaultdict(int)
19
+ tfs = []
20
+ inverted_lists = defaultdict(list)
21
+ doc_lengths = np.zeros(len(corpus), dtype=np.float32)
22
+
23
+ if verbose:
24
+ iterator = tqdm(corpus, desc="Indexing")
25
+ else:
26
+ iterator = corpus
27
+
28
+ for i, doc in enumerate(iterator):
29
+ if isinstance(doc, str):
30
+ doc = doc.split(" ")
31
+ df = {}
32
+ tf = defaultdict(int)
33
+ for token in doc:
34
+ if token not in stop_tokens:
35
+ tf[token] += 1
36
+ df[token] = 1
37
+ tfs.append(dict(tf))
38
+ for token in df:
39
+ dfs[token] += 1
40
+ # store the doc offset in the inverted lists of the corresponding token
41
+ inverted_lists[token].append(i)
42
+
43
+ doc_lengths[i] = len(doc)
44
+
45
+ self.dfs = dict(dfs)
46
+ self.tfs = tfs
47
+ self.doc_length = doc_lengths
48
+ self.inverted_lists = {k: np.array(v) for k, v in inverted_lists.items()}
49
+ self.N = len(corpus)
50
+
51
+ def search(self, queries: Union[str, List[int], List[str], List[List[int]]], hits: int=100, k1: Optional[float]=None, b: Optional[float]=None, verbose: bool=False):
52
+ """Search over the BM25 index."""
53
+ if k1 is None:
54
+ k1 = self.k1
55
+ if b is None:
56
+ b = self.b
57
+
58
+ hits = min(self.N, hits)
59
+
60
+ global_scores = np.zeros(self.N, dtype=np.float32)
61
+
62
+ if isinstance(queries, str):
63
+ queries = [queries]
64
+ elif isinstance(queries, list) and isinstance(queries[0], int):
65
+ queries = [queries]
66
+
67
+ all_scores = np.zeros((len(queries), hits), dtype=np.float32)
68
+ all_indices = np.zeros((len(queries), hits), dtype=np.int64)
69
+
70
+ if verbose:
71
+ iterator = tqdm(queries, desc="Searching")
72
+ else:
73
+ iterator = queries
74
+
75
+ for i, query in enumerate(iterator):
76
+ if isinstance(query, str):
77
+ query = query.split(" ")
78
+ # TODO: stem
79
+
80
+ for token in query:
81
+ if token in self.inverted_lists:
82
+ candidates = self.inverted_lists[token]
83
+ else:
84
+ continue
85
+
86
+ tfs = np.array([self.tfs[candidate][token] for candidate in candidates], dtype=np.float32)
87
+ df = self.dfs[token]
88
+ idf = np.log((self.N - df + 0.5) / (df + 0.5) + 1)
89
+
90
+ candidate_scores = idf * (k1 + 1) * tfs / (tfs + k1 * (1 - b + b * self.doc_length[candidates]))
91
+ global_scores[candidates] += candidate_scores
92
+
93
+ indice = np.argpartition(-global_scores, hits - 1)[:hits]
94
+ score = global_scores[indice]
95
+
96
+ sorted_idx = np.argsort(score)[::-1]
97
+ indice = indice[sorted_idx]
98
+ score = score[sorted_idx]
99
+
100
+ invalid_pos = score == 0
101
+ indice[invalid_pos] = -1
102
+ score[invalid_pos] = -float('inf')
103
+
104
+ all_scores[i] = score
105
+ all_indices[i] = indice
106
+ return all_scores, all_indices