CLIPモデルとは、OpenAIによって公開されたAIモデルで、一言でいうと、画像とテキストの関係性を理解できるモデルです。
画像とテキストを多次元の特徴ベクトルに変換することで、両者の類似度を評価することができます。
今回、Githubで公開されている学習済みモデルを利用して、入力テキストの特徴と類似した画像をデータベースから検索するツールを作ってみようと思います。
CLIPについての詳しい原理は こちら
環境
- Python 3.11
- FastAPI 0.104.1
- PyTorch 2.1.1
- Transformers 4.35.2
- SQLAlchemy 2.0.23
- PostgreSQL 15
画像の登録・検索をWeb API 化したかったのでFastAPIを利用し、
登録した画像はCLIPモデルによってベクトルに変換してDBに保存することにしました。
実装
今回、画像はURLで渡して保存、検索した際には保存したURLを返すことにします。
CLIPモデルを利用した処理部分(画像のベクトル化、テキストのベクトル化、類似度計算)はまとめてクラスに実装して、APIが呼び出された時には適宜そのクラスのメソッドを使うといった方針で実装します。
CLIPモデルクラス
import torch from transformers import CLIPProcessor, CLIPModelfrom PIL import Imageimport requestsfrom io import BytesIOimport numpy as npfrom typing import List, Unionimport logginglogging.basicConfig(level=logging.INFO)logger = logging.getLogger(__name__)class CLIPService: """CLIP画像・テキスト処理サービス""" def __init__(self, model_name: str = "openai/clip-vit-base-patch32"): """ CLIPモデルの初期化 Args: model_name: 使用するCLIPモデル名 """ logger.info(f"Loading CLIP model: {model_name}") self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") self.model = CLIPModel.from_pretrained(model_name).to(self.device) self.processor = CLIPProcessor.from_pretrained(model_name) self.model.eval() logger.info("CLIP model loaded successfully") def load_image_from_url(self, url: str) -> Image.Image: """ URLから画像をダウンロード Args: url: 画像URL Returns: PIL Image """ try: response = requests.get(url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") return image except Exception as e: logger.error(f"Failed to load image from {url}: {e}") raise def encode_image(self, image: Union[str, Image.Image]) -> np.ndarray: """ 画像を特徴ベクトルに変換 Args: image: 画像URL または PIL Image Returns: 特徴ベクトル (512次元) """ try: if isinstance(image, str): image = self.load_image_from_url(image) inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): image_features = self.model.get_image_features(**inputs) # 正規化 image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy().flatten() except Exception as e: logger.error(f"Failed to encode image: {e}") raise def encode_text(self, text: str) -> np.ndarray: """ テキストを特徴ベクトルに変換 Args: text: テキストクエリ Returns: 特徴ベクトル (512次元) """ try: inputs = self.processor(text=[text], return_tensors="pt", padding=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): text_features = self.model.get_text_features(**inputs) # 正規化 text_features = text_features / text_features.norm(dim=-1, keepdim=True) return text_features.cpu().numpy().flatten() except Exception as e: logger.error(f"Failed to encode text: {e}") raise @staticmethod def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float: """ コサイン類似度を計算 Args: vec1: ベクトル1 vec2: ベクトル2 Returns: コサイン類似度 [-1, -1] """ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))# グローバルインスタンス(シングルトン)_clip_service = Nonedef get_clip_service() -> CLIPService: """CLIPServiceのシングルトンインスタンスを取得""" global _clip_service if _clip_service is None: _clip_service = CLIPService() return _clip_service
APIの処理の中で、このクラスをimport、get_clip_service() でクラスインスタンスを作成してもらい、必要に応じて各々メソッドを利用することを想定しています。
利用例:
- 画像登録API: encode_image で画像をベクトルに変換
- 画像検索API: encode_text で入力テキストをベクトルに変換、cosine_similarity で登録されている画像と入力テキストとの類似度を計算
次回予告
今回作ったCLIPモデルクラスを利用して、実際にWebAPIを作ってみようと思います。
ゆくゆくは実際にサーバーに載せてみて動かしてみたいな!
実はこのツールを使って最終的にできたらいいなと考えていることがあるので密かに応援いただけると幸いです。がんばります(`・ω・´)








