How to split data by groups of items (instead of items only) with scikit-learn
Today I learned about the nifty GroupShuffleSplit
class from the scikit-learn
library.
The class can be quite useful to split data by groups.
Example: suppose you want to create a training set and an out-of-time validation set, and you want to keep a distinct set of years to each partition. One way to do so would be like this:
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
# Create a sample dataframe
df = pd.DataFrame({
'feature_a': [1, 2, 3, 4, 5, 6],
'feature_b': [1.1, 2.2, 3.3, 4.4, 5.5, 6.6],
'year': [2001, 2002, 2002, 2003, 2003, 2003]
})
# Create GroupShuffleSplit object
gss = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
# Split the dataframe into train and test sets, based on the 'year' column
for train_idx, test_idx in gss.split(df, groups=df['year']):
train_set = df.loc[train_idx]
test_set = df.loc[test_idx]
# Print the train and test sets
print("Training set:")
print(train_set)
print("\nTest set:")
print(test_set)
Training set:
feature_a feature_b year
1 2 2.2 2002
2 3 3.3 2002
3 4 4.4 2003
4 5 5.5 2003
5 6 6.6 2003
Test set:
feature_a feature_b year
0 1 1.1 2001