latentnavigation-flux / clip_slider_pipeline.py
linoyts's picture
linoyts HF staff
Update app.py (#3)
e0bab07 verified
raw
history blame contribute delete
No virus
26.4 kB
import diffusers
import torch
import random
from tqdm import tqdm
from constants import SUBJECTS, MEDIUMS
from PIL import Image
class CLIPSlider:
def __init__(
self,
sd_pipe,
device: torch.device,
target_word: str = "",
opposite: str = "",
target_word_2nd: str = "",
opposite_2nd: str = "",
iterations: int = 300,
):
self.device = device
self.pipe = sd_pipe.to(self.device, torch.float16)
self.iterations = iterations
if target_word != "" or opposite != "":
self.avg_diff = self.find_latent_direction(target_word, opposite)
else:
self.avg_diff = None
if target_word_2nd != "" or opposite_2nd != "":
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
else:
self.avg_diff_2nd = None
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
return avg_diff
def generate(self,
prompt = "a photo of a house",
scale = 2.,
scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None
seed = 15,
only_pooler = False,
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
with torch.no_grad():
toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
if self.avg_diff_2nd and normalize_scales:
denominator = abs(scale) + abs(scale_2nd)
scale = scale / denominator
scale_2nd = scale_2nd / denominator
if only_pooler:
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
if self.avg_diff_2nd:
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
else:
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
# weights = torch.sigmoid((weights-0.5)*7)
prompt_embeds = prompt_embeds + (
weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
if self.avg_diff_2nd:
prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images
return images
def spectrum(self,
prompt="a photo of a house",
low_scale=-2,
low_scale_2nd=-2,
high_scale=2,
high_scale_2nd=2,
steps=5,
seed=15,
only_pooler=False,
normalize_scales=False,
correlation_weight_factor=1.0,
**pipeline_kwargs
):
images = []
for i in range(steps):
scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1)
image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs)
images.append(image[0])
canvas = Image.new('RGB', (640 * steps, 640))
for i, im in enumerate(images):
canvas.paste(im, (640 * i, 0))
return canvas
class CLIPSliderXL(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
positives2 = []
negatives2 = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
positives2.append(pos2)
negatives2.append(neg2)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
positives2 = torch.cat(positives2, dim=0)
negatives2 = torch.cat(negatives2, dim=0)
diffs2 = positives2 - negatives2
avg_diff2 = diffs2.mean(0, keepdim=True)
return (avg_diff, avg_diff2)
def generate(self,
prompt = "a photo of a house",
scale = 2,
scale_2nd = 2,
seed = 15,
only_pooler = False,
normalize_scales = False,
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
with torch.no_grad():
# toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.cuda()
# prompt_embeds = pipe.text_encoder(toks).last_hidden_state
prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders):
tokenizer = tokenizers[i]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
toks = text_inputs.input_ids
prompt_embeds = text_encoder(
toks.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
if self.avg_diff_2nd and normalize_scales:
denominator = abs(scale) + abs(scale_2nd)
scale = scale / denominator
scale_2nd = scale_2nd / denominator
if only_pooler:
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
if self.avg_diff_2nd:
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
else:
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
if i == 0:
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768)
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
if self.avg_diff_2nd:
prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
else:
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
if self.avg_diff_2nd:
prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
**pipeline_kwargs).images
return images
class CLIPSliderXL_inv(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
positives2 = []
negatives2 = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
positives2.append(pos2)
negatives2.append(neg2)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
positives2 = torch.cat(positives2, dim=0)
negatives2 = torch.cat(negatives2, dim=0)
diffs2 = positives2 - negatives2
avg_diff2 = diffs2.mean(0, keepdim=True)
return (avg_diff, avg_diff2)
def generate(self,
prompt = "a photo of a house",
scale = 2,
scale_2nd = 2,
seed = 15,
only_pooler = False,
normalize_scales = False,
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
with torch.no_grad():
torch.manual_seed(seed)
images = self.pipe(editing_prompt=prompt,
avg_diff=self.avg_diff, avg_diff_2nd=self.avg_diff_2nd,
scale=scale, scale_2nd=scale_2nd,
**pipeline_kwargs).images
return images
class CLIPSliderFlux(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str,
num_iterations: int = None):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
if num_iterations is not None:
iterations = num_iterations
else:
iterations = self.iterations
with torch.no_grad():
positives = []
negatives = []
for i in tqdm(range(iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer(pos_prompt,
padding="max_length",
max_length=self.pipe.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",).input_ids.cuda()
neg_toks = self.pipe.tokenizer(neg_prompt,
padding="max_length",
max_length=self.pipe.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",).input_ids.cuda()
pos = self.pipe.text_encoder(pos_toks).pooler_output
neg = self.pipe.text_encoder(neg_toks).pooler_output
positives.append(pos)
negatives.append(neg)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
return avg_diff
def generate(self,
prompt = "a photo of a house",
scale = 2,
scale_2nd = 2,
seed = 15,
normalize_scales = False,
avg_diff = None,
avg_diff_2nd = None,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
with torch.no_grad():
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
# Use pooled output of CLIPTextModel
text_inputs = self.pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=512,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
toks = text_inputs.input_ids
prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
dtype = self.pipe.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
if avg_diff_2nd is not None and normalize_scales:
denominator = abs(scale) + abs(scale_2nd)
scale = scale / denominator
scale_2nd = scale_2nd / denominator
pooled_prompt_embeds = pooled_prompt_embeds + avg_diff * scale
if avg_diff_2nd is not None:
pooled_prompt_embeds += avg_diff_2nd * scale_2nd
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
**pipeline_kwargs).images
return images[0]
def spectrum(self,
prompt="a photo of a house",
low_scale=-2,
low_scale_2nd=-2,
high_scale=2,
high_scale_2nd=2,
steps=5,
seed=15,
normalize_scales=False,
**pipeline_kwargs
):
images = []
for i in range(steps):
scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1)
image = self.generate(prompt, scale, scale_2nd, seed, normalize_scales, **pipeline_kwargs)
images.append(image[0].resize((512,512)))
canvas = Image.new('RGB', (640 * steps, 640))
for i, im in enumerate(images):
canvas.paste(im, (640 * i, 0))
return canvas
class T5SliderFlux(CLIPSlider):
def find_latent_direction(self,
target_word:str,
opposite:str):
# lets identify a latent direction by taking differences between opposites
# target_word = "happy"
# opposite = "sad"
with torch.no_grad():
positives = []
negatives = []
for i in tqdm(range(self.iterations)):
medium = random.choice(MEDIUMS)
subject = random.choice(SUBJECTS)
pos_prompt = f"a {medium} of a {target_word} {subject}"
neg_prompt = f"a {medium} of a {opposite} {subject}"
pos_toks = self.pipe.tokenizer_2(pos_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
return_length=False,
return_overflowing_tokens=False,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
neg_toks = self.pipe.tokenizer_2(neg_prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
return_length=False,
return_overflowing_tokens=False,
max_length=self.pipe.tokenizer_2.model_max_length).input_ids.cuda()
pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0]
neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0]
positives.append(pos)
negatives.append(neg)
positives = torch.cat(positives, dim=0)
negatives = torch.cat(negatives, dim=0)
diffs = positives - negatives
avg_diff = diffs.mean(0, keepdim=True)
return avg_diff
def generate(self,
prompt = "a photo of a house",
scale = 2,
scale_2nd = 2,
seed = 15,
only_pooler = False,
normalize_scales = False,
correlation_weight_factor = 1.0,
**pipeline_kwargs
):
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
# if pooler token only [-4,4] work well
with torch.no_grad():
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
# Use pooled output of CLIPTextModel
text_inputs = self.pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=512,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
toks = text_inputs.input_ids
prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0]
dtype = self.pipe.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
if self.avg_diff_2nd and normalize_scales:
denominator = abs(scale) + abs(scale_2nd)
scale = scale / denominator
scale_2nd = scale_2nd / denominator
if only_pooler:
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
if self.avg_diff_2nd:
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
else:
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2])
standard_weights = torch.ones_like(weights)
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
prompt_embeds = prompt_embeds + (
weights * self.avg_diff * scale)
if self.avg_diff_2nd:
prompt_embeds += (
weights * self.avg_diff_2nd * scale_2nd)
torch.manual_seed(seed)
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
**pipeline_kwargs).images
return images