検索
連載

「Stable Diffusion」でノイズから画像が生成される過程を確認しようStable Diffusion入門(2/2 ページ)

ホントにノイズからノイズを除去していくとキレイな画像が生成されるのか。これを今回は自分の目で確認してみましょう。

PC用表示 関連情報
Share
Tweet
LINE
Hatena
前のページへ |       

ノイズが除去されてちゃんとした画像になっていく過程の確認

 上ではStableDiffusionPipelineクラスを使って画像を生成しました。また、そこではパイプラインを関数のように呼び出すことで、画像生成が行われると述べました。ここでは、StableDiffusionPipelineクラスを継承するMyStableDiffusionPipelineクラスを定義して、そこで画像生成過程の特定の時点で生成された画像を保持しておいて、それを最終的に生成された画像と一緒に返すようなクラスを作ります。


かわさき

 以下コードはかなり長いのですが、その内容を詳細に知る必要はありません。ありませんよ。ありませんからね!


 ここでは主に__call__メソッドをオーバーライド(上書き)していますが、その内容はStableDiffusionPipelineクラスの__call__メソッドとほぼ同様です。やったことはだいたい次のようなことです。

  • 画像生成を行うループ処理中に、特定の時点でそのときに生成された画像を保持するコードを追加
  • このときに画像生成を行うコードをインスタンスメソッドとして切り出す

 実際のコードを以下に示します。かなりの部分を省略して、変更または切り出した部分を強調書体で表示しています。

class MyStableDiffusionPipeline(StableDiffusionPipeline):
    def __init__(
        self, vae=None, text_encoder=None, tokenizer=None, unet=None,
        scheduler=None, safety_checker=None, feature_extractor=None):
        super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor)

    @torch.no_grad()
    def __call__(
        self, prompt, height=512, width=512, num_inference_steps=50,
        guidance_scale=7.5, eta=0.0, generator=None, latents=None,
        output_type='pil', return_dict=True, step=5, **kwargs):

        # ……省略……

        images = []
        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):

            # ……省略……

            if i % step == 0:
                image, _ = self.make_image(latents, output_type)
                images.append(image[0])

        image, has_nsfw_concept = self.make_image(latents, output_type)

        if not return_dict:
            return (image, has_nsfw_concept, images)

        return (StableDiffusionPipelineOutput(images=image,
                    nsfw_content_detected=has_nsfw_concept), images)

    @torch.no_grad()
    def make_image(self, latents, output_type):
        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()

        safety_checker_input = self.feature_extractor(self.numpy_to_pil(image),
                        return_tensors="pt").to(self.device)
        image, has_nsfw_concept = self.safety_checker(images=image,
                        clip_input=safety_checker_input.pixel_values)

        if output_type == 'pil':
            image = self.numpy_to_pil(image)

        return image, has_nsfw_concept

MyStableDiffusionPipelineクラス

 make_imageメソッドの内容はStableDiffusionPipelineクラスの__call__メソッドに含まれていたコードを切り出しただけなのであまり説明することはありません。要するにここで、Stable Diffusionが生成した画像の潜在空間に置ける表現をVAEのデコーダーに入力して画像として、やばい画像かどうかのチェックをしているだけだと考えてください。

 StableDiffusionPipelineクラスと比べて、__call__メソッドではstepsというパラメーターが増えています(デフォルト値は5)。これは、ノイズ除去を行うループで何回ごとに生成された画像を保存しておくかを指定するものです。num_inference_stepsの値が50であればノイズ除去がおおよそ50回行われますが、このときにstepsの値が5ならループが5回実行されるたびにその時点での画像が保持されるようになるということです(実際にはスケジューラーと呼ばれる機構の仕組みによってはnum_inference_stepsの値が50だからといってループも50回とはならないこともあるので、これはあくまでもおおよその値だと考えてください)。

 保持された画像はリストにまとめられていて、StableDiffusionPipelineクラスの__call__メソッドが返していた画像と後述するNSFWフィルターに引っかかったかを知らせる値に加えて、このリストを返送するようになっています。

 このクラスを使って画像生成を行うコードは例えば次のようになります。

mypipe = MyStableDiffusionPipeline.from_pretrained(
    'CompVis/stable-diffusion-v1-4', use_auth_token=YOUR_TOKEN)
mypipe.to('cuda')

