Source code for skmatter.model_selection._split

import sklearn.model_selection
from sklearn.utils import indexable
from sklearn.utils.validation import _num_samples


[docs] def train_test_split(*arrays, **options): """Extended version of the sklearn train test split supporting overlapping train and test sets. See `sklearn.model_selection.train_test_split (external link) <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html>`_ . Parameters ---------- *arrays : sequence of indexables with same length / shape[0] Allowed inputs are lists, numpy arrays, scipy-sparse matrices or pandas dataframes. test_size : float or int, default=None If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test samples. If :obj:`None`, the value is set to the complement of the train size. If ``train_size`` is also None, it will be set to 0.25. train_size : float or int, default=None If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split. If int, represents the absolute number of train samples. If :obj:`None`, the value is automatically set to the complement of the test size. random_state : int or :class`numpy.random.RandomState` instance, default=None Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls. See `random state glossary from sklearn (external link) <https://scikit-learn.org/stable/glossary.html#term-random-state>`_ shuffle : bool, default=True Whether or not to shuffle the data before splitting. If shuffle=False then stratify must be :obj:`None`. stratify : array-like, default=None If not :obj:`None`, data is split in a stratified fashion, using this as the class labels. train_test_overlap : bool, default=False If :obj:`True`, and train and test set are both not :obj:`None`, the train and test set may overlap. Returns ------- splitting : list, length=2 * len(arrays) List containing train-test split of inputs. """ # NoQa: E501 train_test_overlap = options.pop("train_test_overlap", False) test_size = options.get("test_size", None) train_size = options.get("train_size", None) if train_test_overlap and train_size is not None and test_size is not None: # checks from sklearn arrays = indexable(*arrays) n_samples = _num_samples(arrays[0]) if test_size == 1.0 or test_size == n_samples: test_sets = arrays else: options["train_size"] = None test_sets = sklearn.model_selection.train_test_split(*arrays, **options)[ 1::2 ] options["train_size"] = train_size if train_size == 1.0 or train_size == n_samples: train_sets = arrays else: options["test_size"] = None train_sets = sklearn.model_selection.train_test_split(*arrays, **options)[ ::2 ] options["test_size"] = test_size train_test_sets = [] for i in range(len(train_sets)): train_test_sets += [train_sets[i], test_sets[i]] return train_test_sets else: return sklearn.model_selection.train_test_split(*arrays, **options)