CLIPモデルとは、OpenAIによって公開されたAIモデルで、一言でいうと、画像とテキストの関係性を理解できるモデルです。
画像とテキストを多次元の特徴ベクトルに変換することで、両者の類似度を評価することができます。
今回、Githubで公開されている学習済みモデルを利用して、入力テキストの特徴と類似した画像をデータベースから検索するツールを作ってみようと思います。

CLIPについての詳しい原理は こちら

環境

  • Python 3.11
    • FastAPI 0.104.1
    • PyTorch 2.1.1
    • Transformers 4.35.2
    • SQLAlchemy 2.0.23
  • PostgreSQL 15

画像の登録・検索をWeb API 化したかったのでFastAPIを利用し、
登録した画像はCLIPモデルによってベクトルに変換してDBに保存することにしました。

実装

今回、画像はURLで渡して保存、検索した際には保存したURLを返すことにします。
CLIPモデルを利用した処理部分(画像のベクトル化、テキストのベクトル化、類似度計算)はまとめてクラスに実装して、APIが呼び出された時には適宜そのクラスのメソッドを使うといった方針で実装します。

CLIPモデルクラス

import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from typing import List, Union
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CLIPService:
"""CLIP画像・テキスト処理サービス"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
"""
CLIPモデルの初期化
Args:
model_name: 使用するCLIPモデル名
"""
logger.info(f"Loading CLIP model: {model_name}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.model.eval()
logger.info("CLIP model loaded successfully")
def load_image_from_url(self, url: str) -> Image.Image:
"""
URLから画像をダウンロード
Args:
url: 画像URL
Returns:
PIL Image
"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
except Exception as e:
logger.error(f"Failed to load image from {url}: {e}")
raise
def encode_image(self, image: Union[str, Image.Image]) -> np.ndarray:
"""
画像を特徴ベクトルに変換
Args:
image: 画像URL または PIL Image
Returns:
特徴ベクトル (512次元)
"""
try:
if isinstance(image, str):
image = self.load_image_from_url(image)
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
# 正規化
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Failed to encode image: {e}")
raise
def encode_text(self, text: str) -> np.ndarray:
"""
テキストを特徴ベクトルに変換
Args:
text: テキストクエリ
Returns:
特徴ベクトル (512次元)
"""
try:
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
# 正規化
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Failed to encode text: {e}")
raise
@staticmethod
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
"""
コサイン類似度を計算
Args:
vec1: ベクトル1
vec2: ベクトル2
Returns:
コサイン類似度 [-1, -1]
"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
# グローバルインスタンス(シングルトン)
_clip_service = None
def get_clip_service() -> CLIPService:
"""CLIPServiceのシングルトンインスタンスを取得"""
global _clip_service
if _clip_service is None:
_clip_service = CLIPService()
return _clip_service

APIの処理の中で、このクラスをimport、get_clip_service() でクラスインスタンスを作成してもらい、必要に応じて各々メソッドを利用することを想定しています。

利用例:

  • 画像登録API: encode_image で画像をベクトルに変換
  • 画像検索API: encode_text で入力テキストをベクトルに変換、cosine_similarity で登録されている画像と入力テキストとの類似度を計算

次回予告

今回作ったCLIPモデルクラスを利用して、実際にWebAPIを作ってみようと思います。
ゆくゆくは実際にサーバーに載せてみて動かしてみたいな!

実はこのツールを使って最終的にできたらいいなと考えていることがあるので密かに応援いただけると幸いです。がんばります(`・ω・´)



ギャップロを運営しているアップフロンティア株式会社では、一緒に働いてくれる仲間を随時、募集しています。 興味がある!一緒に働いてみたい!という方は下記よりご応募お待ちしております。
採用情報をみる