prompt = 'a photograph of an astronaut riding a horse'
generator = torch.Generator('cuda').manual_seed(2)

result = mypipe(prompt, guidance_scale=7.5, num_inference_steps=50, generator=generator)

MyStableDiffusionPipelineクラスを使って画像を生成するコード

 クラスがStableDiffusionPipelineからMyStableDiffusionPipelineクラスに変わっただけで、後は上で見たコードと同様ですね。実行結果は以下の通りです。

何やらヘンなメッセージが表示されている
何やらヘンなメッセージが表示されている

 表示されているメッセージは要するに「やばげな画像が生成されたので、代わりに真っ黒な画像を返すよ。今度はプロンプトを変えるか、シードを変えるかして試してみてね」という意味です。Stable DiffusionではNSFW(Not Safe For Work)フィルター機能が標準で組み込まれていて、職場などで閲覧するには問題がある画像が自動的に真っ黒な画像に差し替えられるようになっています。


かわさき

 エッチな画像を生成しているわけではないのですが、これについては後で対処することにしましょう。



一色

 エッチな画像を生成する方法というのが話題になっていたけど、これで制限を解除できるのか。ふむふむ。


 重要なのは、このパイプラインから返された値の第1要素(0始まり)には画像を含んだリストが格納されているということです。ここではStable Diffusion with Diffusersで紹介されているimage_grid関数をそのまま借用させてもらって、画像をグリッド状に並べることにしました。

from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

image_grid関数

 image_grid関数を呼び出す際には、第1引数に画像を含むリストを、第2引数に画像を並べるグリッドの行数を、第3引数に列数を指定します。実はパイプラインが返す画像の数はここでは11個だったので、ここでは最後に生成された画像をリストに追加して12個の画像を含むリストとして4行3列のグリッドに画像を並べることにしました。

result[1].append(result[0].images[0])
grid = image_grid(result[1], 4, 3)
grid

画像をグリッド状に並べて表示

 実行結果を見てみましょう。

上の方は真っ黒
上の方は真っ黒

 NSFWフィルターのおかげで最初の6枚が真っ黒になっていることが分かりました。が、後半の6枚は徐々にノイズが除去されて、だんだんとキレイになっていますね。

 というわけで、最後にNSFWフィルターを無効化することにします。ここでは、何もチェックしない関数を定義しました。画像がそのまま帰ってくればよいので、画像と同時に返送する値はTrueで固定です。


かわさき

 仕様を調べていなかったのですが、Falseでもいいのかな?


def my_checker(self=mypipe, images=None, clip_input=None):
    return images, True

何もチェックしないmy_checker関数

 これをパイプラインのsafety_checker属性(safety_checkerメソッド)にセットすればOKです。メソッドとするので第1パラメーターはパイプラインを参照するselfになります。通常、selfの値はメソッド呼び出し時に自動的にパイプラインを参照する値がセットされるのですが、筆者が試したところではなぜかこれがうまくいかないときがあったので、上のコードではデフォルト引数値として、MyStableDiffusionPipelineクラスのインスタンスであるmypipeを指定しています。

 setattr関数を使って、この関数をパイプラインのsafety_checkerメソッドにしたら、先ほどと同様に画像を生成してみましょう。

setattr(mypipe, 'safety_checker', my_checker)

prompt = 'a photograph of an astronaut riding a horse'
generator = torch.Generator('cuda').manual_seed(2)

result = mypipe(prompt, guidance_scale=7.5, num_inference_steps=50,
                generator=generator)

チェックを行わないようにして画像を生成

 最後にこれをimage_grid関数でグリッド状に並べます。

result[1].append(result[0].images[0])

grid = image_grid(result[1], 4, 3)
grid

image_grid関数で画像を並べる

 以下は実行結果です。

ノイズ除去の過程が分かるようになった
ノイズ除去の過程が分かるようになった

 どうでしょう。ノイズの中から宇宙飛行士と馬が徐々に現れてくる過程がよく分かるようになりました。


 確かにStable Diffusionでは純粋なノイズから徐々にノイズが除去されていき、最終的にキレイな画像が生成されることが確認できました。次回はパイプラインについてもう少し詳しく見てみることにしましょう。

「Stable Diffusion入門」のインデックス

Stable Diffusion入門

Copyright© Digital Advantage Corp. All Rights Reserved.

前のページへ |       
[an error occurred while processing this directive]
ページトップに戻る