This module include classes and functions to train BA-Net

Callbacks

class SampleEpisode[source]

SampleEpisode(data_source, n_episodes, sequence_len, n_sequences, info_df, nburned=100) :: Sampler

Base class for all Samplers.

Every Sampler subclass has to provide an :meth:__iter__ method, providing a way to iterate over indices of dataset elements, and a :meth:__len__ method that returns the length of the returned iterators.

.. note:: The :meth:__len__ method isn't strictly required by :class:~torch.utils.data.DataLoader, but is expected in any calculation involving the length of a :class:~torch.utils.data.DataLoader.

class ImageSequence[source]

ImageSequence(learn, sequence_len=64, n_sequences=1) :: LearnerCallback

Base class for creating callbacks for a Learner.

get_y_fn[source]

get_y_fn(file, satellite='VIIRS750', target_product='MCD64A1C6')

open_mat[source]

open_mat(fn, *args, **kwargs)

open_mask[source]

open_mask(fn:Union[Path, str], div=False, convert_mode='L', after_open:Callable=None)

Return ImageSegment object create from mask in file fn. If div, divides pixel values by 255.

set_info_df[source]

set_info_df(items_list, satellite='VIIRS750', target_product='MCD64A1C6')

class SegLabelListCustom[source]

SegLabelListCustom(items:Iterator[T_co], classes:Collection[T_co]=None, **kwargs) :: SegmentationLabelList

ItemList for segmentation masks.

class SegItemListCustom[source]

SegItemListCustom(*args, convert_mode='RGB', after_open:Callable=None, **kwargs) :: ImageList

ItemList suitable for computer vision.

class BCE[source]

BCE() :: Module

Binary Cross Entropy loss.

accuracy[source]

accuracy(input:Tensor, targs:Tensor)

Computes accuracy with targs when input is bs * n_classes.

dice2d[source]

dice2d(pred, targs, thr=0.5)

mean_absolute_error[source]

mean_absolute_error(pred:Tensor, targ:Tensor)

Mean absolute error between pred and targ.

train_model[source]

train_model(val_year, r_fold, path, model_path, n_epochs=8, lr=0.01, nburned=10, n_episodes_train=2000, n_episodes_valid=100, sequence_len=64, n_sequences=1, do_cutout=True, model_arch=None, pretrained_weights=None, satellite='VIIRS750', target_product='MCD64A1C6', get_learner=False)