Navigation

Python

How to Split Data into Train/Test Sets

Split your datasets properly with scikit-learn's train_test_split for reliable machine learning model evaluation.

Table Of Contents

Data Splitting Done Right

Proper data splitting prevents overfitting and ensures honest model evaluation. Master train_test_split to build trustworthy machine learning pipelines.

Basic Train/Test Split

from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np

# Sample dataset
data = pd.DataFrame({
    'feature1': np.random.randn(1000),
    'feature2': np.random.randn(1000),
    'feature3': np.random.randn(1000),
    'target': np.random.randint(0, 2, 1000)
})

X = data[['feature1', 'feature2', 'feature3']]
y = data['target']

# Basic split (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")

Stratified Splitting

# For classification - maintain class distribution
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    stratify=y,  # Maintain class proportions
    random_state=42
)

print("Original distribution:")
print(y.value_counts(normalize=True))
print("\nTrain distribution:")
print(pd.Series(y_train).value_counts(normalize=True))
print("\nTest distribution:")
print(pd.Series(y_test).value_counts(normalize=True))

Three-Way Split (Train/Validation/Test)

# First split: separate test set
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Second split: separate train and validation
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.25, random_state=42  # 0.25 * 0.8 = 0.2
)

print(f"Train: {X_train.shape[0]} ({X_train.shape[0]/len(X):.1%})")
print(f"Validation: {X_val.shape[0]} ({X_val.shape[0]/len(X):.1%})")
print(f"Test: {X_test.shape[0]} ({X_test.shape[0]/len(X):.1%})")

Time Series Splitting

from sklearn.model_selection import TimeSeriesSplit

# Time series data
dates = pd.date_range('2020-01-01', periods=1000, freq='D')
ts_data = pd.DataFrame({
    'date': dates,
    'value': np.cumsum(np.random.randn(1000)),
    'target': np.random.randn(1000)
})

# Sort by date
ts_data = ts_data.sort_values('date')
X_ts = ts_data[['value']]
y_ts = ts_data['target']

# Time series split
tscv = TimeSeriesSplit(n_splits=5)
for train_idx, test_idx in tscv.split(X_ts):
    print(f"Train: {len(train_idx)}, Test: {len(test_idx)}")

Cross-Validation Splitting

from sklearn.model_selection import KFold, StratifiedKFold

# K-Fold cross validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X)):
    print(f"Fold {fold+1}: Train={len(train_idx)}, Val={len(val_idx)}")

# Stratified K-Fold for classification
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    X_fold_train, X_fold_val = X.iloc[train_idx], X.iloc[val_idx]
    y_fold_train, y_fold_val = y.iloc[train_idx], y.iloc[val_idx]

Practical Examples

# Example: House price prediction
house_data = pd.DataFrame({
    'sqft': np.random.randint(1000, 5000, 500),
    'bedrooms': np.random.randint(1, 6, 500),
    'price': np.random.randint(200000, 800000, 500)
})

X_house = house_data[['sqft', 'bedrooms']]
y_house = house_data['price']

# Split with different test sizes
splits = [0.1, 0.2, 0.3]
for test_size in splits:
    X_tr, X_te, y_tr, y_te = train_test_split(
        X_house, y_house, test_size=test_size, random_state=42
    )
    print(f"Test size {test_size}: Train={len(X_tr)}, Test={len(X_te)}")

Best Practices

  • Use random_state for reproducible splits
  • Stratify classification data to maintain class balance
  • Reserve 20% for testing, 20% for validation
  • Never use test data during model development
  • Time series: Use temporal splits, not random

Master Model Validation

Explore cross-validation techniques, learn hyperparameter tuning, and discover model evaluation metrics.

Share this article

Add Comment

No comments yet. Be the first to comment!

More from Python