📜  Facebook 使用检测转换器 (DETR) 进行对象检测

📅  最后修改于: 2022-05-13 01:54:41.981000             🧑  作者: Mango

Facebook 使用检测转换器 (DETR) 进行对象检测

Facebook 刚刚于 2020 年 5 月 27 日发布了其最先进的对象检测模型。他们称它为 DERT 代表检测转换器,因为它使用转换器来检测对象。这是转换器第一次用于此类对象检测任务以及卷积神经网络。还有其他对象检测模型,例如 RCNN 系列、YOLO(You Look Only Once)和 SSD(Single Shot Detection),但它们都没有使用变压器来完成这项任务。这个模型最好的部分是,由于它使用了一个转换器,它使得架构非常简单,不像提到的所有其他技术都有各种超参数和层。因此,无需进一步告别,让我们开始吧。
什么是物体检测?
给定一张照片,如果您需要确定照片是否有单个特定对象,您可以通过分类来完成。但是,如果您想在图像中也获得该对象的位置……那么即使这不是对象检测任务……它也称为分类和定位。但是,如果图像中有多个对象,并且您想要每个对象的每个位置的位置,那么这就是对象检测。
之前的一些技术试图让 RPN(区域提议网络)提出可能包含对象的潜在区域,然后我们可以使用锚框、NMS(非最大抑制)和 IOU 的概念来生成相关的框并识别对象。尽管这些概念有效,但推理需要一些时间,因此由于其复杂性,无法实现高精度的实时使用。
在高层次上,这使用 CNN,然后使用转换器来检测对象,并且通过二分匹配训练对象来检测对象。这就是它如此简单的主要原因。

来源 - https://arxiv.org/pdf/2005.12872.pdf

第1步:
我们将图像通过卷积神经网络编码器进行处理,因为 CNN 最适合处理图像。所以通过CNN后,图像特征得到了保留。这是具有更多特征通道的图像的高阶表示。
第2步:
图像的这个丰富的特征图被提供给一个转换器编码器 - 解码器,它输出一组框预测。这些盒子中的每一个都由一个元组组成。元组将是一个类和一个边界框。注意:这也包括类 NULL 或 Nothing 类及其位置。
现在,这是一个真正的问题,因为在注释中没有对象类被注释为空。比较和处理彼此相邻的相似对象是另一个主要问题,在本文中,它通过使用二分匹配损失来解决。通过将每个类和边界框与其对应的类和包含无类的框(假设为 N)与包含添加的不包含任何内容的部分的注释进行比较来比较损失,从而使总框为 N。预测到实际是一对一的分配,以使总损失最小化。有一种非常著名的算法称为匈牙利方法来计算这些最小匹配。
主要成分:

来源 - https://arxiv.org/pdf/2005.12872.pdf

主干——从卷积神经网络中提取的特征和位置编码被传递
变压器编码器——变压器自然是一个序列处理单元,出于同样的原因,我们输入的张量是扁平的。它将序列转换为同样长的特征序列。
Transformer 解码器——接受对象查询,因此它是一个解码器,作为条件信息的侧输入。
预测前馈网络 (FFN) –输出通过分类器输出前面讨论的类标签和边界框输出
评估员:
评估是在 COCO 数据集上完成的,它的主要竞争对手是 RCNN 家族,它在一段时间内统治了这一类别,被认为是最经典的目标检测技术。

来源 - https://arxiv.org/pdf/2005.12872.pdf

优点:

  • 这个新模型非常简单,您无需安装任何库即可使用它。
  • DETR 在大对象上表现出明显更好的性能,而不是在可以进一步改进的小对象上。
  • 好消息是他们甚至在论文中提供了代码,所以现在我们也将实现它以了解它的真正能力。

代码:

Python3
# Write Python3 code here
import torch
from torch import nn
from torchvision.models import resnet50
 
