Bruno Arine

How to split data by groups of items (instead of items only) with scikit-learn

Today I learned about the nifty GroupShuffleSplit class from the scikit-learn library. The class can be quite useful to split data by groups.

Example: suppose you want to create a training set and an out-of-time validation set, and you want to keep a distinct set of years to each partition. One way to do so would be like this:

import pandas as pd
from sklearn.model_selection import GroupShuffleSplit

# Create a sample dataframe
df = pd.DataFrame({
    'feature_a': [1, 2, 3, 4, 5, 6],
    'feature_b': [1.1, 2.2, 3.3, 4.4, 5.5, 6.6],
    'year': [2001, 2002, 2002, 2003, 2003, 2003]
})

# Create GroupShuffleSplit object
gss = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=42)

# Split the dataframe into train and test sets, based on the 'year' column
for train_idx, test_idx in gss.split(df, groups=df['year']):
    train_set = df.loc[train_idx]
    test_set = df.loc[test_idx]

# Print the train and test sets
print("Training set:")
print(train_set)

print("\nTest set:")
print(test_set)
Training set:
   feature_a  feature_b  year
1          2        2.2  2002
2          3        3.3  2002
3          4        4.4  2003
4          5        5.5  2003
5          6        6.6  2003

Test set:
   feature_a  feature_b  year
0          1        1.1  2001