https://github.com/caoyunkang/AdaCLIP

Static Prompts & Dynamic Prompts 정의


if self.enabled: # only when enabled, the parameters should be constructed
            if 'S' in prompting_type: # static prompts
                # learnable
                self.static_prompts = nn.ParameterList(
                    [nn.Parameter(torch.empty(self.length, self.channel))
                     for _ in range(self.depth)])

                for single_para in self.static_prompts:
                    nn.init.normal_(single_para, std=0.02)

            if 'D' in prompting_type: # dynamic prompts
                self.dynamic_prompts = [0.] # place holder

    def set_dynamic_prompts(self, dynamic_prompts):
        self.dynamic_prompts = dynamic_prompts

Loss 계산


def train_one_batch(self, items):
    image = items['img'].to(self.device)
    cls_name = items['cls_name']

    # pixel level
    anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=False)

    if not isinstance(anomaly_map, list):
        anomaly_map = [anomaly_map]

    # losses
    gt = items['img_mask'].to(self.device)
    gt = gt.squeeze()

    gt[gt > 0.5] = 1
    gt[gt <= 0.5] = 0

    is_anomaly = items['anomaly'].to(self.device)
    is_anomaly[is_anomaly > 0.5] = 1
    is_anomaly[is_anomaly <= 0.5] = 0
    loss = 0

    # classification loss
    classification_loss = self.loss_focal(anomaly_score, is_anomaly.unsqueeze(1))
    loss += classification_loss

    # seg loss
    seg_loss = 0
    for am, in zip(anomaly_map):
        seg_loss += (self.loss_focal(am, gt) + self.loss_dice(am[:, 1, :, :], gt) +
                     self.loss_dice(am[:, 0, :, :], 1-gt))

    loss += seg_loss

    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    return loss

[ DL ] Focal Loss(Focal Loss for Dense Object Detection)