namespace-Pt commited on
Commit
d7bd91c
1 Parent(s): b6abf69

Upload configuration_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_llama.py +6 -12
configuration_llama.py CHANGED
@@ -143,8 +143,9 @@ class LlamaConfig(PretrainedConfig):
143
  beacon_attend_previous=True,
144
  beacon_ratio=[8],
145
  beacon_ratio_mix="step-random",
146
- beacon_seed=42,
147
- beacon_layers=None,
 
148
  **kwargs,
149
  ):
150
  self.vocab_size = vocab_size
@@ -177,9 +178,9 @@ class LlamaConfig(PretrainedConfig):
177
  self.beacon_ratio = beacon_ratio
178
  self.beacon_stride_mix = beacon_stride_mix
179
  self.beacon_ratio_mix = beacon_ratio_mix
180
- self.beacon_seed = beacon_seed
181
- self.beacon_layers = beacon_layers
182
- self._beacon_validation()
183
 
184
  super().__init__(
185
  pad_token_id=pad_token_id,
@@ -210,10 +211,3 @@ class LlamaConfig(PretrainedConfig):
210
  )
211
  if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
212
  raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
213
-
214
- def _beacon_validation(self):
215
- for stride in self.beacon_stride:
216
- assert self.beacon_window >= stride, f"Make sure the beacon_window {self.beacon_window} >= beacon_stride {stride}!"
217
- assert self.beacon_attn in ["segmentation", "step-expansion", "full-coverage"], f"beacon_attn {self.beacon_attn} not implemented!"
218
- assert self.beacon_stride_mix in ["instance-random", "step-random", "mix-random"], f"beacon_stride_mix {self.beacon_stride_mix} not implemented!"
219
- assert self.beacon_ratio_mix in ["instance-random", "step-random", "mix-random"] or "adapt-" in self.beacon_ratio_mix, f"beacon_ratio_mix {self.beacon_ratio_mix} not implemented!"
 
143
  beacon_attend_previous=True,
144
  beacon_ratio=[8],
145
  beacon_ratio_mix="step-random",
146
+ beacon_param=["q","k","v","o"],
147
+ retrieval_method=None,
148
+ retrieval_topk=None,
149
  **kwargs,
150
  ):
151
  self.vocab_size = vocab_size
 
178
  self.beacon_ratio = beacon_ratio
179
  self.beacon_stride_mix = beacon_stride_mix
180
  self.beacon_ratio_mix = beacon_ratio_mix
181
+ self.beacon_param = beacon_param
182
+ self.retrieval_method = retrieval_method
183
+ self.retrieval_topk = retrieval_topk
184
 
185
  super().__init__(
186
  pad_token_id=pad_token_id,
 
211
  )
212
  if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
213
  raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")