上ではStableDiffusionPipelineクラスを使って画像を生成しました。また、そこではパイプラインを関数のように呼び出すことで、画像生成が行われると述べました。ここでは、StableDiffusionPipelineクラスを継承するMyStableDiffusionPipelineクラスを定義して、そこで画像生成過程の特定の時点で生成された画像を保持しておいて、それを最終的に生成された画像と一緒に返すようなクラスを作ります。
以下コードはかなり長いのですが、その内容を詳細に知る必要はありません。ありませんよ。ありませんからね!
ここでは主に__call__メソッドをオーバーライド(上書き)していますが、その内容はStableDiffusionPipelineクラスの__call__メソッドとほぼ同様です。やったことはだいたい次のようなことです。
実際のコードを以下に示します。かなりの部分を省略して、変更または切り出した部分を強調書体で表示しています。
class MyStableDiffusionPipeline(StableDiffusionPipeline):
def __init__(
self, vae=None, text_encoder=None, tokenizer=None, unet=None,
scheduler=None, safety_checker=None, feature_extractor=None):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor)
@torch.no_grad()
def __call__(
self, prompt, height=512, width=512, num_inference_steps=50,
guidance_scale=7.5, eta=0.0, generator=None, latents=None,
output_type='pil', return_dict=True, step=5, **kwargs):
# ……省略……
images = []
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# ……省略……
if i % step == 0:
image, _ = self.make_image(latents, output_type)
images.append(image[0])
image, has_nsfw_concept = self.make_image(latents, output_type)
if not return_dict:
return (image, has_nsfw_concept, images)
return (StableDiffusionPipelineOutput(images=image,
nsfw_content_detected=has_nsfw_concept), images)
@torch.no_grad()
def make_image(self, latents, output_type):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image),
return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image,
clip_input=safety_checker_input.pixel_values)
if output_type == 'pil':
image = self.numpy_to_pil(image)
return image, has_nsfw_concept
make_imageメソッドの内容はStableDiffusionPipelineクラスの__call__メソッドに含まれていたコードを切り出しただけなのであまり説明することはありません。要するにここで、Stable Diffusionが生成した画像の潜在空間に置ける表現をVAEのデコーダーに入力して画像として、やばい画像かどうかのチェックをしているだけだと考えてください。
StableDiffusionPipelineクラスと比べて、__call__メソッドではstepsというパラメーターが増えています(デフォルト値は5)。これは、ノイズ除去を行うループで何回ごとに生成された画像を保存しておくかを指定するものです。num_inference_stepsの値が50であればノイズ除去がおおよそ50回行われますが、このときにstepsの値が5ならループが5回実行されるたびにその時点での画像が保持されるようになるということです(実際にはスケジューラーと呼ばれる機構の仕組みによってはnum_inference_stepsの値が50だからといってループも50回とはならないこともあるので、これはあくまでもおおよその値だと考えてください)。
保持された画像はリストにまとめられていて、StableDiffusionPipelineクラスの__call__メソッドが返していた画像と後述するNSFWフィルターに引っかかったかを知らせる値に加えて、このリストを返送するようになっています。
このクラスを使って画像生成を行うコードは例えば次のようになります。
mypipe = MyStableDiffusionPipeline.from_pretrained(
'CompVis/stable-diffusion-v1-4', use_auth_token=YOUR_TOKEN)
mypipe.to('cuda')
prompt = 'a photograph of an astronaut riding a horse'
generator = torch.Generator('cuda').manual_seed(2)
result = mypipe(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator)
クラスがStableDiffusionPipelineからMyStableDiffusionPipelineクラスに変わっただけで、後は上で見たコードと同様ですね。実行結果は以下の通りです。
表示されているメッセージは要するに「やばげな画像が生成されたので、代わりに真っ黒な画像を返すよ。今度はプロンプトを変えるか、シードを変えるかして試してみてね」という意味です。Stable DiffusionではNSFW(Not Safe For Work)フィルター機能が標準で組み込まれていて、職場などで閲覧するには問題がある画像が自動的に真っ黒な画像に差し替えられるようになっています。
エッチな画像を生成しているわけではないのですが、これについては後で対処することにしましょう。
エッチな画像を生成する方法というのが話題になっていたけど、これで制限を解除できるのか。ふむふむ。
重要なのは、このパイプラインから返された値の第1要素(0始まり)には画像を含んだリストが格納されているということです。ここではStable Diffusion with Diffusersで紹介されているimage_grid関数をそのまま借用させてもらって、画像をグリッド状に並べることにしました。
from PIL import Image
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
image_grid関数を呼び出す際には、第1引数に画像を含むリストを、第2引数に画像を並べるグリッドの行数を、第3引数に列数を指定します。実はパイプラインが返す画像の数はここでは11個だったので、ここでは最後に生成された画像をリストに追加して12個の画像を含むリストとして4行3列のグリッドに画像を並べることにしました。
result[1].append(result[0].images[0])
grid = image_grid(result[1], 4, 3)
grid
実行結果を見てみましょう。
NSFWフィルターのおかげで最初の6枚が真っ黒になっていることが分かりました。が、後半の6枚は徐々にノイズが除去されて、だんだんとキレイになっていますね。
というわけで、最後にNSFWフィルターを無効化することにします。ここでは、何もチェックしない関数を定義しました。画像がそのまま帰ってくればよいので、画像と同時に返送する値はTrueで固定です。
仕様を調べていなかったのですが、Falseでもいいのかな?
def my_checker(self=mypipe, images=None, clip_input=None):
return images, True
これをパイプラインのsafety_checker属性(safety_checkerメソッド)にセットすればOKです。メソッドとするので第1パラメーターはパイプラインを参照するselfになります。通常、selfの値はメソッド呼び出し時に自動的にパイプラインを参照する値がセットされるのですが、筆者が試したところではなぜかこれがうまくいかないときがあったので、上のコードではデフォルト引数値として、MyStableDiffusionPipelineクラスのインスタンスであるmypipeを指定しています。
setattr関数を使って、この関数をパイプラインのsafety_checkerメソッドにしたら、先ほどと同様に画像を生成してみましょう。
setattr(mypipe, 'safety_checker', my_checker)
prompt = 'a photograph of an astronaut riding a horse'
generator = torch.Generator('cuda').manual_seed(2)
result = mypipe(prompt, guidance_scale=7.5, num_inference_steps=50,
generator=generator)
最後にこれをimage_grid関数でグリッド状に並べます。
result[1].append(result[0].images[0])
grid = image_grid(result[1], 4, 3)
grid
以下は実行結果です。
どうでしょう。ノイズの中から宇宙飛行士と馬が徐々に現れてくる過程がよく分かるようになりました。
確かにStable Diffusionでは純粋なノイズから徐々にノイズが除去されていき、最終的にキレイな画像が生成されることが確認できました。次回はパイプラインについてもう少し詳しく見てみることにしましょう。
Copyright© Digital Advantage Corp. All Rights Reserved.