skbio.table.Augmentation.mixup#
- Augmentation.mixup(n_samples, alpha=2, seed=None)[source]#
Data Augmentation by vanilla mixup.
Randomly select two samples \(s_1\) and \(s_2\) from the OTU table, and generate a new sample \(s\) by a linear combination of \(s_1\) and \(s_2\), as follows:
\[s = \lambda \cdot s_1 + (1 - \lambda) \cdot s_2\]where \(\lambda\) is a random number sampled from a beta distribution with parameters \(\alpha\) and \(\alpha\). The label is computed as the linear combination of the two labels of the two samples
\[y = \lambda \cdot y_1 + (1 - \lambda) \cdot y_2\]- Parameters:
- n_samplesint
The number of new samples to generate.
- alphafloat
The alpha parameter of the beta distribution.
- seedint, Generator or RandomState, optional
A user-provided random seed or random generator instance. See
details
.
- Returns:
- augmented_matrixnumpy.ndarray
The augmented matrix.
- augmented_labelnumpy.ndarray
The augmented label, in one-hot encoding. if the user want to use the augmented label for regression, users can simply call
np.argmax(aug_label, axis=1)
to get the discrete labels.
Notes
The mixup is based on [1], and shares the same core concept as PyTorch’s MixUp. there are key differences:
This implementation generates new samples to augment a dataset, while PyTorch’s MixUp is applied on-the-fly during training to batches of data.
This implementation randomly selects pairs of samples from the entire dataset, while PyTorch’s implementation typically mixes consecutive samples in a batch (requiring prior shuffling).
This implementation returns an augmented dataset with both original and new samples, while PyTorch’s implementation transforms a batch in-place.
This implementation is designed for omic data tables, while PyTorch’s is primarily for image data. And this implementation is mainly based on the Numpy Library.
References
[1]Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2017). mixup: Beyond Empirical Risk Minimization. arXiv preprint arXiv:1710.09412.
Examples
>>> from skbio.table import Table >>> from skbio.table import Augmentation >>> data = np.arange(40).reshape(10, 4) >>> sample_ids = ['S%d' % i for i in range(4)] >>> feature_ids = ['O%d' % i for i in range(10)] >>> table = Table(data, feature_ids, sample_ids) >>> label = np.random.randint(0, 2, size=table.shape[1]) >>> augmentation = Augmentation(table, label, num_classes=2) >>> aug_matrix, aug_label = augmentation.mixup(n_samples=5) >>> print(aug_matrix.shape) (9, 10) >>> print(aug_label.shape) (9, 2)