Add full tools support to the chat template

#45
by Rocketknight1 HF staff - opened

This PR is still in progress! It should work, but it needs testing to verify it exactly matches the outputs of mistral-common.

This has now been tested and confirmed to match the output from mistral-common!

Test script to confirm:

from transformers import AutoTokenizer
from mistral_common.protocol.instruct.tool_calls import Function, Tool

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, ToolMessage
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from mistral_common.protocol.instruct.request import ChatCompletionRequest

hf_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1", revision="pr/45")

hf_tool = {
                "name": "get_current_weather",
                "description": "Get the current weather",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA",
                        },
                        "format": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                            "description": "The temperature unit to use. Infer this from the users location.",
                        },
                    },
                    "required": ["location", "format"],
                },
}

hf_tool = {"type": "function", "function": hf_tool}

test_chat = [{"role": "user", "content": "What's the weather like today in Paris"}]
tool_call = {"name": "get_current_weather", "arguments": {"location": "Paris, France"}}
test_chat.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call, "id": "abcdef123"}]})
test_chat.append({"role": "tool", "name": "get_current_temperature", "tool_call_id": "abcdef123", "content": "22.0"})

hf_text =hf_tokenizer.apply_chat_template(test_chat, tokenize=False, tools=[hf_tool])
hf_tokens = hf_tokenizer.apply_chat_template(test_chat, tokenize=True, tools=[hf_tool])

mistral_tokenizer = MistralTokenizer.v3()

mistral_tool = Tool(
            function=Function(
                name="get_current_weather",
                description="Get the current weather",
                parameters={
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA",
                        },
                        "format": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                            "description": "The temperature unit to use. Infer this from the users location.",
                        },
                    },
                    "required": ["location", "format"],
                },
            )
        )

mistral_query = ChatCompletionRequest(
    tools=[mistral_tool],
    messages=[
        UserMessage(content="What's the weather like today in Paris"),
        AssistantMessage(tool_calls=[ToolCall(type="function", function=FunctionCall(
            name="get_current_weather", arguments={"location": "Paris, France"}), id="abcdef123"
        )]),
        ToolMessage(content="22.0", tool_call_id="abcdef123")
    ],
    model="test",
)
encodeds = mistral_tokenizer.encode_chat_completion(mistral_query).text
mistral_text = encodeds.replace("▁", " ")
mistral_tokens = mistral_tokenizer.encode_chat_completion(mistral_query).tokens

print(hf_text == mistral_text)
print(hf_tokens == mistral_tokens)
pandora-s changed pull request status to merged

Sign up or log in to comment