https://github.com/caoyunkang/AdaCLIP
if self.enabled: # only when enabled, the parameters should be constructed
if 'S' in prompting_type: # static prompts
# learnable
self.static_prompts = nn.ParameterList(
[nn.Parameter(torch.empty(self.length, self.channel))
for _ in range(self.depth)])
for single_para in self.static_prompts:
nn.init.normal_(single_para, std=0.02)
if 'D' in prompting_type: # dynamic prompts
self.dynamic_prompts = [0.] # place holder
def set_dynamic_prompts(self, dynamic_prompts):
self.dynamic_prompts = dynamic_prompts
prompting_type
이 ‘SD’가 default인데, 이 경우 두 if문이 모두 실행torch.empty()
: 주어진 크기의 텐서를 생성
nn.Parameter()
: 텐서를 wrapping하여 PyTorch 모델의 학습 가능한 매개변수로 등록하는 클래스
for _ in range(self.depth)
: list comprehension
nn.ParameterList()
: PyTorch의 Module 클래스 내에서 매개변수를 리스트 형태로 관리
nn.init.normal_(single_para, std=0.02)
: 정규분포를 통해 신경망의 매개변수를 초기화self.dynamic_prompts = [0.]
: dynamic prompt를 실수 0.0 하나가 들어있는 리스트로 초기화
def train_one_batch(self, items):
image = items['img'].to(self.device)
cls_name = items['cls_name']
# pixel level
anomaly_map, anomaly_score = self.clip_model(image, cls_name, aggregation=False)
if not isinstance(anomaly_map, list):
anomaly_map = [anomaly_map]
# losses
gt = items['img_mask'].to(self.device)
gt = gt.squeeze()
gt[gt > 0.5] = 1
gt[gt <= 0.5] = 0
is_anomaly = items['anomaly'].to(self.device)
is_anomaly[is_anomaly > 0.5] = 1
is_anomaly[is_anomaly <= 0.5] = 0
loss = 0
# classification loss
classification_loss = self.loss_focal(anomaly_score, is_anomaly.unsqueeze(1))
loss += classification_loss
# seg loss
seg_loss = 0
for am, in zip(anomaly_map):
seg_loss += (self.loss_focal(am, gt) + self.loss_dice(am[:, 1, :, :], gt) +
self.loss_dice(am[:, 0, :, :], 1-gt))
loss += seg_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
Focal
+ Dice
am[:, 1, :, :]
: 모델이 예측한 비정상 픽셀의 확률 맵Focal
is_anomaly.unsqueeze(1)
: is_anomaly의 차원을 하나 늘려줌