Bug in config.json?

#7
by dhruvmullick - opened

While loading the model, I'm getting a key error:

model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-70B-Instruct", torch_dtype=torch.bfloat16, quantization_config = quant_config)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 3775, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1066, in __init__
    self.model = LlamaModel(config)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 845, in __init__
    [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 845, in <listcomp>
    [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 632, in __init__
    self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 306, in __init__
    self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 110, in __init__
    self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling["type"])
KeyError: 'type'

This seems related to the latest config update
https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct/commit/25acb1b514688b222a02a89c6976a8d7ad0e017f

I'm using transformers==4.43

@ArthurZ @osanseviero

You can resolve this by having both 'rope_type': 'llama3', and 'type': 'llama3' to the rope_scaling config in config.json

@dhruvmullick Try using the latest version, transformers==4.43.1; it works.

Thanks @amandalmia and @cmrfrd .

In that case the model card needs to be updated since it says

Starting with transformers >= 4.43.0 onward,....

Sign up or log in to comment