import torch import numpy as np import torch.distributed as dist from transformers.utils import logging from typing import List, Tuple, Optional from .modeling_retrieval import BM25Retriever logger = logging.get_logger(__name__) class Memory(torch.nn.Module): def __init__(self, model_config, beacon_window:int=1024, beacon_stride:List[int]=[512], beacon_attn:str="step-expansion", beacon_attend_previous:bool=True, beacon_ratio:List[int]=[8], beacon_stride_mix:str="step-random", beacon_ratio_mix:str="step-random", beacon_param:List[str]=["q", "k", "v", "o"], k_seq_dim:int=2, v_seq_dim:int=2, retrieval_method:str=None, retrieval_topk:int=2) -> None: super().__init__() for stride in beacon_stride: assert beacon_window >= stride, f"Make sure the beacon_window {beacon_window} >= beacon_stride {stride}!" assert beacon_attn in ["segmentation", "step-expansion", "full-coverage"], f"beacon_attn {beacon_attn} not implemented!" assert beacon_stride_mix in ["instance-random", "step-random", "mix-random"], f"beacon_stride_mix {beacon_stride_mix} not implemented!" assert beacon_ratio_mix in ["instance-random", "step-random", "mix-random", "sequence"] or "adapt-" in beacon_ratio_mix, f"beacon_ratio_mix {beacon_ratio_mix} not implemented!" if retrieval_method == "bm25": assert len(beacon_stride) == 1, f"Currently retrieval do not support dynamic strides." assert retrieval_topk >= 2, f"Make sure retrieval_topk >= 2. Found {retrieval_topk}." assert len(beacon_ratio) == 2, f"Make sure there are two beacon ratios specified, one for retrieved windows and the other for non-retrieved windows. Found {self.beacon_ratio}" info = f"applying activation beacon on {beacon_param}, with window size {beacon_window}, stride {beacon_stride} (mixed by {beacon_stride_mix}), {beacon_attn} attention ({'attending to previous beacons' if beacon_attend_previous else 'not attending to previous beacons'}), condensing ratio {beacon_ratio} (mixed by {beacon_ratio_mix}), {retrieval_method+' retrieval'+' top-'+str(retrieval_topk) if retrieval_method is not None else 'no retrieval'}, ..." self.beacon_window = beacon_window self.beacon_stride = beacon_stride self.beacon_attn = beacon_attn self.beacon_ratio = beacon_ratio self.beacon_stride_mix = beacon_stride_mix self.beacon_ratio_mix = beacon_ratio_mix max_beacon_size = max([beacon_window // x for x in beacon_ratio if x > 0] + [1]) self.beacon_tokens = torch.zeros(max_beacon_size, dtype=torch.long) + model_config.vocab_size # initialize necessary parameters self.k_seq_dim = k_seq_dim self.v_seq_dim = v_seq_dim self.num_layers = model_config.num_hidden_layers self.max_position_embeddings = model_config.max_position_embeddings self.retrieval_method = retrieval_method self.retrieval_topk = retrieval_topk self.rng = np.random.default_rng(42) self.reset() @property def finish(self): return self.end_idx == self.sequence_length def get_memory_size(self): beacon_memory_size = 0 raw_memory_size = 0 if self.beacon_activations[0][0] is not None: beacon_memory_size += self.beacon_activations[0][0].shape[self.k_seq_dim] if self.raw_activations[0][0] is not None: raw_memory_size += self.raw_activations[0][0].shape[self.k_seq_dim] memory_size = beacon_memory_size + raw_memory_size return beacon_memory_size, raw_memory_size, memory_size def reset(self): # the length of current sequence self.sequence_length = 0 # the length of all sequences until the memory is reset self.total_sequence_length = 0 # the cursor pointing to the start of the current window self.start_idx = 0 # the cursor pointing to the end of the current window self.end_idx = 0 # the beacon sizes of all strides self._beacon_sizes = [] # the step index self.step_idx = 0 if self.beacon_ratio_mix != "step-random": self._stride = None self._ratio = None self.batch_loss = None self.valid_token_num = None self.raw_activations = [(None, None) for _ in range(self.num_layers)] self.beacon_activations = [(None, None) for _ in range(self.num_layers)] if self.retrieval_method == "bm25": self.retriever = BM25Retriever() self.beacon_ratio_mix = "retrieval" # NOTE: when training, we strictly aligh the rng_state across processes if and dist.is_initialized(): rng_state = self.rng.__getstate__() if dist.get_rank() == 0: obj = [rng_state] else: obj = [None] dist.broadcast_object_list(obj, src=0) self.rng.__setstate__(obj[0]) def prepare(self, input_ids, attention_mask, labels): """ Prepare inputs for the model. """ # TODO: support batch_size > 1? assert input_ids.shape[0] == 1, f"Make sure batch_size is 1!" # NOTE: rebase the start/end idx so that it becomes an offset from the start of the current sequence self.start_idx -= self.sequence_length self.end_idx -= self.sequence_length self.sequence_length = input_ids.shape[1] self.total_sequence_length += input_ids.shape[1] if labels is not None: # rotate labels in advance so that the loss of the last token is not ignored in every window labels =[labels[:, 1:], labels.new_zeros((labels.shape[0], 1)) - 100], dim=-1) # if the current sequence has been completely processed self.input_ids = input_ids self.attention_mask = attention_mask self.labels = labels # TODO: retrieval on specified steps # TODO: retrieval for future inputs # TODO: different retrieval methods if self.retrieval_method == "bm25" and self.step_idx == 0: index = BM25Retriever() window = self.beacon_window stride = self.beacon_stride[0] input_ids = input_ids[0].tolist() corpus = [input_ids[:window]] for j in range(window, len(input_ids), stride): corpus.append(input_ids[j: j + stride]) # NOTE: use the last 32 token as query (very naive heuristics) query = input_ids[-32:] index.index(corpus) topk_scores, topk_indices =, hits=self.retrieval_topk) topk_indices = set([x for x in topk_indices[0] if x > -1]) self._topk_indices = topk_indices def set_stride(self): """Choose a stride from self.beacon_stride""" beacon_stride = self.beacon_stride if len(beacon_stride) == 1: return beacon_stride[0] if self.beacon_stride_mix == "mix-random": stride_mix = self.rng.choice(["instance-random", "step-random"]).tolist() else: stride_mix = self.beacon_stride_mix if stride_mix == "instance-random": if self._beacon_stride is None: stride = self.rng.choice(beacon_stride).tolist() self._stride = stride else: stride = self._stride elif stride_mix == "step-random": stride = self.rng.choice(beacon_stride).tolist() else: raise NotImplementedError return stride def set_condensing_ratio(self, beacon_stride, start_idx, end_idx): """Choose a condensing ratio from self.beacon_ratio""" def filter_ratio(ratios, stride): valid_ratios = [] for ratio in ratios: # stride must be bigger than condensing ratio because we there must be at least one beacon if stride < ratio: continue # step-expansion and segmentation requires the stride to be evenly divisible by condensing ratio if self.beacon_attn != "full-coverage" and ratio > 0 and (stride % ratio) != 0: continue # when training, ratio=0 is valid if previous windows contain beacon or later windows contain beacon if ratio == 0 and previous_beacons = [b for b in self._beacon_sizes if b != -1] following_beacons = (start_idx + stride + self.beacon_window) <= self.sequence_length if len(previous_beacons) == 0 and not following_beacons: continue valid_ratios.append(ratio) assert len(valid_ratios), f"Cannot find valid condensing ratio (among {ratios}) for stride {stride}!" return valid_ratios def get_max_length(ratios): max_lengths = [] for condensing_ratio in ratios: if condensing_ratio > 0: max_lengths.append((self.max_position_embeddings - self.beacon_window) * condensing_ratio + self.beacon_window) else: max_lengths.append(self.max_position_embeddings) return max_lengths if len(self.beacon_ratio) == 1: return self.beacon_ratio[0] beacon_ratio = filter_ratio(self.beacon_ratio, beacon_stride) if self.beacon_ratio_mix == "mix-random": ratio_mix = self.rng.choice(["instance-random", "step-random"]).tolist() else: ratio_mix = self.beacon_ratio_mix if ratio_mix == "instance-random": if self._ratio is None: beacon_ratio = self.rng.choice(beacon_ratio).tolist() self._ratio = beacon_ratio else: beacon_ratio = self._ratio elif ratio_mix == "step-random": beacon_ratio = self.rng.choice(beacon_ratio).tolist() elif ratio_mix == "sequence": idx = min(self.step_idx, len(beacon_ratio) - 1) beacon_ratio = beacon_ratio[idx] elif ratio_mix == "retrieval": # for retrieved windows, we use low ratio; otherwise high ratio if self.step_idx in self._topk_indices: beacon_ratio = min(self.beacon_ratio) else: beacon_ratio = max(self.beacon_ratio) elif "adapt" in ratio_mix: if self._ratio is None: future_length = int(ratio_mix.split("-")[1]) sequence_length = self.total_sequence_length + future_length max_lengths = get_max_length(beacon_ratio) # ascendingly sort the max lengths valid_max_lengths_and_indices = [x for x in enumerate(max_lengths) if x[1] >= sequence_length] if len(valid_max_lengths_and_indices): minimum_length_index = min(valid_max_lengths_and_indices, key=lambda x: x[1])[0] # use the minimal possible length for this sequence (the smallest fold ratio) beacon_ratio = beacon_ratio[minimum_length_index] else: beacon_ratio = max(beacon_ratio) # logger.warning(f"Failed to find valid fold window and size for sequence length {sequence_length}, as the maximum theoretical length is {max(max_lengths)}. Fall back to use the maximum one: {beacon_ratio}.") self._ratio = beacon_ratio else: beacon_ratio = self._ratio return beacon_ratio def step(self): """ Yield one window with the following logic: The window size is L, the stride is S. The window moves over S tokens at a time. The raw activations passed by the window are condensed according to a condensing_ratio. The beacons are added if and only if the raw activations fulfill the window. In the future, we may switch window size to decrease cache size of raw activations. """ # the starting position of the current window w.r.t. the start of the current input sequence start_idx = self.start_idx # the end position of the current window w.r.t. the start of the current input sequence end_idx = start_idx + self.beacon_window # indicates if the current window is completely filled by raw activations and new tokens # we only append beacon tokens for full windows is_full_window = True if end_idx > self.sequence_length: # the input is shorter than the initial window size end_idx = self.sequence_length is_full_window = False # the real window size (remaining_size + new_token_size) window_size = end_idx - start_idx if is_full_window: # set stride and condensing ratio beacon_stride = self.set_stride() condensing_ratio = self.set_condensing_ratio(beacon_stride, start_idx=start_idx, end_idx=end_idx) # the stride must be evenly divisible by condensing_ratio if condensing_ratio > 0: beacon_size = beacon_stride // condensing_ratio else: # the raw activations are used as beacon activations beacon_size = -1 # forward start_idx and end_idx next_start_idx = start_idx + beacon_stride # how many raw activations to save raw_size_to_cache = end_idx - next_start_idx self.remaining_size = 0 else: # no stride because the sequence has finished next_start_idx = start_idx # cache all recent raw activations to be used in the next window raw_size_to_cache = window_size self.remaining_size = window_size beacon_size = 0 # this is for debugging the resilience to different beacons # if self.step_idx == 97: # a = torch.load("beacon_activations") # for i, (beacon_key, beacon_value) in enumerate(self.beacon_activations): # foreign_beacon_key = a[i][0] # foreign_beacon_value = a[i][1] # new_beacon_key = cat_tensor([foreign_beacon_key, slice_tensor(beacon_key, start=16, dim=self.k_seq_dim)], dim=self.k_seq_dim) # new_beacon_value = cat_tensor([foreign_beacon_value, slice_tensor(beacon_value, start=16, dim=self.v_seq_dim)], dim=self.v_seq_dim) # new_beacon_key = foreign_beacon_key # new_beacon_value = foreign_beacon_value # self.beacon_activations[i] = (new_beacon_key, new_beacon_value) # streamingly add new input_ids input_ids = self.input_ids[:, self.end_idx: end_idx] batch_size = input_ids.shape[0] if self.attention_mask is not None: attention_mask = self.attention_mask[:, self.end_idx: end_idx] else: attention_mask = torch.ones_like(input_ids) if self.labels is not None: labels = self.labels[:, self.end_idx: end_idx] else: labels = None # prepend 1 to attention mask for previous memory _, _, memory_size = self.get_memory_size() if memory_size > 0: attention_mask =[attention_mask.new_ones(batch_size, memory_size), attention_mask], dim=1) # append beacons if necessary if is_full_window and beacon_size > 0: input_ids =[input_ids, self.beacon_tokens[:beacon_size].expand(batch_size, -1).to(input_ids.device, dtype=input_ids.dtype)], dim=1) # NOTE: prepend beacon_memory_size 1 to attention_mask because we have past_key_values attention_mask =[attention_mask, attention_mask.new_ones(batch_size, beacon_size)], dim=1) if labels is not None: labels =[labels, labels.new_zeros(batch_size, beacon_size) - 100], dim=1) # generate memory (memory_length = old_beacon_size + beacon_size * condensing_ratio + raw_cache_size) past_key_values = [] for (beacon_key, beacon_value), (raw_key, raw_value) in zip(self.beacon_activations, self.raw_activations): key = cat_tensor([beacon_key, raw_key], dim=self.k_seq_dim) value = cat_tensor([beacon_value, raw_value], dim=self.v_seq_dim) layer_past_key_values = (key, value, beacon_size, raw_size_to_cache, window_size) past_key_values.append(layer_past_key_values) # involked in self.output() self._beacon_sizes.append(beacon_size) # update end_idx self.start_idx = next_start_idx self.end_idx = end_idx self.step_idx += 1 # print("****************************************") # if is_full_window: # print(f"total_seq_len: {self.total_sequence_length}") # print(f"stride: {beacon_stride}") # print(f"condensing ratio: {condensing_ratio}") # print(f"beacon_size: {beacon_size}") # print(f"input_ids: {input_ids.shape}") # print(f"start_idx: {start_idx}") # print(f"next_start_idx: {next_start_idx}") # print(f"end_idx: {end_idx}") # x = input() # if x == "s": # return # if self.step_idx == 3: # input() return input_ids, attention_mask, past_key_values, labels def update_memory(self, past_key_values): """ Accumulate beacon activations and raw activations. """ for layer_idx, (key, value, beacon_size, raw_size_to_cache, window_size) in enumerate(past_key_values): # NOTE: the past_key_values are incrementally returned (only the new keys and values are returned) # key/value: (num_layer, 2, batch_size, num_head, new_seq_len, head_dim) # beacon_size: how many beacon activations are in key and value # raw_size_to_cache: how many raw activations should be kept previous_beacon_key, previous_beacon_value = self.beacon_activations[layer_idx] previous_raw_key, previous_raw_value = self.raw_activations[layer_idx] if beacon_size == 0: # this means the current input does not fulfill a window # thus, the key and value are all raw activations, and we accumulate them until the window is fulfilled beacon_key = previous_beacon_key beacon_value = previous_beacon_value assert raw_size_to_cache == window_size raw_key = cat_tensor([ previous_raw_key, key ], dim=self.k_seq_dim) raw_value = cat_tensor([ previous_raw_value, value ], dim=self.v_seq_dim) elif beacon_size == -1: # this means the raw activations are used as beacon activations for this window if raw_size_to_cache > 0: # if we have raw activations, we must first concatenate previous raw activations and current ones, then extract raw_size_to_cache as raw memory, while others as beacon memory concat_key = cat_tensor([ previous_raw_key, key ], dim=self.k_seq_dim) concat_value = cat_tensor([ previous_raw_value, value ], dim=self.v_seq_dim) beacon_key = cat_tensor([ previous_beacon_key, slice_tensor(concat_key, end=-raw_size_to_cache, dim=self.k_seq_dim) ], dim=self.k_seq_dim) beacon_value = cat_tensor([ previous_beacon_value, slice_tensor(concat_value, end=-raw_size_to_cache, dim=self.v_seq_dim) ], dim=self.v_seq_dim) raw_key = slice_tensor(concat_key, start=-raw_size_to_cache, dim=self.k_seq_dim) raw_value = slice_tensor(concat_value, start=-raw_size_to_cache, dim=self.v_seq_dim) else: # if we donot have raw activations, this means stride==window. we put all into beacon memory beacon_key = cat_tensor([ previous_beacon_key, key, ], dim=self.k_seq_dim) beacon_value = cat_tensor([ previous_beacon_value, value, ], dim=self.v_seq_dim) raw_key = None raw_value = None else: # [-beacon_size:] activations are from beacons, need to be accumulated # [-raw_cache_size-beacon_size:-beacon_size] raw activations will be cached; if they are shorter than raw_cache_size, part of the previous raw activations will also be kept beacon_key = cat_tensor([ previous_beacon_key, slice_tensor(key, start=-beacon_size, dim=self.k_seq_dim) ], dim=self.k_seq_dim) beacon_value = cat_tensor([ previous_beacon_value, slice_tensor(value, start=-beacon_size, dim=self.v_seq_dim) ], dim=self.v_seq_dim) if key.shape[self.k_seq_dim] < raw_size_to_cache + beacon_size: concat_raw_key = cat_tensor([ previous_raw_key, slice_tensor(key, end=-beacon_size, dim=self.k_seq_dim) ], dim=self.k_seq_dim) concat_raw_value = cat_tensor([ previous_raw_value, slice_tensor(value, end=-beacon_size, dim=self.v_seq_dim) ], dim=self.v_seq_dim) raw_key = slice_tensor(concat_raw_key, start=-raw_size_to_cache, dim=self.k_seq_dim) raw_value = slice_tensor(concat_raw_value, start=-raw_size_to_cache, dim=self.v_seq_dim) else: # becomes None when raw_size_to_cache = 0 raw_key = slice_tensor(key, start=-raw_size_to_cache - beacon_size, end=-beacon_size, dim=self.k_seq_dim) raw_value = slice_tensor(value, start=-raw_size_to_cache - beacon_size, end=-beacon_size, dim=self.v_seq_dim) self.beacon_activations[layer_idx] = (beacon_key, beacon_value) self.raw_activations[layer_idx] = (raw_key, raw_value) # NOTE: this is for debugging the resilience to different beacons # if self.step_idx == 2: # print(self.get_memory_size()) #, "beacon_activations") def update_loss(self, batch_loss, valid_token_num): """ Accumulate loss for later perplexity computation and backward pass; past_key_values according to cache_method. """ if self.batch_loss is None: # NOTE: multiply valid_token_num because batch_loss is divided by it in advance self.batch_loss = batch_loss * valid_token_num self.valid_token_num = valid_token_num else: # NOTE: avoid in-place operations, otherwise there will be gradient errors in training self.batch_loss = self.batch_loss + batch_loss * valid_token_num self.valid_token_num = self.valid_token_num + valid_token_num def output(self, model_outputs): """ Override loss with accumulated loss. """ # override loss if self.batch_loss is not None: # here the batch_loss is the summation of all token losses in each element loss = self.batch_loss.sum() / self.valid_token_num.sum() # NOTE: prevent nan batch_loss = self.batch_loss / self.valid_token_num if (self.valid_token_num == 0).any(): batch_loss = batch_loss.masked_fill(self.valid_token_num == 0, 0.) # NOTE: we must use dict to override values, otherwise trainer cannot find loss model_outputs["loss"] = loss model_outputs["batch_loss"] = batch_loss model_outputs["valid_token_num"] = self.valid_token_num # override last_hidden_states (used in generation) beacon_size = self._beacon_sizes[-1] # remove logits corresponding to beacon tokens if beacon_size > 0: model_outputs["logits"] = model_outputs["logits"][:, :-beacon_size] # print(f"process {dist.get_rank()}: loss {loss}") # print(f"process {dist.get_rank()}: beacon_sizes {self._beacon_sizes}") return model_outputs def slice_tensor(x, start=None, end=None, dim=2): if x is None: return None if end == 0: return None if start == x.shape[dim]: return None if start == end: return None if dim == 2: if start is None and end is not None: return x[:, :, :end, ...] elif start is not None and end is None: return x[:, :, start:, ...] elif start is not None and end is not None: return x[:, :, start:end, ...] elif dim == 1: if start is None and end is not None: return x[:, :end, ...] elif start is not None and end is None: return x[:, start:, ...] elif start is not None and end is not None: return x[:, start:end, ...] else: raise NotImplementedError def cat_tensor(list_of_tensors, dim=-1): list_of_tensors = [t for t in list_of_tensors if t is not None] if len(list_of_tensors) > 1: result =, dim=dim) elif len(list_of_tensors) == 1: result = list_of_tensors[0] else: result = None return result def softmax(x:np.ndarray, axis=-1, temperature=1): if isinstance(x, list): x = np.array(x) x = x / temperature x = x - x.max(axis=axis, keepdims=True) y = np.exp(x) return y / y.sum(axis=axis, keepdims=True) def l1_norm(x): sum_x = sum(x) x = [y/sum_x for y in x] return x