This module include classes and functions to train BA-Net. **Note: This module uses fastai version 1.**
# export import fastai from fastai.vision import * from fastai.callbacks import * import scipy.io as sio import sys from banet.models import BA_Net

Callbacks

# export class SampleEpisode(Sampler): def __init__(self, data_source, n_episodes, sequence_len, n_sequences, info_df, nburned=100): self.ds, self.epoch_size = data_source, n_episodes self.sequence_len, self.n_sequences = sequence_len, n_sequences self._epochs = [] self.df = info_df self.nburned = nburned def __len__(self): return self.epoch_size*self.sequence_len*self.n_sequences def __iter__(self): return iter(self.get_epoch()) def get_epoch(self): """Get indices for one epoch of size epoch_size""" idx = [] for n in range(self.epoch_size): idx = [*idx, *self.get_batch()] return idx def get_batch(self): """Get indices for one mini-batch""" idx = [] n = 0 while n < self.n_sequences: k = np.random.choice(self.df.loc[self.df.ba>self.nburned].index, size=1, replace=False)[0] s = self.random_sample(k) if s is not None: idx = [*idx, *s] n += 1 return idx def random_sample(self, k): """Random samples are n-way k-shot""" idx = [] condition = ((self.df.name == self.df.loc[k, 'name']) & (self.df.time == self.df.loc[k, 'time'] + pd.Timedelta(days=self.sequence_len)) & (self.df.r == self.df.loc[k, 'r']) & (self.df.c == self.df.loc[k, 'c'])) where = self.df.loc[condition].index.values if len(where) == 0: idx = None else: times = pd.date_range(self.df.loc[k-self.sequence_len//2, 'time'], periods=2*self.sequence_len, freq='D') condition = ((self.df.name == self.df.loc[k, 'name']) & (self.df.time.isin(times)) & (self.df.r == self.df.loc[k, 'r']) & (self.df.c == self.df.loc[k, 'c'])) where = self.df.loc[condition].sort_values(by='time').index.values idx = where[:self.sequence_len] if len(idx) != self.sequence_len: idx = None return idx class ImageSequence(LearnerCallback): def __init__(self, learn, sequence_len=64, n_sequences=1): super().__init__(learn) self.sequence_len = sequence_len self.n_sequences = n_sequences def on_batch_begin(self, last_input, last_target, epoch, iteration, **kwargs): bs, ch, sz1, sz2 = last_input.size() last_input = last_input.view(self.sequence_len, self.n_sequences, ch, sz1, sz2).permute(1, 2, 0, 3, 4) last_target = last_target.view(self.sequence_len, self.n_sequences, 1, sz1, sz2).permute(1, 2, 0, 3, 4)#.max(2)[0] return {'last_input': last_input, 'last_target': last_target}# export def get_y_fn(file, satellite='VIIRS750', target_product='MCD64A1C6'): f = str(Path(str(file))).replace('images', 'masks') f = f.replace(satellite, target_product) return f def open_mat(fn, *args, **kwargs): data = sio.loadmat(fn) data = np.array([data[r] for r in ['Red', 'NIR', 'MIR', 'FRP']]) data[np.isnan(data)] = 0 data[-1, ...] = np.log1p(data[-1,...]) data[np.isnan(data)] = 0 data = torch.from_numpy(data).float() return Image(data) def open_mask(fn, *args, **kwargs): data = sio.loadmat(fn)['bafrac'] data[np.isnan(data)] = 0 data = torch.from_numpy(data).float() return Image(data.view(-1, data.size()[0], data.size()[1])) def set_info_df(items_list, satellite='VIIRS750', target_product='MCD64A1C6'): names, dates = [], [] rs, cs = [], [] for o in items_list: name, date, r, c = Path(o).stem.split('_') date = pd.Timestamp(date) names.append(name) dates.append(date) rs.append(r) cs.append(c) ba = [open_mask(get_y_fn(str(o), satellite=satellite, target_product=target_product) ).data.sum().item() for o in progress_bar(items_list)] return pd.DataFrame({'name': names, 'time': dates, 'r':rs, 'c':cs, 'ba':ba}) class SegLabelListCustom(SegmentationLabelList): def open(self, fn): return open_mask(fn, div=True) class SegItemListCustom(ImageList): _label_cls = SegLabelListCustom def open(self, fn): return open_mat(fn) def _cutout(x, n_holes:uniform_int=1, length:uniform_int=40): "Cut out `n_holes` number of square holes of size `length` in image at random locations." h,w = x.shape[1:] for n in range(n_holes): h_y = np.random.randint(0, h) h_x = np.random.randint(0, w) y1 = int(np.clip(h_y - length / 2, 0, h)) y2 = int(np.clip(h_y + length / 2, 0, h)) x1 = int(np.clip(h_x - length / 2, 0, w)) x2 = int(np.clip(h_x + length / 2, 0, w)) #x[:2, y1:y2, x1:x2] = 1 x[-1, y1:y2, x1:x2] = 0 return x cutout = TfmPixel(_cutout, order=20, ) def _cutout2(x, n_holes:uniform_int=1, length:uniform_int=40): "Cut out `n_holes` number of square holes of size `length` in image at random locations." h,w = x.shape[1:] h_y = np.random.randint(0, h) h_x = np.random.randint(0, w) y1 = int(np.clip(h_y - length / 2, 0, h)) y2 = int(np.clip(h_y + length / 2, 0, h)) x1 = int(np.clip(h_x - length / 2, 0, w)) x2 = int(np.clip(h_x + length / 2, 0, w)) x[0, y1:y2, x1:x2] = torch.rand(1) x[1, y1:y2, x1:x2] = torch.rand(1) x[2, y1:y2, x1:x2] = torch.rand(1) return x cutout2 = TfmPixel(_cutout2, order=20) class BCE(Module): "Binary Cross Entropy loss." def forward(self, x, y): bce = nn.BCEWithLogitsLoss() return 100*bce(x.view(x.size()[0],-1),y.view(y.size()[0], -1)) def accuracy(input:Tensor, targs:Tensor, thr:int=0.5)->Rank0Tensor: "Compute accuracy with `targs` when `input` is bs * n_classes." input = (input.sigmoid()>thr).long() targs = (targs>thr).long() return (input==targs).float().mean() def dice2d(pred, targs, thr=0.5): pred = pred.squeeze() targs = targs.squeeze().sum(0) pred = (pred.sigmoid().sum(0)>thr).float() targs = (targs>thr).float() return 2. * (pred*targs).sum() / (pred+targs).sum() def mae(pred, targs, thr=0.5): a = pred.squeeze().sigmoid().sum(0)>thr pred = pred.squeeze().max(0)[1] targs = targs.squeeze().max(0)[1] pred = pred[a.byte()] targs = targs[a.byte()] return (pred-targs).abs().float().mean() def train_model(val_year, r_fold, path, model_path, n_epochs=8, lr=1e-2, nburned=10, n_episodes_train=2000, n_episodes_valid=100, sequence_len=64, n_sequences=1, do_cutout=True, model_arch=None, pretrained_weights=None, satellite='VIIRS750', target_product='MCD64A1C6', get_learner=False): path_img = path/'images' train_files = sorted([f.name for f in path_img.iterdir()]) times = pd.DatetimeIndex([pd.Timestamp(t.split('_')[1]) for t in train_files]) train_df = pd.DataFrame({'times': times, 'ID': train_files}) valid_idx = train_df.loc[train_df.times.dt.year == val_year].index.values if do_cutout: tfms = get_transforms(do_flip=False, max_zoom=0, max_warp=0, max_rotate=0, xtra_tfms=[cutout(n_holes=(1, 5), length=(5, 50), p=0.5), cutout2(n_holes=(1, 5), length=(5, 50), p=0.5)]) else: tfms = get_transforms(do_flip=False, max_zoom=0, max_warp=0, max_rotate=0) data = (SegItemListCustom.from_df(train_df, path, cols='ID', folder='images') .split_by_idx(valid_idx) .label_from_func( partial(get_y_fn, satellite=satellite, target_product=target_product), classes=['Burned']) .transform(tfms, size=128, tfm_y=False)) info_train_df = set_info_df(data.train.items, satellite=satellite, target_product=target_product) info_valid_df = set_info_df(data.valid.items, satellite=satellite, target_product=target_product) bs = sequence_len*n_sequences train_dl = DataLoader( data.train, batch_size=bs, sampler=SampleEpisode(data.train, n_episodes=n_episodes_train, sequence_len=sequence_len, n_sequences=n_sequences, info_df=info_train_df, nburned=nburned)) valid_dl = DataLoader( data.valid, batch_size=bs, sampler=SampleEpisode(data.valid, n_episodes=n_episodes_valid, sequence_len=sequence_len, n_sequences=n_sequences, info_df=info_valid_df, nburned=nburned)) databunch = ImageDataBunch(train_dl, valid_dl, path='.') databunch = databunch.normalize([tensor([0.2349, 0.3548, 0.1128, 0.0016]), tensor([0.1879, 0.1660, 0.0547, 0.0776])]) if model_arch is None: model = BA_Net(4, 1, sequence_len) else: model = model_arch(4, 1, sequence_len) if pretrained_weights is not None: print(f'Loading pretrained_weights from {pretrained_weights}\n') if torch.cuda.is_available(): model.load_state_dict(torch.load(pretrained_weights)['model']) else: model.load_state_dict( torch.load(pretrained_weights, map_location=torch.device('cpu'))['model']) learn = Learner(databunch, model, callback_fns=[ partial(ImageSequence, sequence_len=sequence_len, n_sequences=n_sequences)], loss_func=BCE(), wd=1e-2, metrics=[accuracy, dice2d, mae]) learn.clip_grad = 1 if get_learner: return learn print('Starting traning loop\n') learn.fit_one_cycle(n_epochs, lr) model_path.mkdir(exist_ok=True) torch.save(learn.model.state_dict(), model_path/f'banet-val{val_year}-fold{r_fold}-test.pth') print(f'Completed! banet-val{val_year}-fold{r_fold}-test.pth saved to {model_path}.')