CLIPSeg의 원리: CLIP의 텍스트-이미지 임베딩

CLIPSeg는 CLIP 모델을 기반으로 개발된 제로샷 이미지 분할 모델이다. 텍스트와 이미지의 상호 관계를 학습하여, 사용자가 입력한 텍스트 프롬프트에 맞는 이미지의 특정 부분을 자동으로 분할한다. 예를 들어, 고양이가 있는 이미지와 “고양이”라는 텍스트 프롬프트를 입력하면 이미지 속의 고양이 부분만을 자동으로 선택할 수 있다.

기존의 컴퓨터 비전 모델은 pre-trained된 작업에 대해 우수한 성능을 보인다. 하지만 새로운 작업을 할당하기 위해서는 새로운 데이터셋에 대한 학습이 필요하기 때문에 많은 비용이 든다. CLIP은 인터넷 상의 이미지와 텍스트 데이터를 수집하고, 이를 학습하여 미세조정 없이도 일반화된 이미지의 특징을 잘 추출해내는 기법을 적용하였다.

CLIP(Contrastive Language–Image Pre-training)은 텍스트와 이미지의 의미를 동일한 공간에 매핑하는 멀티모달(Multimodal) 임베딩 기법을 사용한다. 즉, 이미지와 텍스트를 연관 짓고, 서로 매칭되는 특성 공간에서 두 데이터를 이해한다. CLIP은 텍스트-이미지 쌍 데이터를 학습함으로써 특정 단어나 문장이 어떤 이미지와 연관되는지를 설명할 수 있다.

CLIPSeg의 이미지 분할 수행

  1. 텍스트 프롬프트 입력: 원하는 객체를 지칭하는 텍스트 프롬프트를 입력한다.
  2. CLIP 임베딩 생성: 텍스트 프롬프트와 이미지를 CLIP 모델의 임베딩 공간에 매핑하여 서로 비교 가능한 벡터로 변환한다. 이 과정에서, 텍스트와 이미지의 의미적 유사성을 파악할 수 있는 벡터가 생성된다.
  3. 분할 마스크 생성: 이미지의 각 픽셀이 텍스트 프롬프트와 얼마나 유사한지를 계산하고 유사도가 높은 영역만 선택하여 분할 마스크를 생성한다.

CLIPSeg의 주요 특징

  • 텍스트 기반의 이미지 분할: 복잡한 사전 훈련 없이, 간단한 텍스트-이미지 쌍 입력만으로 원하는 객체를 분할할 수 있다.
  • 범용성: 특정 데이터셋에 국한되지 않고, 다양한 이미지에 적용할 수 있다.
  • 빠른 속도와 정확도: 기존의 이미지 분할 모델에 비해 빠르고 정확하게 결과를 얻을 수 있다.

CLIPSeg의 한계

  • 복잡한 객체 분할: 단순한 객체나 잘 정의된 형태의 경우 분할이 용이하지만, 서로 겹치는 복잡한 객체들이 많은 이미지에서는 성능이 떨어질 수 있다.
  • 텍스트 표현의 한계: 프롬프트에 따라 객체를 찾는 방식으로, 추상적인 개념이나 잘 정의되지 않은 객체는 분할이 어려울 수 있다.

CLIPSeg 사용 가이드

모델 불러오기

Hugging Face 트랜스포머를 사용하면 이미지에 사전 학습된 CLIPSeg 모델을 쉽게 다운로드하고 사용할 수있다.

from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

이미지 불러오기

이제 모델을 사용하여 이미지 분할을 수행할 수 있다.

url = "https://i.pinimg.com/564x/f9/92/79/f992799d34ed72382794c1abcebeb50f.jpg"

response = requests.get(url, stream=True)

image = Image.open(response.raw)

image.show()

프롬프트 정의하기

prompts = ["tv","sofa","flowers","painting","lamps"]

모델 예측 수행

image = image.convert('RGB')

tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")

with torch.no_grad():
  outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
masks = torch.sigmoid(preds).squeeze(1)
maskss = outputs.logits

결과 시각화

fig, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[ax[i].axis('off') for i in range(len(prompts) + 1)]
ax[0].imshow(image)

for i, prompt in enumerate(prompts):
    # 예측된 히트맵
    pred_heatmap = torch.sigmoid(preds[i][0])

    # 히트맵 크기를 원본 이미지 크기로 조정
    pred_heatmap = pred_heatmap.cpu().numpy()
    pred_heatmap = cv2.resize(pred_heatmap, (image.width, image.height))

    # 히트맵 임계값을 설정해 노이즈 제거
    threshold = 0.5
    pred_heatmap[pred_heatmap < threshold] = 0
    pred_heatmap[pred_heatmap >= threshold] = 1

    # 원본 이미지에 히트 맵 적용
    heatmap = np.uint8(255 * pred_heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    overlayed_img = cv2.addWeighted(np.array(image), 0.5, heatmap, 0.5, 0)

    # 오버레이된 이미지 표시
    ax[i+1].imshow(overlayed_img)
    ax[i+1].set_title(prompt)
    
plt.show()

image