From 4c8b71456f987eacd0b1eeeed7e1f2e5f7f807c0 Mon Sep 17 00:00:00 2001 From: goaT <11860652+goaT1031@user.noreply.gitee.com> Date: Thu, 27 Oct 2022 12:23:13 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20=E9=9F=A9=E9=98=B3-?= =?UTF-8?q?=E5=9F=BA=E4=BA=8E=E8=87=AA=E7=9B=91=E7=9D=A3=E5=AD=A6=E4=B9=A0?= =?UTF-8?q?=E7=9A=84=E5=93=A8=E5=85=B5=E4=BA=8C=E5=8F=B7=E9=81=A5=E6=84=9F?= =?UTF-8?q?=E5=BD=B1=E5=83=8F=E5=8F=98=E5=8C=96=E6=A3=80=E6=B5=8B=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../.keep" | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 "code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/.keep" diff --git "a/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/.keep" "b/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/.keep" new file mode 100644 index 0000000..e69de29 -- Gitee From d61b2cb447ddcd1d45bb725392b155ab7609d16d Mon Sep 17 00:00:00 2001 From: goaT <11860652+goaT1031@user.noreply.gitee.com> Date: Thu, 27 Oct 2022 12:48:00 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=E9=9F=A9=E9=98=B3-=E5=9F=BA=E4=BA=8E?= =?UTF-8?q?=E8=87=AA=E7=9B=91=E7=9D=A3=E5=AD=A6=E4=B9=A0=E7=9A=84=E5=93=A8?= =?UTF-8?q?=E5=85=B5=E4=BA=8C=E5=8F=B7=E9=81=A5=E6=84=9F=E5=BD=B1=E5=83=8F?= =?UTF-8?q?=E5=8F=98=E5=8C=96=E6=A3=80=E6=B5=8B=E6=96=B9=E6=B3=95=20train?= =?UTF-8?q?=5FBYOL=E6=98=AF=E4=B8=BB=E4=BD=93=E5=87=BD=E6=95=B0,mydataset?= =?UTF-8?q?=E6=98=AF=E8=B0=83=E7=94=A8=E6=95=B0=E6=8D=AE=E6=96=87=E4=BB=B6?= =?UTF-8?q?=EF=BC=8Cresnet=E6=98=AFBYOL=E8=B0=83=E7=94=A8=E7=9A=84backbone?= =?UTF-8?q?=20=E6=95=B0=E6=8D=AE=E4=BD=BF=E7=94=A8=E7=9A=84=E6=98=AFOSCD?= =?UTF-8?q?=E5=8F=98=E5=8C=96=E6=A3=80=E6=B5=8B=E6=95=B0=E6=8D=AE=E9=9B=86?= =?UTF-8?q?=EF=BC=8C=E6=AD=A4=E5=A4=84=E5=A4=A7=E5=B0=8F=E5=8F=97=E9=99=90?= =?UTF-8?q?=EF=BC=8C=E6=97=A0=E6=B3=95=E4=B8=8A=E4=BC=A0=EF=BC=8C=E4=B8=8B?= =?UTF-8?q?=E8=BD=BD=E9=93=BE=E6=8E=A5=E5=A6=82=E4=B8=8B:https://ieee-data?= =?UTF-8?q?port.org/open-access/oscd-onera-satellite-change-detection#file?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: goaT <11860652+goaT1031@user.noreply.gitee.com> --- .../mydataset.py" | 98 ++++ .../resnet.py" | 204 +++++++ .../train_BYOL.py" | 546 ++++++++++++++++++ 3 files changed, 848 insertions(+) create mode 100644 "code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/mydataset.py" create mode 100644 "code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/resnet.py" create mode 100644 "code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/train_BYOL.py" diff --git "a/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/mydataset.py" "b/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/mydataset.py" new file mode 100644 index 0000000..c69b6a3 --- /dev/null +++ "b/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/mydataset.py" @@ -0,0 +1,98 @@ +import torch.nn.functional as F +import torch +import torch +import torch.nn as nn +from torch.autograd import Variable +import torchvision.models as models +from torchvision import transforms, utils +from torch.utils.data import Dataset, DataLoader +from PIL import Image +import numpy as np +import torch.optim as optim +import glob +import rasterio +import os + +# torch.cuda.set_device(gpu_id)#使用GPU +learning_rate = 0.0001 + + +root = os.getcwd() + '/full_data/' # 调用图像 + +class RandomCrop(object): + """ + Args: + output_size (tuple or int): 期望裁剪的图片大小。如果是 int,将得到一个正方形大小的图片. + """ + + def __init__(self, output_size): + assert isinstance(output_size, (int, tuple)) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + assert len(output_size) == 2 + self.output_size = output_size + + def __call__(self, image): + + + + _, h, w = image.shape + new_h, new_w = self.output_size + + top = np.random.randint(0, h - new_h) + left = np.random.randint(0, w - new_w) + + image = image[:, top: top + new_h, left: left + new_w] + + return image + + + + + + + +def default_loader(path): + with rasterio.open(path) as data: + s2 = data.read() + return s2 + + + +class MyDataset(Dataset): + def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): + super(MyDataset, self).__init__() + fh = open(txt, 'r') + imgs = [] + for line in fh: + line = line.strip('\n') + line = line.rstrip('\n') + words = line.split() + imgs.append(words) + + self.imgs = imgs + self.transform = transform + self.target_transform = target_transform + self.loader = loader + if transform: + self.transform = RandomCrop(400) + else: + self.transform = None + + def __getitem__(self, index): + fn = self.imgs[index] + img = self.loader(fn[0]) + if self.transform is not None: + img = self.transform(img) + img = img.astype(np.int16) + img = torch.from_numpy(img) + + img = img.float() + + return img + + def __len__(self): + return len(self.imgs) + + diff --git "a/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/resnet.py" "b/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/resnet.py" new file mode 100644 index 0000000..3031265 --- /dev/null +++ "b/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/resnet.py" @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +import math + +__all__ = ['ResNet', ] + +class MLPHead(nn.Module): + def __init__(self, in_channels, mlp_hidden_size, projection_size): + super(MLPHead, self).__init__() + + self.net = nn.Sequential( + nn.Linear(in_channels, mlp_hidden_size), + nn.BatchNorm1d(mlp_hidden_size), + nn.ReLU(inplace=True), + nn.Linear(mlp_hidden_size, projection_size) + ) + + def forward(self, x): + return self.net(x) + +class PreMLPHead(nn.Module): + def __init__(self, in_channels, mlp_hidden_size, projection_size): + super(PreMLPHead, self).__init__() + + self.net = nn.Sequential( + nn.Linear(in_channels, mlp_hidden_size), + nn.BatchNorm1d(mlp_hidden_size), + nn.ReLU(inplace=True), + nn.Linear(mlp_hidden_size, projection_size) + ) + + def forward(self, x): + return self.net(x) + + +def Net6(**kwargs): + return NetN(BasicBlock, [2, 2, 2, 2], **kwargs) + +def ResNet18(**kwargs): + return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + +def ResNet34(**kwargs): + return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_feats=128, width=1, in_channel=1, mi=False): + super(ResNet, self).__init__() + + self._norm_layer = nn.BatchNorm2d + self.inplanes = max(int(64 * width), 64) + self.base = int(64 * width) + self.layer0 = nn.Sequential( + nn.Conv2d(in_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + self._norm_layer(self.inplanes), + nn.ReLU(inplace=True)) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, self.base, layers[0]) + self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=1) + self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=1) + self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if mi: + self.final = nn.Linear(512 * block.expansion, num_feats) + else: + self.final = MLPHead(512 * block.expansion, 4096, num_feats) + + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + norm_layer = self._norm_layer + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion)) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + return nn.Sequential(*layers) + + def forward(self, x): + + x = self.layer0(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.final(x) + + return x + + +class NetN(nn.Module): + + def __init__(self, block, layers, num_feats=64, width=1, in_channel=3, mi=False): + super(NetN, self).__init__() + + self._norm_layer = nn.BatchNorm2d + self.inplanes = max(int(64 * width), 64) + self.base = int(64 * width) + self.layer0 = nn.Sequential( + nn.Conv2d(in_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), + self._norm_layer(self.inplanes), + nn.ReLU(inplace=True)) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, self.base, layers[0], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if mi: + self.final = nn.Linear(self.base, num_feats) + else: + self.final = MLPHead(self.base, self.base * 2, num_feats) + + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + norm_layer = self._norm_layer + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion)) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) + return nn.Sequential(*layers) + + def forward(self, x): + + x = self.layer0(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.final(x) + + return x \ No newline at end of file diff --git "a/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/train_BYOL.py" "b/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/train_BYOL.py" new file mode 100644 index 0000000..da26e3d --- /dev/null +++ "b/code/2022_autumn/\351\237\251\351\230\263-\345\237\272\344\272\216\350\207\252\347\233\221\347\235\243\345\255\246\344\271\240\347\232\204\345\223\250\345\205\265\344\272\214\345\217\267\351\201\245\346\204\237\345\275\261\345\203\217\345\217\230\345\214\226\346\243\200\346\265\213\346\226\271\346\263\225/train_BYOL.py" @@ -0,0 +1,546 @@ +from __future__ import print_function + +import os +import copy +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +import argparse +from torch.utils.data import DataLoader +from mydataset import MyDataset +from models.resnet import ResNet34,PreMLPHead +from collections import Counter +import cmath +import visdom +import matplotlib.image as mpimg + + + + +root = os.getcwd() + '/full_data/' # 调用图像 + +vis = visdom.Visdom(env='BYOL') +opt1 = { + 'xlabel': 'epochs', # 横坐标名称 + 'ylabel': 'loss_value', # 纵坐标名称 + 'title': 'loss' # 标题是loss + } +#在环境窗口建立一个新的图窗,图窗标题是loss +loss_window = vis.line( + X=[0], + Y=[0], + opts=opt1 +) + +opt2={ + 'xlabel': 'epochs', # 横坐标名称 + 'ylabel': 'loss_value', # 纵坐标名称 + 'title': 'acc' ,# 标题是loss + 'legend': ['Pre', 'Rec','F1','OA','Kappa'] + } +acc_window = vis.line( + X=[0.0], + Y=[[0.0,0.0,0.0,0.0,0.0]], + opts=opt2 +) + + + +def parse_option(): + + parser = argparse.ArgumentParser('argument for training') + # load the original big image (if just one, it should be big enough) + parser.add_argument('--batch_size', type=int, default=1, help='batch_size for data training') + parser.add_argument('--crop_size', type=int, default=200, help='crop_size for ensuring same patch_size within all batches') + parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use') # ? + parser.add_argument('--epochs', type=int, default=500, help='number of training epochs') + # split image into small patches + parser.add_argument('--patch_size', type=int, default=8, help='patch_size for training') + parser.add_argument('--unfold_stride', type=int, default=4, help='stride during the training patches') + parser.add_argument('--val_patch_size', type=int, default=8, help='patch_size for inference') + parser.add_argument('--val_unfold_stride', type=int, default=4, help='stride during the inference patches') + parser.add_argument('--pbatch_size', type=int, default=256, help='batch_size of patches during inference') + # optimization + parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate') + parser.add_argument('--weight_decay', type=float, default=1.5e-6, help='weight decay') + parser.add_argument('--momentum', type=float, default=0.996, help='momentum') + # resume and test + parser.add_argument('--resume', default=False, type=bool, help='flag for training from checkpoint') # ? + parser.add_argument('--test', default=False, type=bool, help='flag for testing on test data set') # ? + # model definition + parser.add_argument('--model', type=str, default='resnet18', choices=['resnet18', 'Net6']) + parser.add_argument('--feat_dim', type=int, default=256, help='dim of feat for inner product') # ? + # input/output for data use + parser.add_argument('--use_s2hr', action='store_true', default=True, help='use sentinel-2 high-resolution (10 m) bands') + parser.add_argument('--use_s2mr', action='store_true', default=False, help='use sentinel-2 medium-resolution (20 m) bands') + parser.add_argument('--use_s2lr', action='store_true', default=False, help='use sentinel-2 low-resolution (60 m) bands') + parser.add_argument('--use_s1', action='store_true', default=True, help='use sentinel-1 data') + parser.add_argument('--no_savanna', action='store_true', default=False, help='ignore class savanna') # ? + + # specify folder + parser.add_argument('--data_dir_train', type=str, default='./InferS2-all', help='path to training data set') + parser.add_argument('--data_dir_eval', type=str, default='./InferS2', help='path to test data set') + parser.add_argument('--data_dir_ground_truth', type=str, default='./ground_truth', help='path to ground truth') + parser.add_argument('--save_path', type=str, default='./save_BYOL', help='path to save model') + parser.add_argument('--eval_freq', type=int, default=2, help='print frequency') + parser.add_argument('--save_freq', type=int, default=10, help='save frequency') + + opt = parser.parse_args() + + + return opt + +def rosin_thresholding(difference_img): + difference_img = torch.div((difference_img - difference_img.min()), (difference_img.max() - difference_img.min())) + difference_img = 255 * difference_img + difference_img = difference_img.round() + difference = difference_img.numpy() + difference = difference.astype('int32') + + c = Counter(difference.flatten()) + r = c.most_common(1) + i, j = r[0] + for tt in range(i, 256): + if c[tt] == 0: + m, n = tt, c[tt] + break + m, n = 255, c[255] + A = (n - j) / (m - i) + B = -1 + C = j - A * i + d = np.zeros(256, dtype=int) + r = c.most_common(256) + t_max = np.shape(r)[0] + for t in range(i, m): + # r=c.most_common(t_max) + x, y = t, c[t] + d[t] = abs(A * x + B * y + C) / cmath.sqrt(A ** 2 + B ** 2) + max_d = d.argmax() + difference[difference < max_d] = 0 + difference[difference >= max_d] = 1 + + + return difference + + +def change_map(difference_img): + + threshold_init = torch.mean(difference_img) + 3 * torch.std(difference_img) + difference_img = torch.where(difference_img < threshold_init, difference_img, threshold_init) + + difference_img = (difference_img - difference_img.mean()) / difference_img.std() + + threshold = difference_img.min().abs() + + return difference_img >= threshold + +def get_train_val_loader(args): + train_set = MyDataset(txt=root + 'train.txt', transform='True') + eval_set = MyDataset(txt=root + 'eval.txt', transform=None) + + train_loader = DataLoader(train_set, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True) + + eval_loader = DataLoader(eval_set, + batch_size=1, + shuffle=True, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False) + + + return train_loader, eval_loader + + + +class BYOLTrainer: + def __init__(self, args, online_network, target_network, predictor, optimizer,scheduler, device): + + + self.online_network = online_network + self.target_network = target_network + self.optimizer = optimizer + self.scheduler = scheduler + self.device = device + self.savepath = args.save_path + self.predictor = predictor + self.max_epochs = args.epochs + self.m = 0.996 + self.n_classes = args.n_classes + self.patch_size = args.patch_size + self.pbatch_size = args.pbatch_size + self.unfold_stride = args.unfold_stride + self.val_patch_size = args.val_patch_size + self.val_unfold_stride = args.val_unfold_stride + self.eval_freq = args.eval_freq + self.save_freq = args.save_freq + self.ground_truth=args.data_dir_ground_truth + + @torch.no_grad() + def _update_target_network_parameters(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): + param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + + @staticmethod + def regression_loss(x, y): + + x = F.normalize(x, dim=1) + y = F.normalize(y, dim=1) + + l=torch.norm((x-y), p='fro', dim=None, keepdim=False, out=None, dtype=None) + + + return l + + + + + def initializes_target_network(self): + + for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + def train(self, train_loader,eval_loader): + + niter = 0 + + args = parse_option() + + if torch.cuda.is_available(): + args.use_gpu = True + device = 'cuda' + else: + args.use_gpu = False + device = 'cpu' + + self.initializes_target_network() + + for epoch_counter in range(self.max_epochs): + + train_loss = 0.0 + + for idx, batch in enumerate(train_loader): + + batch = batch.to(device) + patches = self.patchize(batch, self.patch_size, self.unfold_stride) # img==>patch + P, C, pH, pW = patches.shape + + shuffle_ids = torch.randperm(P).cuda() + + this_patches = patches[shuffle_ids] + + quotient, remainder = divmod(P, self.pbatch_size) + pbatch = quotient if quotient > 0 else remainder + for i in range(pbatch): + start = i * self.pbatch_size + end = start + self.pbatch_size + + patch = this_patches[start:end, :, :, :] + + + loss = self.update(patch) + + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self._update_target_network_parameters() # update the key encoder + + niter += 1 + train_loss += loss.item() + + train_loss = train_loss / self.pbatch_size + print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch_counter, train_loss)) + self.scheduler.step(train_loss) + vis.line(X=[epoch_counter], Y=[train_loss], win=loss_window, opts=opt1, update='append') + + + + + torch.cuda.empty_cache() + + # evaluate the model + if (epoch_counter + 1) % self.eval_freq == 0: + self.validate(eval_loader,epoch_counter) + self.online_network.train() + + # save checkpoints + if (epoch_counter + 1) % self.save_freq == 0: + self.save_model(os.path.join(self.savepath, 'BOYL_epoch_{epoch}_{loss}.pth'. + format(epoch=epoch_counter, loss=train_loss))) + def update(self, image): + # split pre and post + batch_view_1, batch_view_2 = torch.split(image, [13, 13], dim=1) + batch_view_1 = batch_view_1.to(self.device) + batch_view_2 = batch_view_2.to(self.device) + # compute query feature + predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) + predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) + + # compute key features + with torch.no_grad(): + targets_to_view_1 = self.target_network(batch_view_1) + targets_to_view_2 = self.target_network(batch_view_2) + + loss = self.regression_loss(predictions_from_view_1, targets_to_view_2) + loss += self.regression_loss(predictions_from_view_2, targets_to_view_1) + return loss.mean() + + def with_no_gradient_update(self, image): + # split pre and post + batch_view_1, batch_view_2 = torch.split(image, [13, 13], dim=1) + batch_view_1 = batch_view_1.to(self.device) + batch_view_2 = batch_view_2.to(self.device) + # compute query feature + predictions_from_view_1 = self.predictor(self.online_network(batch_view_1)) + predictions_from_view_2 = self.predictor(self.online_network(batch_view_2)) + + # compute key features + + targets_to_view_1 = self.target_network(batch_view_1) + targets_to_view_2 = self.target_network(batch_view_2) + + loss = self.regression_loss(predictions_from_view_1, targets_to_view_2) + loss += self.regression_loss(predictions_from_view_2, targets_to_view_1) + return loss.mean() + + def validate(self, val_loader,epoch_counter): + + + self.online_network.eval() + + + + with torch.no_grad(): + TP = FP = FN = TN = 0 + for idx, batch in enumerate(val_loader): + + + start = time.time() + # ===================forward===================== + prediction = self.compute_heatmap(batch, self.val_patch_size, self.val_unfold_stride) + print('time elapsed:', time.time() - start) + change_pre=rosin_thresholding(prediction) + + ground_truth = mpimg.imread(self.ground_truth+'/'+str(1)+'.png') + + + + for i in range(prediction.shape[0]): + for j in range(prediction.shape[1]): + if (change_pre[i,j]==ground_truth[i,j]): + if(ground_truth[i,j]==1): + TP=TP+1 + else:TN=TN+1 + else: + if(ground_truth[i,j]==1): + FN=FN+1 + else:FP=FP+1 + + + for i in range(batch.shape[0]): + plt.imsave(str(idx)+'BYOLI' + str(i) + '.png', prediction.squeeze().cpu().detach().numpy(), cmap='gray') + plt.imsave(str(idx)+'BYOL' + str(i) + '.png', np.squeeze(change_pre), cmap='gray') + + Pre = TP / (TP + FP) + Rec = TP / (TP + FN) + F1 = 2 * Pre * Rec / (Pre + Rec+1e-8) + OA = (TP + TN) / (TP + TN + FP + FN) + PE = (TP + FP) * (TP + FN) / (TP + TN + FP + FN) ** 2 + (FN + TN) * (FP + TN) / (TP + TN + FP + FN) ** 2 + Kappa = (OA - PE) / (1 - PE) + print(Pre, Rec, F1, OA, Kappa) + + vis.line(X=[epoch_counter], Y=[[Pre, Rec, F1, OA, Kappa]], win=acc_window, opts=opt2,update="append") + + + + + + def patchize(self, img: torch.Tensor, patch_size, unfold_stride) -> torch.Tensor: + + + + B, C, iH, iW = img.shape + pH = patch_size + pW = patch_size + + unfold = nn.Unfold(kernel_size=(pH, pW),dilation=1, stride=unfold_stride) + + patches = unfold(img) # (B, V, P) + patches = patches.permute(0, 2, 1).contiguous() # (B, P, V) + patches = patches.view(-1, C, pH, pW) # (P, C, pH, pW) + return patches + + def compute_squared_l2_distance(self, pred: torch.Tensor, surrogate_label: torch.Tensor) -> torch.Tensor: + + losses = (pred - surrogate_label) + losses = torch.norm(losses, p='fro', dim=1, keepdim=False, out=None, dtype=None) + # losses = losses.view(losses.shape[0], -1) + # losses = torch.mean(losses, dim=1) + losses = losses.cpu().detach() + + return losses + + def compute_heatmap(self, img: torch.Tensor, patch_size, unfold_stride): + + + + patches = self.patchize(img, patch_size, unfold_stride) + + B, C, iH, iW = img.shape + P, C, pH, pW = patches.shape + + heatmap = torch.zeros(P) + quotient, remainder = divmod(P, self.pbatch_size) + + for i in range(quotient): + + start = i * self.pbatch_size + end = start + self.pbatch_size + + patch = patches[start:end, :, :, :] + patch = patch.to(self.device) + + patch1, patch2 = torch.split(patch, [13, 13], dim=1) + surrogate_label = self.online_network(patch1) + pred = self.online_network(patch2) + surrogate_label_m = surrogate_label.mean(dim=1).view(self.pbatch_size, 1) + surrogate_label_s = surrogate_label.std(dim=1).view(self.pbatch_size, 1) + pred_m = pred.mean(dim=1).view(self.pbatch_size, 1) + pred_s = pred.std(dim=1).view(self.pbatch_size, 1) + surrogate_label = torch.div((surrogate_label - surrogate_label_m), surrogate_label_s) + pred = torch.div((pred - pred_m), pred_s) + + losses = self.compute_squared_l2_distance(pred, surrogate_label) + heatmap[start:end] = losses + + if remainder != 0: + patch = patches[-remainder:, :, :, :] + patch = patch.to(self.device) + patch1, patch2 = torch.split(patch, [13, 13], dim=1) + surrogate_label = self.online_network(patch1) + pred = self.online_network(patch2) + surrogate_label_m = surrogate_label.mean(dim=1).view(remainder, 1) + surrogate_label_s = surrogate_label.std(dim=1).view(remainder, 1) + pred_m = pred.mean(dim=1).view(remainder, 1) + pred_s = pred.std(dim=1).view(remainder, 1) + surrogate_label = torch.div((surrogate_label - surrogate_label_m), surrogate_label_s) + pred = torch.div((pred - pred_m), pred_s) + + losses = self.compute_squared_l2_distance(pred, surrogate_label) + heatmap[-remainder:] = losses + + fold = nn.Fold( + output_size=(iH, iW), + kernel_size=(pH, pW), + stride=unfold_stride, + ) + Pbatch = int(P/B) + + heatmap = heatmap.expand(1, pH * pW, P) + heatmap = heatmap.contiguous().view(B, pH * pW, Pbatch) + heatmap = fold(heatmap) + heatmap = heatmap.squeeze() + + del patches + return heatmap + + def save_model(self, PATH): + print('==> Saving...') + state ={ + 'online_network_state_dict': self.online_network.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + } + torch.save(state, PATH) + # help release GPU memory + del state + + +def main(): + + # parse the args + args = parse_option() + + # set flags for GPU processing if available + if torch.cuda.is_available(): + args.use_gpu = True + device = 'cuda' + else: + args.use_gpu = False + device = 'cpu' + + # set the data loader + train_loader,eval_loader= get_train_val_loader(args) + args.n_inputs = 8 + args.n_classes = 2 + + # set the model + online_network = ResNet34(num_feats=args.feat_dim, width=1, in_channel=13).to(device) + + if args.resume: + try: + print('loading pretrained models') + checkpoints_folder = os.path.join('.', 'pre_train') + + # load pre-trained parameters + load_params = torch.load(os.path.join(os.path.join(checkpoints_folder, 'BYOL' + str(args.crop_size) + '.pth')), + map_location=device) + + online_network.load_state_dict(load_params['online_network_state_dict']) + + if args.test: + trainer = BYOLTrainer(args, + online_network=online_network, + target_network=None, + optimizer=None, + predictor=None, + device=device) + + trainer.validate(train_loader) + + except FileNotFoundError: + print("Pre-trained weights not found. Training from scratch.") + + # --> target model + target_network = copy.deepcopy(online_network) + target_network = target_network.to(device) + # predictor network + predictor = PreMLPHead(args.feat_dim, 4096, args.feat_dim).to(device) + + # target encoder + optimizer = torch.optim.SGD(list(online_network.parameters()) +list(online_network.parameters()) + list(predictor.parameters()), lr=args.learning_rate, + momentum=args.momentum, weight_decay=args.weight_decay) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True, + threshold=0.0001, threshold_mode='rel', cooldown=10, min_lr=0, + eps=1e-08) + + + + trainer = BYOLTrainer(args, + online_network=online_network, + target_network=target_network, + optimizer=optimizer, + scheduler=scheduler, + predictor=predictor, + device=device) + + trainer.train(train_loader,eval_loader) + + + +if __name__ == '__main__': + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + main() -- Gitee