AccDiffusion / utils.py
fffiloni's picture
Upload 4 files
fe72a39 verified
raw
history blame contribute delete
No virus
35.1 kB
from __future__ import annotations
import os
import cv2
import abc
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.attention import Attention
from PIL import Image
import random
import matplotlib.pyplot as plt
import pdb
import math
from PIL import Image
class P2PCrossAttnProcessor:
def __init__(self, controller, place_in_unet):
super().__init__()
self.controller = controller
self.place_in_unet = place_in_unet
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# one line change
self.controller(attention_probs, is_cross, self.place_in_unet)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def create_controller(
prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device, attn_res
) -> AttentionControl:
edit_type = cross_attention_kwargs.get("edit_type", None)
local_blend_words = cross_attention_kwargs.get("local_blend_words", None)
equalizer_words = cross_attention_kwargs.get("equalizer_words", None)
equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths", None)
n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4)
n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4)
if edit_type == 'visualize':
return AttentionStore(device=device)
# only replace
if edit_type == "replace" and local_blend_words is None:
return AttentionReplace(
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res
)
# replace + localblend
if edit_type == "replace" and local_blend_words is not None:
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res)
return AttentionReplace(
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res
)
# only refine
if edit_type == "refine" and local_blend_words is None:
return AttentionRefine(
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device, attn_res=attn_res
)
# refine + localblend
if edit_type == "refine" and local_blend_words is not None:
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res)
return AttentionRefine(
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device, attn_res=attn_res
)
# only reweight
if edit_type == "reweight" and local_blend_words is None:
assert (
equalizer_words is not None and equalizer_strengths is not None
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
assert len(equalizer_words) == len(
equalizer_strengths
), "equalizer_words and equalizer_strengths must be of same length."
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
equalizer=equalizer,
attn_res=attn_res,
)
# reweight and localblend
if edit_type == "reweight" and local_blend_words:
assert (
equalizer_words is not None and equalizer_strengths is not None
), "To use reweight edit, please specify equalizer_words and equalizer_strengths."
assert len(equalizer_words) == len(
equalizer_strengths
), "equalizer_words and equalizer_strengths must be of same length."
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device, attn_res=attn_res)
return AttentionReweight(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
equalizer=equalizer,
attn_res=attn_res,
local_blend=lb,
)
raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.")
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
return 0
@abc.abstractmethod
def forward(self, attn, is_cross: bool, place_in_unet: str):
raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= self.num_uncond_att_layers:
h = attn.shape[0]
attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
def __init__(self, attn_res=None):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
self.attn_res = attn_res
class EmptyControl(AttentionControl):
def forward(self, attn, is_cross: bool, place_in_unet: str):
return attn
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
if attn.shape[1] <= 32**2: # avoid memory overhead
if self.device.type != 'cuda':
attn = attn.cpu()
self.step_store[key].append(attn)
return attn
def between_steps(self):
if len(self.attention_store) == 0:
self.attention_store = self.step_store
else:
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] += self.step_store[key][i]
self.step_store = self.get_empty_store()
def get_average_attention(self):
average_attention = {
key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store
}
return average_attention
def reset(self):
super(AttentionStore, self).reset()
self.step_store = self.get_empty_store()
self.attention_store = {}
def __init__(self, attn_res=None, device='cuda'):
super(AttentionStore, self).__init__(attn_res)
self.step_store = self.get_empty_store()
self.attention_store = {}
self.device = device
class LocalBlend:
def __call__(self, x_t, attention_store):
# note that this code works on the latent level!
k = 1
# maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] # These are the numbers because we want to take layers that are 256 x 256, I think this can be changed to something smarter...like, get all attentions where thesecond dim is self.attn_res[0] * self.attn_res[1] in up and down cross.
maps = [m for m in attention_store["down_cross"] + attention_store["mid_cross"] + attention_store["up_cross"] if m.shape[1] == self.attn_res[0] * self.attn_res[1]]
maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, self.attn_res[0], self.attn_res[1], self.max_num_words) for item in maps]
maps = torch.cat(maps, dim=1)
maps = (maps * self.alpha_layers).sum(-1).mean(1) # since alpha_layers is all 0s except where we edit, the product zeroes out all but what we change. Then, the sum adds the values of the original and what we edit. Then, we average across dim=1, which is the number of layers.
mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
mask = F.interpolate(mask, size=(x_t.shape[2:]))
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
mask = mask.gt(self.threshold)
mask = mask[:1] + mask[1:]
mask = mask.to(torch.float16)
x_t = x_t[:1] + mask * (x_t - x_t[:1]) # x_t[:1] is the original image. mask*(x_t - x_t[:1]) zeroes out the original image and removes the difference between the original and each image we are generating (mostly just one). Then, it applies the mask on the image. That is, it's only keeping the cells we want to generate.
return x_t
def __init__(
self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, attn_res=None
):
self.max_num_words = 77
self.attn_res = attn_res
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
for i, (prompt, words_) in enumerate(zip(prompts, words)):
if isinstance(words_, str):
words_ = [words_]
for word in words_:
ind = get_word_inds(prompt, word, tokenizer)
alpha_layers[i, :, :, :, :, ind] = 1
self.alpha_layers = alpha_layers.to(device) # a one-hot vector where the 1s are the words we modify (source and target)
self.threshold = threshold
class AttentionControlEdit(AttentionStore, abc.ABC):
def step_callback(self, x_t):
if self.local_blend is not None:
x_t = self.local_blend(x_t, self.attention_store)
return x_t
def replace_self_attention(self, attn_base, att_replace):
if att_replace.shape[2] <= self.attn_res[0]**2:
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
else:
return att_replace
@abc.abstractmethod
def replace_cross_attention(self, attn_base, att_replace):
raise NotImplementedError
def forward(self, attn, is_cross: bool, place_in_unet: str):
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
h = attn.shape[0] // (self.batch_size)
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
attn_base, attn_replace = attn[0], attn[1:]
if is_cross:
alpha_words = self.cross_replace_alpha[self.cur_step]
attn_replace_new = (
self.replace_cross_attention(attn_base, attn_replace) * alpha_words
+ (1 - alpha_words) * attn_replace
)
attn[1:] = attn_replace_new
else:
attn[1:] = self.replace_self_attention(attn_base, attn_replace)
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
return attn
def __init__(
self,
prompts,
num_steps: int,
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
self_replace_steps: Union[float, Tuple[float, float]],
local_blend: Optional[LocalBlend],
tokenizer,
device,
attn_res=None,
):
super(AttentionControlEdit, self).__init__(attn_res=attn_res)
# add tokenizer and device here
self.tokenizer = tokenizer
self.device = device
self.batch_size = len(prompts)
self.cross_replace_alpha = get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps, self.tokenizer
).to(self.device)
if isinstance(self_replace_steps, float):
self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
self.local_blend = local_blend
class AttentionReplace(AttentionControlEdit):
def replace_cross_attention(self, attn_base, att_replace):
return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper)
def __init__(
self,
prompts,
num_steps: int,
cross_replace_steps: float,
self_replace_steps: float,
local_blend: Optional[LocalBlend] = None,
tokenizer=None,
device=None,
attn_res=None,
):
super(AttentionReplace, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res
)
self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)
class AttentionRefine(AttentionControlEdit):
def replace_cross_attention(self, attn_base, att_replace):
attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
return attn_replace
def __init__(
self,
prompts,
num_steps: int,
cross_replace_steps: float,
self_replace_steps: float,
local_blend: Optional[LocalBlend] = None,
tokenizer=None,
device=None,
attn_res=None
):
super(AttentionRefine, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res
)
self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
class AttentionReweight(AttentionControlEdit):
def replace_cross_attention(self, attn_base, att_replace):
if self.prev_controller is not None:
attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
return attn_replace
def __init__(
self,
prompts,
num_steps: int,
cross_replace_steps: float,
self_replace_steps: float,
equalizer,
local_blend: Optional[LocalBlend] = None,
controller: Optional[AttentionControlEdit] = None,
tokenizer=None,
device=None,
attn_res=None,
):
super(AttentionReweight, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, attn_res
)
self.equalizer = equalizer.to(self.device)
self.prev_controller = controller
### util functions for all Edits
def update_alpha_time_word(
alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None
):
if isinstance(bounds, float):
bounds = 0, bounds
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
if word_inds is None:
word_inds = torch.arange(alpha.shape[2])
alpha[:start, prompt_ind, word_inds] = 0
alpha[start:end, prompt_ind, word_inds] = 1
alpha[end:, prompt_ind, word_inds] = 0
return alpha
def get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77
):
if not isinstance(cross_replace_steps, dict):
cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0.0, 1.0)
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
for i in range(len(prompts) - 1):
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i)
for key, item in cross_replace_steps.items():
if key != "default_":
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
for i, ind in enumerate(inds):
if len(ind) > 0:
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
return alpha_time_words
### util functions for LocalBlend and ReplacementEdit
def get_word_inds(text: str, word_place: int, tokenizer):
split_text = text.split(" ")
if isinstance(word_place, str):
word_place = [i for i, word in enumerate(split_text) if word_place == word]
elif isinstance(word_place, int):
word_place = [word_place]
out = []
if len(word_place) > 0:
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
cur_len, ptr = 0, 0
for i in range(len(words_encode)):
cur_len += len(words_encode[i])
if ptr in word_place:
out.append(i + 1)
if cur_len >= len(split_text[ptr]):
ptr += 1
cur_len = 0
return np.array(out)
### util functions for ReplacementEdit
def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
words_x = x.split(" ")
words_y = y.split(" ")
if len(words_x) != len(words_y):
raise ValueError(
f"attention replacement edit can only be applied on prompts with the same length"
f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words."
)
inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
mapper = np.zeros((max_len, max_len))
i = j = 0
cur_inds = 0
while i < max_len and j < max_len:
if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
if len(inds_source_) == len(inds_target_):
mapper[inds_source_, inds_target_] = 1
else:
ratio = 1 / len(inds_target_)
for i_t in inds_target_:
mapper[inds_source_, i_t] = ratio
cur_inds += 1
i += len(inds_source_)
j += len(inds_target_)
elif cur_inds < len(inds_source):
mapper[i, j] = 1
i += 1
j += 1
else:
mapper[j, j] = 1
i += 1
j += 1
# return torch.from_numpy(mapper).float()
return torch.from_numpy(mapper).to(torch.float16)
def get_replacement_mapper(prompts, tokenizer, max_len=77):
x_seq = prompts[0]
mappers = []
for i in range(1, len(prompts)):
mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
mappers.append(mapper)
return torch.stack(mappers)
### util functions for ReweightEdit
def get_equalizer(
text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer
):
if isinstance(word_select, (int, str)):
word_select = (word_select,)
equalizer = torch.ones(len(values), 77)
values = torch.tensor(values, dtype=torch.float32)
for i, word in enumerate(word_select):
inds = get_word_inds(text, word, tokenizer)
equalizer[:, inds] = torch.FloatTensor(values[i])
return equalizer
### util functions for RefinementEdit
class ScoreParams:
def __init__(self, gap, match, mismatch):
self.gap = gap
self.match = match
self.mismatch = mismatch
def mis_match_char(self, x, y):
if x != y:
return self.mismatch
else:
return self.match
def get_matrix(size_x, size_y, gap):
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
matrix[0, 1:] = (np.arange(size_y) + 1) * gap
matrix[1:, 0] = (np.arange(size_x) + 1) * gap
return matrix
def get_traceback_matrix(size_x, size_y):
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
matrix[0, 1:] = 1
matrix[1:, 0] = 2
matrix[0, 0] = 4
return matrix
def global_align(x, y, score):
matrix = get_matrix(len(x), len(y), score.gap)
trace_back = get_traceback_matrix(len(x), len(y))
for i in range(1, len(x) + 1):
for j in range(1, len(y) + 1):
left = matrix[i, j - 1] + score.gap
up = matrix[i - 1, j] + score.gap
diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
matrix[i, j] = max(left, up, diag)
if matrix[i, j] == left:
trace_back[i, j] = 1
elif matrix[i, j] == up:
trace_back[i, j] = 2
else:
trace_back[i, j] = 3
return matrix, trace_back
def get_aligned_sequences(x, y, trace_back):
x_seq = []
y_seq = []
i = len(x)
j = len(y)
mapper_y_to_x = []
while i > 0 or j > 0:
if trace_back[i, j] == 3:
x_seq.append(x[i - 1])
y_seq.append(y[j - 1])
i = i - 1
j = j - 1
mapper_y_to_x.append((j, i))
elif trace_back[i][j] == 1:
x_seq.append("-")
y_seq.append(y[j - 1])
j = j - 1
mapper_y_to_x.append((j, -1))
elif trace_back[i][j] == 2:
x_seq.append(x[i - 1])
y_seq.append("-")
i = i - 1
elif trace_back[i][j] == 4:
break
mapper_y_to_x.reverse()
return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
def get_mapper(x: str, y: str, tokenizer, max_len=77):
x_seq = tokenizer.encode(x)
y_seq = tokenizer.encode(y)
score = ScoreParams(0, 1, -1)
matrix, trace_back = global_align(x_seq, y_seq, score)
mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
alphas = torch.ones(max_len)
alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
mapper = torch.zeros(max_len, dtype=torch.int64)
mapper[: mapper_base.shape[0]] = mapper_base[:, 1]
mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq))
return mapper, alphas
def get_refinement_mapper(prompts, tokenizer, max_len=77):
x_seq = prompts[0]
mappers, alphas = [], []
for i in range(1, len(prompts)):
mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
mappers.append(mapper)
alphas.append(alpha)
return torch.stack(mappers), torch.stack(alphas)
def aggregate_attention(prompts, attention_store: AttentionStore, height: int, width: int, from_where: List[str], is_cross: bool, select: int):
out = []
attention_maps = attention_store.get_average_attention()
attention_map_height = height // 32
attention_map_width = width // 32
num_pixels = attention_map_height * attention_map_width
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] == num_pixels:
cross_maps = item.reshape(len(prompts), -1, attention_map_width, attention_map_height, item.shape[-1])[select]
out.append(cross_maps)
out = torch.cat(out, dim=0)
out = out.sum(0) / out.shape[0]
return out.cpu()
def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0, t=0):
tokens = tokenizer.encode(prompts[select])
decoder = tokenizer.decode
attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
images = []
for i in range(len(tokens)):
image = attention_maps[:, :, i]
image = 255 * image / image.max()
image = image.unsqueeze(-1).expand(*image.shape, 3)
image = image.numpy().astype(np.uint8)
image = np.array(Image.fromarray(image).resize((256, 256)))
image = text_under_image(image, decoder(int(tokens[i])))
images.append(image)
view_images(np.stack(images, axis=0), t=t, from_where=from_where)
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
max_com=10, select: int = 0):
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
images = []
for i in range(max_com):
image = vh[i].reshape(res, res)
image = image - image.min()
image = 255 * image / image.max()
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
image = Image.fromarray(image).resize((256, 256))
image = np.array(image)
images.append(image)
view_images(np.concatenate(images, axis=1),from_where=from_where)
def view_images(images, num_rows=1, offset_ratio=0.02, t=0, from_where= List[str]):
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)
h, w, c = images[0].shape
offset = int(h * offset_ratio)
num_cols = num_items // num_rows
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
for i in range(num_rows):
for j in range(num_cols):
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
i * num_cols + j]
pil_img = Image.fromarray(image_)
if len(from_where) > 1:
from_where = '_'.join(from_where)
save_path = f'./visualization/{from_where}'
if not os.path.exists(save_path):
os.mkdir(save_path)
pil_img.save(f"{save_path}/{t}.png")
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
h, w, c = image.shape
offset = int(h * .2)
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
font = cv2.FONT_HERSHEY_SIMPLEX
img[:h] = image
textsize = cv2.getTextSize(text, font, 1, 2)[0]
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
return img
def get_views(height, width, window_size=32, stride=16, random_jitter=False):
num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1
num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * stride)
h_end = h_start + window_size
w_start = int((i % num_blocks_width) * stride)
w_end = w_start + window_size
if h_end > height:
h_start = int(h_start + height - h_end)
h_end = int(height)
if w_end > width:
w_start = int(w_start + width - w_end)
w_end = int(width)
if h_start < 0:
h_end = int(h_end - h_start)
h_start = 0
if w_start < 0:
w_end = int(w_end - w_start)
w_start = 0
if random_jitter:
jitter_range = (window_size - stride) // 4
w_jitter = 0
h_jitter = 0
if (w_start != 0) and (w_end != width):
w_jitter = random.randint(-jitter_range, jitter_range)
elif (w_start == 0) and (w_end != width):
w_jitter = random.randint(-jitter_range, 0)
elif (w_start != 0) and (w_end == width):
w_jitter = random.randint(0, jitter_range)
if (h_start != 0) and (h_end != height):
h_jitter = random.randint(-jitter_range, jitter_range)
elif (h_start == 0) and (h_end != height):
h_jitter = random.randint(-jitter_range, 0)
elif (h_start != 0) and (h_end == height):
h_jitter = random.randint(0, jitter_range)
h_start += (h_jitter + jitter_range)
h_end += (h_jitter + jitter_range)
w_start += (w_jitter + jitter_range)
w_end += (w_jitter + jitter_range)
views.append((int(h_start), int(h_end), int(w_start), int(w_end)))
return views
def get_multidiffusion_prompts(tokenizer, prompts, threthod, attention_store:AttentionStore, height:int, width:int, from_where: List[str], scale_num=4, random_jitter=False):
tokens = tokenizer.encode(prompts[0])
decoder = tokenizer.decode
# get cross_attention_maps
attention_maps = aggregate_attention(prompts, attention_store, height, width, from_where, True, 0)
# view cross_attention_maps
images = []
for i in range(len(tokens)):
image = attention_maps[:, :, i]
image = 255 * image / image.max()
image = image.unsqueeze(-1).expand(*image.shape, 3).numpy().astype(np.uint8)
image = np.array(Image.fromarray(image).resize((256, 256)))
image = text_under_image(image, decoder(int(tokens[i])))
images.append(image)
# get high attention regions
masks = []
for i in range(len(tokens)):
attention_map = attention_maps[:, :, i]
attention_map = attention_map.to(torch.float32)
words = decoder(int(tokens[i]))
mask = torch.where(attention_map > attention_map.mean(), 1, 0).numpy().astype(np.uint8)
mask = mask * 255
# process mask
kernel = np.ones((3, 3), np.uint8)
eroded_mask = cv2.erode(mask, kernel, iterations=mask.shape[0]//16)
dilated_mask = cv2.dilate(eroded_mask, kernel, iterations=mask.shape[0]//16)
masks.append(dilated_mask)
# dict for prompts and views
prompt_dict = {}
view_dict = {}
ori_w, ori_h = mask.shape
window_size = max(ori_h, ori_w)
for scale in range(2, scale_num+1):
# current height and width
cur_w = ori_w * scale
cur_h = ori_h * scale
views = get_views(height=cur_h, width=cur_w, window_size=window_size, stride=window_size/2, random_jitter=random_jitter)
words_in_patch = []
for i, mask in enumerate(masks):
# skip endoftext and beginof text masks
if i == 0 or i == len(masks) - 1:
continue
# upscale masks
mask = cv2.resize(mask, (cur_w, cur_h), interpolation=cv2.INTER_NEAREST)
if random_jitter:
jitter_range = int((ori_h - ori_h/2) // 4)
mask = np.pad(mask, ((jitter_range, jitter_range), (jitter_range, jitter_range)), 'constant', constant_values=(0, 0))
word_in_patch =[]
word = decoder(int(tokens[i]))
for i, view in enumerate(views):
h_start, h_end, w_start, w_end = view
view_mask = mask[h_start:h_end, w_start:w_end]
if (view_mask/255).sum() / (ori_h * ori_w) >= threthod:
word_in_patch.append(word) # word in patch
else:
word_in_patch.append('') # word not in patch
words_in_patch.append(word_in_patch)
# get prompts for each view
result = []
prompts_for_each_views = [' '.join(strings) for strings in zip(*words_in_patch)]
for prompt in prompts_for_each_views:
prompt = prompt.split()
result.append(" ".join(prompt))
# save prompts and views in each scale
prompt_dict[scale] = result
view_dict[scale] = views
return prompt_dict, view_dict
class ScaledAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __init__(self, processor, test_res, train_res):
self.processor = processor
self.test_res = test_res
self.train_res = train_res
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
input_ndim = hidden_states.ndim
# print(f"cross attention: {not encoder_hidden_states is None}")
# if encoder_hidden_states is None:
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
sequence_length = height * width
else:
batch_size, sequence_length, _ = hidden_states.shape
test_train_ratio = (self.test_res ** 2.0) / (self.train_res ** 2.0)
# test_train_ratio = float(self.test_res / self.train_res)
# print(f"test_train_ratio: {test_train_ratio}")
train_sequence_length = sequence_length / test_train_ratio
scale_factor = math.log(sequence_length, train_sequence_length) ** 0.5
# else:
# scale_factor = 1
# print(f"scale factor: {scale_factor}")
original_scale = attn.scale
attn.scale = attn.scale * scale_factor
hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale = attn.scale )
# hidden_states = super(ScaledAttnProcessor, self).__call__(
# attn, hidden_states, encoder_hidden_states, attention_mask, temb)
attn.scale = original_scale
return hidden_states