This module include classes and functions to train BA-Net
SampleEpisode
(data_source
, n_episodes
, sequence_len
, n_sequences
, info_df
, nburned
=100
) :: Sampler
Base class for all Samplers.
Every Sampler subclass has to provide an :meth:__iter__
method, providing a
way to iterate over indices of dataset elements, and a :meth:__len__
method
that returns the length of the returned iterators.
.. note:: The :meth:__len__
method isn't strictly required by
:class:~torch.utils.data.DataLoader
, but is expected in any
calculation involving the length of a :class:~torch.utils.data.DataLoader
.
ImageSequence
(learn
, sequence_len
=64
, n_sequences
=1
) :: LearnerCallback
Base class for creating callbacks for a Learner
.
get_y_fn
(file
, satellite
='VIIRS750'
, target_product
='MCD64A1C6'
)
open_mat
(fn
, *args
, **kwargs
)
open_mask
(fn
:Union
[Path
, str
], div
=False
, convert_mode
='L'
, after_open
:Callable
=None
)
Return ImageSegment
object create from mask in file fn
. If div
, divides pixel values by 255.
set_info_df
(items_list
, satellite
='VIIRS750'
, target_product
='MCD64A1C6'
)
SegLabelListCustom
(items
:Iterator
[T_co
], classes
:Collection
[T_co
]=None
, **kwargs
) :: SegmentationLabelList
ItemList
for segmentation masks.
SegItemListCustom
(*args
, convert_mode
='RGB'
, after_open
:Callable
=None
, **kwargs
) :: ImageList
ItemList
suitable for computer vision.
BCE
() :: Module
Binary Cross Entropy loss.
accuracy
(input
:Tensor
, targs
:Tensor
)
Computes accuracy with targs
when input
is bs * n_classes.
dice2d
(pred
, targs
, thr
=0.5
)
mean_absolute_error
(pred
:Tensor
, targ
:Tensor
)
Mean absolute error between pred
and targ
.
train_model
(val_year
, r_fold
, path
, model_path
, n_epochs
=8
, lr
=0.01
, nburned
=10
, n_episodes_train
=2000
, n_episodes_valid
=100
, sequence_len
=64
, n_sequences
=1
, do_cutout
=True
, model_arch
=None
, pretrained_weights
=None
, satellite
='VIIRS750'
, target_product
='MCD64A1C6'
, get_learner
=False
)