skbio.table.Augmentation.mixup#

Augmentation.mixup(n_samples, alpha=2, seed=None)[source]#

Data Augmentation by vanilla mixup.

Randomly select two samples \(s_1\) and \(s_2\) from the OTU table, and generate a new sample \(s\) by a linear combination of \(s_1\) and \(s_2\), as follows:

\[s = \lambda \cdot s_1 + (1 - \lambda) \cdot s_2\]

where \(\lambda\) is a random number sampled from a beta distribution with parameters \(\alpha\) and \(\alpha\). The label is computed as the linear combination of the two labels of the two samples

\[y = \lambda \cdot y_1 + (1 - \lambda) \cdot y_2\]
Parameters:
n_samplesint

The number of new samples to generate.

alphafloat

The alpha parameter of the beta distribution.

seedint, Generator or RandomState, optional

A user-provided random seed or random generator instance. See details.

Returns:
augmented_matrixnumpy.ndarray

The augmented matrix.

augmented_labelnumpy.ndarray

The augmented label, in one-hot encoding. if the user want to use the augmented label for regression, users can simply call np.argmax(aug_label, axis=1) to get the discrete labels.

Notes

The mixup is based on [1], and shares the same core concept as PyTorch’s MixUp. there are key differences:

  1. This implementation generates new samples to augment a dataset, while PyTorch’s MixUp is applied on-the-fly during training to batches of data.

  2. This implementation randomly selects pairs of samples from the entire dataset, while PyTorch’s implementation typically mixes consecutive samples in a batch (requiring prior shuffling).

  3. This implementation returns an augmented dataset with both original and new samples, while PyTorch’s implementation transforms a batch in-place.

  4. This implementation is designed for omic data tables, while PyTorch’s is primarily for image data. And this implementation is mainly based on the Numpy Library.

References

[1]

Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2017). mixup: Beyond Empirical Risk Minimization. arXiv preprint arXiv:1710.09412.

Examples

>>> from skbio.table import Table
>>> from skbio.table import Augmentation
>>> data = np.arange(40).reshape(10, 4)
>>> sample_ids = ['S%d' % i for i in range(4)]
>>> feature_ids = ['O%d' % i for i in range(10)]
>>> table = Table(data, feature_ids, sample_ids)
>>> label = np.random.randint(0, 2, size=table.shape[1])
>>> augmentation = Augmentation(table, label, num_classes=2)
>>> aug_matrix, aug_label = augmentation.mixup(n_samples=5)
>>> print(aug_matrix.shape)
(9, 10)
>>> print(aug_label.shape)
(9, 2)