sk / utils.py
naqibhakimi's picture
initial
797a2e2
raw
history blame contribute delete
No virus
6.69 kB
import math
import pickle
import re
import bs4
import torch
from GoogleNews import GoogleNews
from tqdm import tqdm
from kb import KB
from newspaper import Article, ArticleException
from pyvis.network import Network
def extract_relations_from_model_output(text):
relations = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
for token in text_replaced.split():
if token == "<obj>":
current = 'o'
relation = ''
elif token == "<subj>":
current = 's'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
object_ = ''
elif token == "<triplet>":
current = 't'
if relation != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
relation = ''
subject = ''
elif current == 'o':
relation += f' {token}'
elif current == 's':
object_ += f' {token}'
elif current == 't':
subject += f' {token}'
if subject != '' and relation != '' and object_ != '':
relations.append({
'head': subject.strip(),
'type': relation.strip(),
'tail': object_.strip()
})
return relations
def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None,
article_publish_date=None, verbose=False):
# tokenize whole text
print('Start tokenizing')
inputs = tokenizer([text], return_tensors="pt")
print('End tokenizing')
# compute span boundaries
num_tokens = len(inputs["input_ids"][0])
if verbose:
print(f"Input has {num_tokens} tokens")
num_spans = math.ceil(num_tokens / span_length)
if verbose:
print(f"Input has {num_spans} spans")
overlap = math.ceil((num_spans * span_length - num_tokens) /
max(num_spans - 1, 1))
spans_boundaries = []
start = 0
for i in tqdm(range(num_spans)):
spans_boundaries.append([start + span_length * i,
start + span_length * (i + 1)])
start -= overlap
if verbose:
print(f"Span boundaries are {spans_boundaries}")
# transform input with spans
tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
for boundary in spans_boundaries]
inputs = {
"input_ids": torch.stack(tensor_ids),
"attention_mask": torch.stack(tensor_masks)
}
# generate relations
num_return_sequences = 3
gen_kwargs = {
"max_length": 256,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": num_return_sequences
}
generated_tokens = model.generate(
**inputs,
**gen_kwargs,
)
# decode relations
decoded_preds = tokenizer.batch_decode(generated_tokens,
skip_special_tokens=False)
# create kb
kb = KB()
for i, sentence_pred in enumerate(decoded_preds):
current_span_index = i // num_return_sequences
relations = extract_relations_from_model_output(sentence_pred)
for relation in relations:
relation["meta"] = {
article_url: {
"spans": [spans_boundaries[current_span_index]]
}
}
kb.add_relation(relation, article_title, article_publish_date)
return kb
def get_article(url):
article = Article(url)
article.download()
article.parse()
return article
def from_url_to_kb(url, model, tokenizer):
article = get_article(url)
config = {
"article_title": article.title,
"article_publish_date": article.publish_date
}
return from_text_to_kb(article.text, model, tokenizer, article.url, **config)
def get_news_links(query, lang="en", region="US", pages=1):
googlenews = GoogleNews(lang=lang, region=region)
googlenews.search(query)
all_urls = []
for page in range(pages):
googlenews.get_page(page)
all_urls += googlenews.get_links()
return list(set(all_urls))
def from_urls_to_kb(urls, model, tokenizer, verbose=False):
kb = KB()
if verbose:
print(f"{len(urls)} links to visit")
for url in urls:
if verbose:
print(f"Visiting {url}...")
try:
kb_url = from_url_to_kb(url, model, tokenizer)
kb.merge_with_kb(kb_url)
except ArticleException:
if verbose:
print(f" Couldn't download article at url {url}")
return kb
def save_network_html(kb, filename="network.html"):
# create network
net = Network(directed=True, width="700px", height="700px")
# nodes
color_entity = "#00FF00"
for e in kb.entities:
net.add_node(e, shape="circle", color=color_entity)
# edges
for r in kb.relations:
net.add_edge(r["head"], r["tail"],
title=r["type"], label=r["type"])
# save network
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show_buttons(filter_=['physics'])
net.show(filename)
def save_kb(kb, filename):
with open(filename, "wb") as f:
pickle.dump(kb, f)
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
return KB if name == 'KB' else super().find_class(module, name)
def load_kb(filename):
res = None
with open(filename, "rb") as f:
res = CustomUnpickler(f).load()
return res
def process_transcript(src, dist):
with open(src, 'r') as src:
html = bs4.BeautifulSoup(src.read())
transcript = html.findChildren('div', {'class': 'transcript-line'})
with open(dist, 'w') as dist:
transcript_texts = map(lambda x: x.find('span', {'class': 'transcript-text'}).text, transcript)
transcript_texts = map(lambda text: re.sub(r'\s(?=\s)','',re.sub(r'\s',' ', text)) , transcript_texts)
text = ' '.join(transcript_texts)
dist.write(text)