-
Notifications
You must be signed in to change notification settings - Fork 9
/
pipeline.mojo
128 lines (112 loc) · 5.28 KB
/
pipeline.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from helpers.utils import *
from clip import CLIP
from vae import Encoder, Decoder
from diffusion import Diffusion
from sampler import DDPMSampler
# Image width / height. Make sure it is a multiple of 8! Here, we assume a square image
alias image_size = 8
# We set the number of inference steps to 1, as we only want to do a single forward pass. Typical values would be around 50
# Also, this runs on a batch size of 1 (like in stochastic gradient descent. To use the same code but with a higher batch size, create a Matrix_Array struct (available in utils.mojo) and parallelize the generate() code for all its elements.
fn generate(
prompt: String,
backup_prompt: String = "",
strength: Float32 = 0.8,
cfg: Bool = True,
cfg_scale: Float32 = 7.5,
inference_steps: Int = 1,
seed_val: Int = 0,
input_image: Matrix[float_dtype] = Matrix[float_dtype](0, 0, 0),
) -> Matrix[float_dtype]:
if (
not SIMD[DType.float32, 1].splat(0.0)
<= strength
<= SIMD[DType.float32, 1].splat(1.0)
):
print("Strength must be between 0 and 1. Returning empty matrix")
return Matrix[float_dtype](0, 0, 0)
var clip = CLIP()
var tokenizer_ref = StringRef("tokenizer_clip.bin")
var tokenizer_buffer = FileBuf()
read_file(tokenizer_ref, tokenizer_buffer)
# Using a vocab size of 49408, since we rely on the CLIP Tokenizer
var tokenizer = Tokenizer(49408, tokenizer_buffer)
var context: Matrix[float_dtype]
var processed_prompt = prompt.replace(" ", "</w>")
var processed_backup = backup_prompt.replace(" ", "</w>")
if cfg:
var prompt_tokens = List[Int]()
var cond_tokens_vector = bpe_encode(processed_prompt, tokenizer)
var cond_tokens = vector_to_matrix(cond_tokens_vector)
var cond_context = clip.forward(cond_tokens)
var backup_tokens_vector = bpe_encode(processed_backup, tokenizer)
var backup_tokens = vector_to_matrix(backup_tokens_vector)
var backup_context = clip.forward(backup_tokens)
context = cond_context.concat(backup_context, dim=0)
else:
var tokens_vector = bpe_encode(processed_prompt, tokenizer)
var tokens = vector_to_matrix(tokens_vector)
context = clip.forward(tokens)
print("CLIP forward pass concluded")
var sampler = DDPMSampler(seed_val)
sampler.set_inference_timesteps(inference_steps)
var latents_shape = (4, image_size // 8, image_size // 8)
var latents = Matrix[float_dtype](
Tuple.get[0, Int](latents_shape),
Tuple.get[1, Int](latents_shape),
Tuple.get[2, Int](latents_shape),
)
if input_image.size() > 0:
var encoder = Encoder()
print("Encoder instance created")
var resized_input = resize_image(input_image, image_size, image_size)
var rescaled_input = resized_input.rescale((0, 255), (-1, 1))
var encoder_noise = Matrix[float_dtype](
Tuple.get[0, Int](latents_shape),
Tuple.get[1, Int](latents_shape),
Tuple.get[2, Int](latents_shape),
)
encoder_noise.init_weights_seed(seed_val)
latents = encoder.forward(rescaled_input, encoder_noise)
sampler.set_strength(strength)
latents = sampler.add_noise(latents, sampler.timesteps[0])
else:
latents.init_weights_seed(seed_val)
var diffusion = Diffusion()
print("Diffusion instance created")
var num_timesteps = sampler.timesteps.num_elements()
for i in range(num_timesteps):
var timestep = sampler.timesteps[i]
var time_embedding = get_time_embedding(timestep)
var model_input = latents
var model_output: Matrix[float_dtype]
if not cfg:
model_output = diffusion.forward(model_input, context, time_embedding)
else:
# Due to a current mojo limitation, I am currently unable to run this parallelized code, so I am using the sequential approach below.
# Hopefully Mojo will support non-embarassingly-parallel tasks soon.
# @parameter
# fn parallel_diffusion(idx: Int):
# if idx == 0:
# model_output1 = diffusion.forward(model_input, context, time_embedding)
# else:
# model_output2 = diffusion.forward(model_input, context, time_embedding)
# parallelize[parallel_diffusion](2, 2)
var model_output1 = diffusion.forward(model_input, context, time_embedding)
var model_output2 = diffusion.forward(model_input, context, time_embedding)
model_output = model_output1.concat(model_output2, dim=0)
var chunked_output = model_output.chunk(0, 2)
var conditional_output = chunked_output[0]
var backup_output = chunked_output[1]
var cfg_scale_f32 = SIMD[DType.float32, 1].splat(
cfg_scale.cast[DType.float32]()
)
model_output = (
conditional_output - backup_output
) * cfg_scale_f32 + backup_output
latents = sampler.step(int(timestep), latents, model_output)
print("Timestep", i, "concluded")
var decoder = Decoder()
var images = decoder.forward(latents)
print("Decoder forward pass concluded")
images = images.rescale((-1, 1), (0, 255), clamp=True)
return images