def add_noise_to_value(value: int, noise_param: float)
@BatchSampler.register("bucket") class BucketBatchSampler(BatchSampler): | def __init__( | self, | batch_size: int, | sorting_keys: List[str] = None, | padding_noise: float = 0.1, | drop_last: bool = False | )
An sampler which by default, argsorts batches with respect to the maximum input lengths
batch. You can provide a list of field names and padding keys (or pass none, in which case they
will be inferred) which the dataset will be sorted by before doing this batching, causing inputs
with similar length to be batched together, making computation more efficient (as less time is
wasted on padded elements of the batch).
The size of each batch of instances yielded when calling the data_loader.
To bucket inputs into batches, we want to group the instances by padding length, so that we minimize the amount of padding necessary per batch. In order to do this, we need to know which fields need what type of padding, and in what order.
Specifying the right keys for this is a bit cryptic, so if this is not given we try to auto-detect the right keys by iterating through a few instances upfront, reading all of the padding keys and seeing which one has the longest length. We use that one for padding. This should give reasonable results in most cases. Some cases where it might not be the right thing to do are when you have a
ListField[TextField], or when you have a really long, constant length
When you need to specify this yourself, you can create an instance from your dataset and call
Instance.get_padding_lengths()to see a list of all keys used in your data. You should give one or more of those as the sorting keys here.
float, optional (default =
When sorting by padding length, we add a bit of noise to the lengths, so that the sorting isn't deterministic. This parameter determines how much noise we add, as a percentage of the actual padding value for each instance.
bool, optional (default =
True, the sampler will drop the last batch if its size would be less than batch_size`.
class BucketBatchSampler(BatchSampler): | ... | def get_batch_indices( | self, | instances: Sequence[Instance] | ) -> Iterable[List[int]]
class BucketBatchSampler(BatchSampler): | ... | def get_num_batches(self, instances: Sequence[Instance]) -> int
class BucketBatchSampler(BatchSampler): | ... | def get_batch_size(self) -> Optional[int]