namespace-Pt commited on
Commit
2fee8e7
1 Parent(s): d7bd91c

Upload modeling_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_llama.py +180 -71
modeling_llama.py CHANGED
@@ -226,8 +226,35 @@ class LlamaMLP(nn.Module):
226
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
227
  self.act_fn = ACT2FN[config.hidden_act]
228
 
229
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  if self.config.pretraining_tp > 1:
 
 
 
231
  slice = self.intermediate_size // self.config.pretraining_tp
232
  gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
233
  up_proj_slices = self.up_proj.weight.split(slice, dim=0)
@@ -243,8 +270,28 @@ class LlamaMLP(nn.Module):
243
  F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
244
  ]
245
  down_proj = sum(down_proj)
 
246
  else:
247
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  return down_proj
250
 
@@ -297,16 +344,20 @@ class LlamaAttention(nn.Module):
297
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
298
  self._init_rope()
299
 
300
- # NOTE: add extra parameters for fold tokens
301
- self.beacon_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
302
- self.beacon_k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
303
- self.beacon_v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
304
- self.beacon_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
305
  # skip post initialization to speed up loading
306
- self.beacon_q_proj._is_hf_initialized = True
307
- self.beacon_k_proj._is_hf_initialized = True
308
- self.beacon_v_proj._is_hf_initialized = True
309
- self.beacon_o_proj._is_hf_initialized = True
 
 
 
 
 
 
 
 
310
 
311
  def _init_rope(self):
312
  if self.config.rope_scaling is None:
@@ -335,22 +386,33 @@ class LlamaAttention(nn.Module):
335
  else:
336
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
337
 
338
- def _init_beacon_proj(self):
339
- """Initialize the fold projection weight with that of the ordinal projection."""
 
 
 
340
  if is_deepspeed_zero3_enabled():
341
  import deepspeed
342
  params = [self.beacon_q_proj.weight, self.beacon_k_proj.weight, self.beacon_v_proj.weight, self.beacon_o_proj.weight, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight, self.o_proj.weight]
343
  with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
 
 
 
 
 
 
 
 
 
 
 
344
  self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
 
345
  self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
 
346
  self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
 
347
  self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
348
- else:
349
- # only copy the value in-place, without tieing the weight
350
- self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
351
- self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
352
- self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
353
- self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
354
 
355
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
356
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -360,17 +422,26 @@ class LlamaAttention(nn.Module):
360
  ordinal_hidden_states = hidden_states[:, :-beacon_size]
361
  beacon_hidden_states = hidden_states[:, -beacon_size:]
362
 
363
- ordinal_query_states = self.q_proj(ordinal_hidden_states)
364
- ordinal_key_states = self.k_proj(ordinal_hidden_states)
365
- ordinal_value_states = self.v_proj(ordinal_hidden_states)
366
-
367
- beacon_query_states = self.beacon_q_proj(beacon_hidden_states)
368
- beacon_key_states = self.beacon_k_proj(beacon_hidden_states)
369
- beacon_value_states = self.beacon_v_proj(beacon_hidden_states)
 
 
 
 
 
 
370
 
371
- query_states = torch.cat([ordinal_query_states, beacon_query_states], dim=1)
372
- key_states = torch.cat([ordinal_key_states, beacon_key_states], dim=1)
373
- value_states = torch.cat([ordinal_value_states, beacon_value_states], dim=1)
 
 
 
374
 
375
  else:
376
  query_states = self.q_proj(hidden_states)
@@ -378,6 +449,18 @@ class LlamaAttention(nn.Module):
378
  value_states = self.v_proj(hidden_states)
379
 