class DETR(nn.Module):
 
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
  super().__init__()
  # We take only convolutional layers from ResNet-50 model
  self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
  self.conv = nn.Conv2d(2048, hidden_dim, 1)
  self.transformer = nn.Transformer(hidden_dim, heads,
  num_encoder_layers, num_decoder_layers)
  self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
  self.linear_bbox = nn.Linear(hidden_dim, 4)
  self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
  self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
  self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
  def forward(self, inputs):
  x = self.backbone(inputs)
  h = self.conv(x)
  H , W = h.shape[-2:]
  pos = torch.cat([
  self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
  self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
  h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
  self.query_pos.unsqueeze(1))
  return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
 
 
Listing 1: DETR PyTorch inference code. For clarity, it uses learned positional encodings in the encoder instead of fixed, and positional encodings are added to the input
only instead of at each transformer layer. Making these changes requires going beyond
PyTorch implementation of transformers, which hampers readability. The entire code
to reproduce the experiments will be made available before the conference.


Python3
import torch as th
import torchvision.transforms as T
import requests
from PIL import Image, ImageDraw, ImageFont


Python3
model = th.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True)
model.eval()
model = model.cuda()


Python3
# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
 
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]


Python3
url = input()


Python3
img = Image.open(requests.get(url, stream=True).raw).resize((800,600)).convert('RGB')
img


Python3
img_tens = transform(img).unsqueeze(0).cuda()
with th.no_grad():
  output = model(img_tens)
 
draw = ImageDraw.Draw(img)
pred_logits=output['pred_logits'][0][:, :len(CLASSES)]
pred_boxes=output['pred_boxes'][0]
 
max_output = pred_logits.softmax(-1).max(-1)
topk = max_output.values.topk(15)
 
pred_logits = pred_logits[topk.indices]
pred_boxes = pred_boxes[topk.indices]
pred_logits.shape


Python3
for logits, box in zip(pred_logits, pred_boxes):
  cls = logits.argmax()
  if cls >= len(CLASSES):
    continue
  label = CLASSES[cls]
  print(label)
  box = box.cpu() * th.Tensor([800, 600, 800, 600])
  x, y, w, h = box
  x0, x1 = x-w//2, x+w//2
  y0, y1 = y-h//2, y+h//2
  draw.rectangle([x0, y0, x1, y1], outline='red', width=5)
  draw.text((x, y), label, fill='white')


Python3
img


我们只从 ResNet-50 模型中提取卷积层
代码取自论文
代码:尝试在 colab 上运行此代码,或者直接访问此链接,复制并运行完整的文件。

Python3

import torch as th
import torchvision.transforms as T
import requests
from PIL import Image, ImageDraw, ImageFont

我们将使用 ResNet 101 作为主干架构,我们将直接从 Pytorch Hub 加载该架构。
代码:

Python3

model = th.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True)
model.eval()
model = model.cuda()

Python3

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
 
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

在此处输入图像的 URL。我使用的是 https://i.ytimg.com/vi/vrlX3cwr3ww/maxresdefault.jpg
代码:

Python3

url = input()

显示图像

Python3

img = Image.open(requests.get(url, stream=True).raw).resize((800,600)).convert('RGB')
img

代码:

Python3

img_tens = transform(img).unsqueeze(0).cuda()
with th.no_grad():
  output = model(img_tens)
 
draw = ImageDraw.Draw(img)
pred_logits=output['pred_logits'][0][:, :len(CLASSES)]
pred_boxes=output['pred_boxes'][0]
 
max_output = pred_logits.softmax(-1).max(-1)
topk = max_output.values.topk(15)
 
pred_logits = pred_logits[topk.indices]
pred_boxes = pred_boxes[topk.indices]
pred_logits.shape

代码:

Python3

for logits, box in zip(pred_logits, pred_boxes):
  cls = logits.argmax()
  if cls >= len(CLASSES):
    continue
  label = CLASSES[cls]
  print(label)
  box = box.cpu() * th.Tensor([800, 600, 800, 600])
  x, y, w, h = box
  x0, x1 = x-w//2, x+w//2
  y0, y1 = y-h//2, y+h//2
  draw.rectangle([x0, y0, x1, y1], outline='red', width=5)
  draw.text((x, y), label, fill='white')

代码:显示检测到的图像

Python3

img

这是 colab 笔记本和 github 代码的链接。另外,请随时查看官方 GitHub 以获得相同的信息
缺点:
训练需要很长时间。它在 8 个 GPU 上训练了六天。当您将它与这种规模的语言模型进行比较时,它并没有那么多,因为它们使用了转换器,但仍然如此。