--- language: - en license: mit library_name: Tevatron tags: - vidore datasets: - Tevatron/docmatix-ir - HuggingFaceM4/Docmatix - Tevatron/msmarco-passage-aug --- # DSE-Phi3-Docmatix-V2 DSE-Phi3-Docmatix-V2 is a bi-encoder model designed to encode document screenshots into dense vectors for document retrieval. The Document Screenshot Embedding ([DSE](https://arxiv.org/abs/2406.11251)) approach captures documents in their original visual format, preserving all information such as text, images, and layout, thus avoiding tedious parsing and potential information loss. The model, `Tevatron/dse-phi3-docmatix-v2`, is trained using 1/10 of the `Tevatron/docmatix-ir` dataset, a variant of `HuggingFaceM4/Docmatix` specifically adapted for training PDF retrievers with Vision Language Models in open-domain question answering scenarios. For more information on dataset filtering and hard negative mining, refer to the [docmatix-ir](https://huggingface.co/datasets/Tevatron/docmatix-ir/blob/main/README.md) dataset page. DSE has strong zero-shot effectiveness for document retrieval both with visual input and text input. For example, DSE-Phi3-Docmatix-V2 achieves **77.6** nDCG@5 on [ViDoRE](https://huggingface.co/spaces/vidore/vidore-leaderboard) leaderboard in **zero-shot setting** (without finetuning with ViDoRe training data). ## How to train the model from scratch Please see https://github.com/texttron/tevatron/tree/main/examples/dse ## How to Use the Model ### Load the Model and Processor ```python import torch from transformers import AutoProcessor, AutoModelForCausalLM processor = AutoProcessor.from_pretrained('Tevatron/dse-phi3-docmatix-v2', trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained('Tevatron/dse-phi3-docmatix-v2', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0') def get_embedding(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: sequence_lengths = attention_mask.sum(dim=1) - 1 bs = last_hidden_state.shape[0] reps = last_hidden_state[torch.arange(bs, device=last_hidden_state.device), sequence_lengths] reps = torch.nn.functional.normalize(reps, p=2, dim=-1) return reps ``` ### Encode Text Query ```python queries = ["query: Where can we see Llama?", "query: What is LLaMA model?"] query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=128, truncation=True).to('cuda:0') with torch.no_grad(): output = model(**query_inputs, return_dict=True, output_hidden_states=True) query_embeddings = get_embedding(output.hidden_states[-1], query_inputs["attention_mask"]) ``` ### Encode Document Screenshot ```python from PIL import Image import requests from io import BytesIO # URLs of the images url1 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/animal-llama.png" url2 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v2/resolve/main/meta-llama.png" # Download and open images response1 = requests.get(url1) response2 = requests.get(url2) passage_image1 = Image.open(BytesIO(response1.content)).resize((1344, 1344)) passage_image2 = Image.open(BytesIO(response2.content)).resize((1344, 1344)) passage_images = [passage_image1, passage_image2] passage_prompts = ["<|image_1|>\nWhat is shown in this image?", "<|image_2|>\nWhat is shown in this image?"] # Process inputs and get embeddings passage_inputs = processor(passage_prompts, images=passage_images, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0') passage_inputs['input_ids'] = passage_inputs['input_ids'].squeeze(0) passage_inputs['attention_mask'] = passage_inputs['attention_mask'].squeeze(0) passage_inputs['image_sizes'] = passage_inputs['image_sizes'].squeeze(0) with torch.no_grad(): output = model(**passage_inputs, return_dict=True, output_hidden_states=True) doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"]) ``` ### Compute Similarity ```python from torch.nn.functional import cosine_similarity num_queries = query_embeddings.size(0) num_passages = doc_embeddings.size(0) for i in range(num_queries): query_embedding = query_embeddings[i].unsqueeze(0) similarities = cosine_similarity(query_embedding, doc_embeddings) print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}") ``` ### Encode Document Text This DSE checkpoint is warm-up with `Tevatron/msmarco-passage-aug`, thus the model can also effectively encode document as text input. ```python passage_prompts = [ "The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.", "Llama (acronym for Large Language Model Meta AI, and formerly stylized as LLaMA) is a family of autoregressive large language models (LLMs) released by Meta AI starting in February 2023.[2][3] The latest version is Llama 3.1, released in July 2024.[4]" ] passage_inputs = processor(passage_prompts, images=None, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0') with torch.no_grad(): output = model(**passage_inputs, return_dict=True, output_hidden_states=True) doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"]) for i in range(num_queries): query_embedding = query_embeddings[i].unsqueeze(0) similarities = cosine_similarity(query_embedding, doc_embeddings) print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}") ``` ### Citation If you find this checkpoint is helpful, please consider cite Phi3, Docmatix and our DSE work.