380
  return query_states, key_states, value_states
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
  def forward(
383
  self,
@@ -403,8 +486,10 @@ class LlamaAttention(nn.Module):
403
  else:
404
  past_seq_len = 0
405
 
406
- # TODO: support pretraining_tp
407
  if self.config.pretraining_tp > 1:
 
 
 
408
  key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
409
  query_slices = self.q_proj.weight.split(
410
  (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
@@ -430,13 +515,14 @@ class LlamaAttention(nn.Module):
430
 
431
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
432
 
 
 
 
 
433
  if past_key is not None:
434
  # reuse k, v, self_attention
435
  key_states = torch.cat([past_key, key_states], dim=2)
436
  value_states = torch.cat([past_value, value_states], dim=2)
437
-
438
- # return keys and values before rope
439
- past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
440
 
441
  key_position_ids = position_ids
442
  # align query position_ids with key
@@ -480,16 +566,13 @@ class LlamaAttention(nn.Module):
480
 
481
  if self.config.pretraining_tp > 1:
482
  # TODO: support pretraining_tp
 
483
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
484
  o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
485
  attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
 
486
  else:
487
- if beacon_size > 0:
488
- regular_attn_output = self.o_proj(attn_output[:, :-beacon_size])
489
- beacon_attn_output = self.beacon_o_proj(attn_output[:, -beacon_size:])
490
- attn_output = torch.cat([regular_attn_output, beacon_attn_output], dim=1)
491
- else:
492
- attn_output = self.o_proj(attn_output)
493
 
494
  if not output_attentions:
495
  attn_weights = None
@@ -545,14 +628,15 @@ class LlamaSdpaAttention(LlamaAttention):
545
 
546
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
547
 
 
 
 
 
548
  if past_key is not None:
549
  # reuse k, v, self_attention
550
  key_states = torch.cat([past_key, key_states], dim=2)
551
  value_states = torch.cat([past_value, value_states], dim=2)
552
 
553
- # return keys and values before rope
554
- past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
555
-
556
  key_position_ids = position_ids
557
  # align query position_ids with key
558
  query_position_ids = key_position_ids[:, -q_len:]
@@ -588,13 +672,20 @@ class LlamaSdpaAttention(LlamaAttention):
588
 
589
  attn_output = attn_output.transpose(1, 2).contiguous()
590
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
591
-
592
- if beacon_size > 0:
593
- regular_attn_output = self.o_proj(attn_output[:, :-beacon_size])
594
- beacon_attn_output = self.beacon_o_proj(attn_output[:, -beacon_size:])
595
- attn_output = torch.cat([regular_attn_output, beacon_attn_output], dim=1)
596
- else:
597
- attn_output = self.o_proj(attn_output)
 
 
 
 
 
 
 
598
 
599
  return attn_output, None, past_key_value
600
 
@@ -645,6 +736,9 @@ class LlamaDecoderLayer(nn.Module):
645
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
646
  )
647
 
 
 
 
648
  residual = hidden_states
649
 
650
  hidden_states = self.input_layernorm(hidden_states)
@@ -664,7 +758,7 @@ class LlamaDecoderLayer(nn.Module):
664
  # Fully Connected
665
  residual = hidden_states
666
  hidden_states = self.post_attention_layernorm(hidden_states)
667
- hidden_states = self.mlp(hidden_states)
668
  hidden_states = residual + hidden_states
669
 
670
  outputs = (hidden_states,)
@@ -843,10 +937,6 @@ def compute_loss(logits, labels, shift=False):
843
  if (valid_token_num == 0).any():
844
  batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.)
845
 
846
- # print("beacon")
847
- # print(f"token_loss: {token_loss[:, :100].tolist()}")
848
- # print(f"batch_loss: {batch_loss}")
849
- # input()
850
  return loss, batch_loss, valid_token_num
851
 
852
  @dataclass
@@ -895,7 +985,7 @@ class LlamaModel(LlamaPreTrainedModel):
895
  self.post_init()
896
 
897
  def _init_beacon_embed(self):
898
- """Initialize the fold token embedding with that of the eos token."""
899
  if is_deepspeed_zero3_enabled():
900
  import deepspeed
901
  params = [self.beacon_embed_tokens.weight, self.embed_tokens.weight]
@@ -1109,6 +1199,15 @@ class LlamaModel(LlamaPreTrainedModel):
1109
 
1110
  hidden_states = self.norm(hidden_states)
1111
 
 
 
 
 
 
 
 
 
 
1112
  # add hidden states from the last decoder layer
1113
  if output_hidden_states:
1114
  all_hidden_states += (hidden_states,)
@@ -1139,9 +1238,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1139
 
1140
  def set_memory(self):
1141
  config: LlamaConfig = self.config
1142
- info = f"applying activation beacon on {'all' if config.beacon_layers is None else config.beacon_layers} layers, with window size {config.beacon_window}, stride {config.beacon_stride} (mixed by {config.beacon_stride_mix}), {config.beacon_attn} attention, and condensing ratio {config.beacon_ratio} (mixed by {config.beacon_ratio_mix}), seed {config.beacon_seed}..."
1143
- logger.info(info)
1144
-
1145
  self.memory = Memory(
1146
  model_config=config,
1147
  beacon_window=config.beacon_window,
@@ -1151,10 +1247,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1151
  beacon_ratio=config.beacon_ratio,
1152
  beacon_stride_mix=config.beacon_stride_mix,
1153
  beacon_ratio_mix=config.beacon_ratio_mix,
1154
- beacon_seed=config.beacon_seed,
1155
- beacon_layers=config.beacon_layers,
1156
  k_seq_dim=2,
1157
  v_seq_dim=2,
 
 
1158
  )
1159
 
1160
  def get_input_embeddings(self):
@@ -1180,15 +1277,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
1180
  """Override the default from_pretrained to extend vocab size according to beacon_size."""
1181
  model, loading_info = super().from_pretrained(*args, **kwargs, output_loading_info=True)
1182
  missing_keys = loading_info["missing_keys"]
1183
- # only initialize weights when they are missing from the checkpoint
1184
- if any("beacon_embed_tokens" in missing_key for missing_key in missing_keys):
1185
- # initialize weights of embedding layers for fold tokens
1186
- model.model._init_beacon_embed()
1187
- if any("beacon_q_proj" in missing_key for missing_key in missing_keys):
1188
- # initialize weights of linear layers for fold tokens
1189
- for layer in model.model.layers:
1190
- if hasattr(layer.self_attn, "_init_beacon_proj"):
1191
- layer.self_attn._init_beacon_proj()
 
 
 
 
 
 
 
 
 
 
 
1192
  return model
1193
 
1194
  def _native_forward(
@@ -1397,7 +1505,7 @@ def evaluate_perplexity(model, dataloader, accelerator:Optional[Accelerator]=Non
1397
 
1398
  # NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
1399
  if hasattr(output, "batch_loss"):
1400
- # output from Fold-Llama has batch_loss by default
1401
  batch_loss = output.batch_loss
1402
  valid_token_num = output.valid_token_num
1403
  else:
@@ -1415,9 +1523,10 @@ def evaluate_perplexity(model, dataloader, accelerator:Optional[Accelerator]=Non
1415
  all_loss[_id].append((_loss * _num, _num))
1416
 
1417
  for _id, loss_and_num in all_loss.items():
1418
- # sum up the loss for all valid tokens, and divide the number of valid tokens
1419
  all_loss[_id] = sum([x[0] for x in loss_and_num]) / sum(x[1] for x in loss_and_num)
1420
 
 
1421
  perplexity = math.exp(sum(all_loss.values()) / len(all_loss))
1422
  return perplexity
1423
 
 
226
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
227
  self.act_fn = ACT2FN[config.hidden_act]
228
 
229
+ if "mlp" in config.beacon_param:
230
+ self.beacon_up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
231
+ self.beacon_down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
232
+ self.beacon_up_proj._is_hf_initialized = True
233
+ self.beacon_down_proj._is_hf_initialized = True
234
+
235
+ def _init_beacon_proj(self, beacon_param=None):
236
+ """Initialize the beacon projection weight with that of the ordinal projection."""
237
+ if beacon_param is None:
238
+ beacon_param = self.config.beacon_param
239
+
240
+ if is_deepspeed_zero3_enabled():
241
+ import deepspeed
242
+ params = [self.up_proj, self.down_proj, self.beacon_up_proj, self.beacon_down_proj]
243
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
244
+ if "mlp" in beacon_param:
245
+ self.beacon_up_proj.weight.data[:] = self.up_proj.weight.data
246
+ self.beacon_down_proj.weight.data[:] = self.down_proj.weight.data
247
+ else:
248
+ # only copy the value in-place, without tieing the weight
249
+ if "mlp" in beacon_param:
250
+ self.beacon_up_proj.weight.data[:] = self.up_proj.weight.data
251
+ self.beacon_down_proj.weight.data[:] = self.down_proj.weight.data
252
+
253
+ def forward(self, x, beacon_size):
254
  if self.config.pretraining_tp > 1:
255
+ # TODO: support pretraining_tp
256
+ raise NotImplementedError
257
+
258
  slice = self.intermediate_size // self.config.pretraining_tp
259
  gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
260
  up_proj_slices = self.up_proj.weight.split(slice, dim=0)
 
270
  F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
271
  ]
272
  down_proj = sum(down_proj)
273
+
274
  else:
275
+ if "mlp" in self.config.beacon_param:
276
+ if beacon_size > 0:
277
+ ordinal_hidden_states = x[:, :-beacon_size]
278
+ beacon_hidden_states = x[:, -beacon_size:]
279
+
280
+ # ordinal_up_proj = self.up_proj(ordinal_hidden_states)
281
+ # beacon_up_proj = self.beacon_up_proj(beacon_hidden_states)
282
+ # up_proj = torch.cat([ordinal_up_proj, beacon_up_proj], dim=1)
283
+ # intermediate = self.act_fn(self.gate_proj(x)) * up_proj
284
+ # ordinal_down_proj = self.down_proj(intermediate[:, :-beacon_size])
285
+ # beacon_down_proj = self.beacon_down_proj(intermediate[:, -beacon_size:])
286
+ # down_proj = torch.cat([ordinal_down_proj, beacon_down_proj], dim=1)
287
+
288
+ ordinal_down_proj = self.down_proj(self.act_fn(self.gate_proj(ordinal_hidden_states)) * self.up_proj(ordinal_hidden_states))
289
+ beacon_down_proj = self.beacon_down_proj(self.act_fn(self.gate_proj(beacon_hidden_states)) * self.beacon_up_proj(beacon_hidden_states))
290
+ down_proj = torch.cat([ordinal_down_proj, beacon_down_proj], dim=1)
291
+ else:
292
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
293
+ else:
294
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
295
 
296
  return down_proj
297
 
 
344
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
345
  self._init_rope()
346
 
347
+ # NOTE: add extra parameters for beacon tokens
 
 
 
 
348
  # skip post initialization to speed up loading
349
+ if "q" in config.beacon_param:
350
+ self.beacon_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
351
+ self.beacon_q_proj._is_hf_initialized = True
352
+ if "k" in config.beacon_param:
353
+ self.beacon_k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
354
+ self.beacon_k_proj._is_hf_initialized = True
355
+ if "v" in config.beacon_param:
356
+ self.beacon_v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
357
+ self.beacon_v_proj._is_hf_initialized = True
358
+ if "o" in config.beacon_param:
359
+ self.beacon_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
360
+ self.beacon_o_proj._is_hf_initialized = True
361
 
362
  def _init_rope(self):
363
  if self.config.rope_scaling is None:
 
386
  else:
387
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
388
 
389
+ def _init_beacon_proj(self, beacon_param=None):
390
+ """Initialize the beacon projection weight with that of the ordinal projection."""
391
+ if beacon_param is None:
392
+ beacon_param = self.config.beacon_param
393
+
394
  if is_deepspeed_zero3_enabled():
395
  import deepspeed
396
  params = [self.beacon_q_proj.weight, self.beacon_k_proj.weight, self.beacon_v_proj.weight, self.beacon_o_proj.weight, self.q_proj.weight, self.k_proj.weight, self.v_proj.weight, self.o_proj.weight]
397
  with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
398
+ if "q" in beacon_param:
399
+ self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
400
+ if "k" in beacon_param:
401
+ self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
402
+ if "v" in beacon_param:
403
+ self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
404
+ if "o" in beacon_param:
405
+ self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
406
+ else:
407
+ # only copy the value in-place, without tieing the weight
408
+ if "q" in beacon_param:
409
  self.beacon_q_proj.weight.data[:] = self.q_proj.weight.data
410
+ if "k" in beacon_param:
411
  self.beacon_k_proj.weight.data[:] = self.k_proj.weight.data
412
+ if "v" in beacon_param:
413
  self.beacon_v_proj.weight.data[:] = self.v_proj.weight.data
414
+ if "o" in beacon_param:
415
  self.beacon_o_proj.weight.data[:] = self.o_proj.weight.data
 
 
 
 
 
 
416
 
417
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
418
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
422
  ordinal_hidden_states = hidden_states[:, :-beacon_size]
423
  beacon_hidden_states = hidden_states[:, -beacon_size:]
424
 
425
+ if "q" in self.config.beacon_param:
426
+ ordinal_query_states = self.q_proj(ordinal_hidden_states)
427
+ beacon_query_states = self.beacon_q_proj(beacon_hidden_states)
428
+ query_states = torch.cat([ordinal_query_states, beacon_query_states], dim=1)
429
+ else:
430
+ query_states = self.q_proj(hidden_states)
431
+
432
+ if "k" in self.config.beacon_param:
433
+ ordinal_key_states = self.k_proj(ordinal_hidden_states)
434
+ beacon_key_states = self.beacon_k_proj(beacon_hidden_states)
435
+ key_states = torch.cat([ordinal_key_states, beacon_key_states], dim=1)
436
+ else:
437
+ key_states = self.k_proj(hidden_states)
438
 
439
+ if "v" in self.config.beacon_param:
440
+ ordinal_value_states = self.v_proj(ordinal_hidden_states)
441
+ beacon_value_states = self.beacon_v_proj(beacon_hidden_states)
442
+ value_states = torch.cat([ordinal_value_states, beacon_value_states], dim=1)
443
+ else:
444
+ value_states = self.v_proj(hidden_states)
445
 
446
  else:
447
  query_states = self.q_proj(hidden_states)
 
449
  value_states = self.v_proj(hidden_states)
450
 
451
  return query_states, key_states, value_states
452
+
453
+ def o_proj_with_beacon(self, attn_output, beacon_size=0):
454
+ if beacon_size > 0:
455
+ if "o" in self.config.beacon_param:
456
+ ordinal_attn_output = self.o_proj(attn_output[:, :-beacon_size])
457
+ beacon_attn_output = self.beacon_o_proj(attn_output[:, -beacon_size:])
458
+ attn_output = torch.cat([ordinal_attn_output, beacon_attn_output], dim=1)
459
+ else:
460
+ attn_output = self.o_proj(attn_output)
461
+ else:
462
+ attn_output = self.o_proj(attn_output)
463
+ return attn_output
464
 
465
  def forward(
466
  self,
 
486
  else:
487
  past_seq_len = 0
488
 
 
489
  if self.config.pretraining_tp > 1:
490
+ # TODO: support pretraining_tp
491
+ raise NotImplementedError
492
+
493
  key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
494
  query_slices = self.q_proj.weight.split(
495
  (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
 
515
 
516
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
517
 
518
+ # return keys and values before rope
519
+ # NOTE: incrementally return keys and values for efficiency
520
+ past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
521
+
522
  if past_key is not None:
523
  # reuse k, v, self_attention
524
  key_states = torch.cat([past_key, key_states], dim=2)
525
  value_states = torch.cat([past_value, value_states], dim=2)
 
 
 
526
 
527
  key_position_ids = position_ids
528
  # align query position_ids with key
 
566
 
567
  if self.config.pretraining_tp > 1:
568
  # TODO: support pretraining_tp
569
+ raise NotImplementedError
570
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
571
  o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
572
  attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
573
+
574
  else:
575
+ attn_output = self.o_proj_with_beacon(attn_output, beacon_size)
 
 
 
 
 
576
 
577
  if not output_attentions:
578
  attn_weights = None
 
628
 
629
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
630
 
631
+ # return keys and values before rope
632
+ # NOTE: incrementally return keys and values for efficiency
633
+ past_key_value = (key_states, value_states, beacon_size, raw_size_to_cache, window_size)
634
+
635
  if past_key is not None:
636
  # reuse k, v, self_attention
637
  key_states = torch.cat([past_key, key_states], dim=2)
638
  value_states = torch.cat([past_value, value_states], dim=2)
639
 
 
 
 
640
  key_position_ids = position_ids
641
  # align query position_ids with key
642
  query_position_ids = key_position_ids[:, -q_len:]
 
672
 
673
  attn_output = attn_output.transpose(1, 2).contiguous()
674
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
675
+ attn_output = self.o_proj_with_beacon(attn_output, beacon_size)
676
+
677
+ # for debug
678
+ # if torch.distributed.get_rank() == 4 and self.layer_idx == 0:
679
+ # torch.save({
680
+ # "hidden_states": hidden_states,
681
+ # "past_key_value": past_key_value,
682
+ # "query_states": query_states,
683
+ # "key_states": key_states,
684
+ # "value_states": value_states,
685
+ # "attn_output": attn_output,
686
+ # "attention_mask": attention_mask,
687
+ # "key_position_ids": key_position_ids,
688
+ # }, "beacon_llama_layer_0")
689
 
690
  return attn_output, None, past_key_value
691
 
 
736
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
737
  )
738
 
739
+ # NOTE: get beacon_size in case the mlp is included in beacon_param
740
+ past_key, past_value, beacon_size, raw_size_to_cache, window_size = past_key_value
741
+
742
  residual = hidden_states
743
 
744
  hidden_states = self.input_layernorm(hidden_states)
 
758
  # Fully Connected
759
  residual = hidden_states
760
  hidden_states = self.post_attention_layernorm(hidden_states)
761
+ hidden_states = self.mlp(hidden_states, beacon_size)
762
  hidden_states = residual + hidden_states
763
 
764
  outputs = (hidden_states,)
 
937
  if (valid_token_num == 0).any():
938
  batch_loss = batch_loss.masked_fill(valid_token_num == 0, 0.)
939
 
 
 
 
 
940
  return loss, batch_loss, valid_token_num
941
 
942
  @dataclass
 
985
  self.post_init()
986
 
987
  def _init_beacon_embed(self):
988
+ """Initialize the beacon token embedding with that of the eos token."""
989
  if is_deepspeed_zero3_enabled():
990
  import deepspeed
991
  params = [self.beacon_embed_tokens.weight, self.embed_tokens.weight]
 
1199
 
1200
  hidden_states = self.norm(hidden_states)
1201
 
1202
+ # for debug
1203
+ # if torch.distributed.get_rank() == 4:
1204
+ # torch.save({
1205
+ # "hidden_states": hidden_states,
1206
+ # "past_key_values": past_key_values,
1207
+ # "attention_mask": attention_mask,
1208
+ # "position_ids": position_ids,
1209
+ # }, "beacon_llama_inputs")
1210
+
1211
  # add hidden states from the last decoder layer
1212
  if output_hidden_states:
1213
  all_hidden_states += (hidden_states,)
 
1238
 
1239
  def set_memory(self):
1240
  config: LlamaConfig = self.config
 
 
 
1241
  self.memory = Memory(
1242
  model_config=config,
1243
  beacon_window=config.beacon_window,
 
1247
  beacon_ratio=config.beacon_ratio,
1248
  beacon_stride_mix=config.beacon_stride_mix,
1249
  beacon_ratio_mix=config.beacon_ratio_mix,
1250
+ beacon_param=config.beacon_param,
 
1251
  k_seq_dim=2,
1252
  v_seq_dim=2,
1253
+ retrieval_method=config.retrieval_method,
1254
+ retrieval_topk=config.retrieval_topk,
1255
  )
1256
 
1257
  def get_input_embeddings(self):
 
1277
  """Override the default from_pretrained to extend vocab size according to beacon_size."""
1278
  model, loading_info = super().from_pretrained(*args, **kwargs, output_loading_info=True)
1279
  missing_keys = loading_info["missing_keys"]
1280
+ # only initialize beacon weights when they are missing from the checkpoint
1281
+ beacon_param = set()
1282
+ for missing_key in missing_keys:
1283
+ if "beacon_embed_tokens" in missing_key:
1284
+ model.model._init_beacon_embed()
1285
+ elif "beacon_q_proj" in missing_key:
1286
+ beacon_param.add("q")
1287
+ elif "beacon_k_proj" in missing_key:
1288
+ beacon_param.add("k")
1289
+ elif "beacon_v_proj" in missing_key:
1290
+ beacon_param.add("v")
1291
+ elif "beacon_o_proj" in missing_key:
1292
+ beacon_param.add("o")
1293
+ elif "beacon_up_proj" in missing_key:
1294
+ beacon_param.add("mlp")
1295
+
1296
+ # initialize weights of possible q,k,v,o,mlp
1297
+ for layer in model.model.layers:
1298
+ layer.self_attn._init_beacon_proj(beacon_param)
1299
+ layer.mlp._init_beacon_proj(beacon_param)
1300
  return model
1301
 
1302
  def _native_forward(
 
1505
 
1506
  # NOTE: we need the loss for each element in the batch for accurate computation, because the number of valid tokens may differ among elements
1507
  if hasattr(output, "batch_loss"):
1508
+ # output from our model has batch_loss by default
1509
  batch_loss = output.batch_loss
1510
  valid_token_num = output.valid_token_num
1511
  else:
 
1523
  all_loss[_id].append((_loss * _num, _num))
1524
 
1525
  for _id, loss_and_num in all_loss.items():
1526
+ # sum up the loss for all valid tokens in the entire sequence, and divide the number of valid tokens
1527
  all_loss[_id] = sum([x[0] for x in loss_and_num]) / sum(x[1] for x in loss_and_num)
1528
 
1529
+ # average across then take exp
1530
  perplexity = math.exp(sum(all_loss.values()) / len(all_loss))
1531
  return perplexity
1532