This module include classes and functions to train BA-Net
SampleEpisode(data_source, n_episodes, sequence_len, n_sequences, info_df, nburned=100) :: Sampler
Base class for all Samplers.
Every Sampler subclass has to provide an :meth:__iter__ method, providing a
way to iterate over indices of dataset elements, and a :meth:__len__ method
that returns the length of the returned iterators.
.. note:: The :meth:__len__ method isn't strictly required by
:class:~torch.utils.data.DataLoader, but is expected in any
calculation involving the length of a :class:~torch.utils.data.DataLoader.
ImageSequence(learn, sequence_len=64, n_sequences=1) :: LearnerCallback
Base class for creating callbacks for a Learner.
get_y_fn(file, satellite='VIIRS750', target_product='MCD64A1C6')
open_mat(fn, *args, **kwargs)
open_mask(fn:Union[Path, str], div=False, convert_mode='L', after_open:Callable=None)
Return ImageSegment object create from mask in file fn. If div, divides pixel values by 255.
set_info_df(items_list, satellite='VIIRS750', target_product='MCD64A1C6')
SegLabelListCustom(items:Iterator[T_co], classes:Collection[T_co]=None, **kwargs) :: SegmentationLabelList
ItemList for segmentation masks.
SegItemListCustom(*args, convert_mode='RGB', after_open:Callable=None, **kwargs) :: ImageList
ItemList suitable for computer vision.
BCE() :: Module
Binary Cross Entropy loss.
accuracy(input:Tensor, targs:Tensor)
Computes accuracy with targs when input is bs * n_classes.
dice2d(pred, targs, thr=0.5)
mean_absolute_error(pred:Tensor, targ:Tensor)
Mean absolute error between pred and targ.
train_model(val_year, r_fold, path, model_path, n_epochs=8, lr=0.01, nburned=10, n_episodes_train=2000, n_episodes_valid=100, sequence_len=64, n_sequences=1, do_cutout=True, model_arch=None, pretrained_weights=None, satellite='VIIRS750', target_product='MCD64A1C6', get_learner=False)