Skip to main content

Decision Trees: From Basics to Random Forests

Posted on:  at 
Machine Learning

Decision Trees: From Basics to Random Forests

Decision trees are among the most intuitive machine learning algorithms. Having used decision tree-based solutions that saved millions in operational costs, I want to share practical insights on when and how to use them effectively.

What Are Decision Trees?

Decision trees create predictive models by learning simple decision rules from data features. Think of them as a series of yes/no questions that lead to a prediction.

How Decision Trees Work

The algorithm builds a tree by:

  1. Selecting the Best Split: Choose the feature and threshold that best separates the data
  2. Recursive Partitioning: Apply the same process to each subset
  3. Stopping: When criteria are met (max depth, min samples, etc.)
  4. Prediction: Follow the path from root to leaf

Mathematical Foundation

Impurity Measures

Decision trees use impurity measures to find the best splits:

Gini Impurity (Classification): $$Gini(t) = 1 - \sum_{i} [p(i|t)]^2$$

Entropy (Classification): $$Entropy(t) = -\sum_{i} p(i|t) \log_2 p(i|t)$$

Mean Squared Error (Regression): $$MSE(t) = \frac{1}{n} \sum_{i} (y_i - \bar{y})^2$$

Information Gain

Information gain measures the reduction in impurity: $$IG = Impurity(parent) - \sum_{j} \left( \frac{n_j}{n} \times Impurity(child_j) \right)$$

Complete Implementation

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.datasets import make_classification, make_regression
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, mean_squared_error, classification_report
from sklearn.tree import plot_tree
import seaborn as sns

# Generate sample classification data
X_class, y_class = make_classification(
    n_samples=1000, 
    n_features=4, 
    n_informative=3, 
    n_redundant=1, 
    n_clusters_per_class=1, 
    random_state=42
)

# Generate sample regression data
X_reg, y_reg = make_regression(
    n_samples=1000, 
    n_features=4, 
    noise=0.1, 
    random_state=42
)

# Split data
X_train_c, X_test_c, y_train_c, y_test_c = train_test_split(
    X_class, y_class, test_size=0.2, random_state=42
)

X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(
    X_reg, y_reg, test_size=0.2, random_state=42
)

Decision Tree Implementation

Classification Example

# Create and train decision tree classifier
dt_classifier = DecisionTreeClassifier(
    max_depth=5,
    min_samples_split=20,
    min_samples_leaf=10,
    random_state=42
)

dt_classifier.fit(X_train_c, y_train_c)

# Make predictions
y_pred_dt = dt_classifier.predict(X_test_c)
dt_accuracy = accuracy_score(y_test_c, y_pred_dt)

print("Decision Tree Classifier Results:")
print(f"Accuracy: {dt_accuracy:.4f}")
print("\nDetailed Classification Report:")
print(classification_report(y_test_c, y_pred_dt))

# Feature importance
feature_importance = dt_classifier.feature_importances_
print(f"\nFeature Importances: {feature_importance}")

Regression Example

# Decision tree regressor
dt_regressor = DecisionTreeRegressor(
    max_depth=5,
    min_samples_split=20,
    min_samples_leaf=10,
    random_state=42
)

dt_regressor.fit(X_train_r, y_train_r)

# Predictions and evaluation
y_pred_dt_reg = dt_regressor.predict(X_test_r)
dt_mse = mean_squared_error(y_test_r, y_pred_dt_reg)
dt_rmse = np.sqrt(dt_mse)

print(f"Decision Tree Regressor RMSE: {dt_rmse:.4f}")

Hyperparameter Tuning

Key parameters to tune:

from sklearn.model_selection import GridSearchCV

# Define parameter grid
param_grid = {
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_split': [2, 10, 20],
    'min_samples_leaf': [1, 5, 10],
    'max_features': ['sqrt', 'log2', None]
}

# Grid search
grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1
)

grid_search.fit(X_train_c, y_train_c)

print("Best Parameters:", grid_search.best_params_)
print("Best Cross-validation Score:", grid_search.best_score_)

# Use best model
best_dt = grid_search.best_estimator_

Tree Visualization

# Visualize the decision tree
plt.figure(figsize=(20, 10))
plot_tree(
    dt_classifier,
    max_depth=3,  # Limit depth for readability
    feature_names=[f'Feature_{i}' for i in range(X_class.shape[1])],
    class_names=['Class_0', 'Class_1'],
    filled=True,
    rounded=True,
    fontsize=10
)
plt.title("Decision Tree Visualization (Max Depth 3)")
plt.show()

Random Forest: Ensemble Power

Decision trees are prone to overfitting. Random Forests overcome this by building multiple decision trees and aggregating their predictions.

How Random Forests Work

  1. Bagging: Each tree is built on a bootstrap sample of the data.
  2. Feature Randomness: At each split, only a random subset of features is considered.
  3. Aggregation:
    • Classification: Majority vote
    • Regression: Average of predictions

Random Forest Implementation

# Random Forest Classifier
rf_classifier = RandomForestClassifier(
    n_estimators=100,  # Number of trees
    max_depth=10,
    min_samples_split=10,
    min_samples_leaf=5,
    random_state=42,
    n_jobs=-1  # Use all available cores
)

rf_classifier.fit(X_train_c, y_train_c)
y_pred_rf = rf_classifier.predict(X_test_c)
rf_accuracy = accuracy_score(y_test_c, y_pred_rf)

print("\nRandom Forest Classifier Results:")
print(f"Accuracy: {rf_accuracy:.4f}")
print("\nDetailed Classification Report (Random Forest):")
print(classification_report(y_test_c, y_pred_rf))

# Random Forest Regressor
rf_regressor = RandomForestRegressor(
    n_estimators=100,
    max_depth=10,
    min_samples_split=10,
    min_samples_leaf=5,
    random_state=42,
    n_jobs=-1
)

rf_regressor.fit(X_train_r, y_train_r)
y_pred_rf_reg = rf_regressor.predict(X_test_r)
rf_mse = mean_squared_error(y_test_r, y_pred_rf_reg)
rf_rmse = np.sqrt(rf_mse)

print(f"Random Forest Regressor RMSE: {rf_rmse:.4f}")

# Feature importance from Random Forest
rf_feature_importance = rf_classifier.feature_importances_
print(f"\nRandom Forest Feature Importances: {rf_feature_importance}")

Advantages of Decision Trees and Random Forests

Decision Trees:

  • Interpretability: Easy to understand and visualize (especially small trees).
  • No Scaling Needed: Not sensitive to feature scaling.
  • Handle Mixed Data: Can handle both numerical and categorical data.

Random Forests:

  • High Accuracy: Generally perform very well.
  • Robust to Overfitting: Due to bagging and feature randomness.
  • Feature Importance: Provide a good measure of feature importance.
  • Handle Missing Values: Can handle missing values implicitly.

Disadvantages

Decision Trees:

  • Overfitting: Prone to overfitting if not pruned or depth-limited.
  • Instability: Small changes in data can lead to a very different tree.

Random Forests:

  • Less Interpretable: More of a "black box" compared to single trees.
  • Computational Cost: More trees mean longer training times.

Real-World Applications

As a Staff Data Scientist, I've leveraged these algorithms for critical business problems:

1. Fraud Detection

Used Random Forests to identify anomalous transactions and fraudulent driver behavior, significantly reducing financial losses.

# Example for fraud detection
# features = ['transaction_amount', 'location_deviation', 'time_of_day', 'user_history_score']
# fraud_model = RandomForestClassifier()

2. Customer Churn Prediction

Decision trees helped understand key drivers of customer churn by identifying decision rules (e.g., "if customer calls support > 3 times and plan cost > X, then churn risk is high").

3. Medical Diagnosis

Random Forests can be used for classifying diseases based on patient symptoms and test results.

4. Supply Chain Optimization

Used decision tree-based models to predict equipment failures, optimizing maintenance schedules and reducing downtime.

When to Use Which?

Decision Tree: When interpretability is paramount, for quick baselines, or for rule extraction.

Random Forest: When high predictive accuracy is needed, robust performance against overfitting, and dealing with complex, high-dimensional data.

Production Best Practices

Deploying tree-based models requires careful attention:

1. Model Monitoring

Monitor feature drift and model performance. Tree models can be sensitive to shifts in data distributions.

2. Explainability (XAI)

For Random Forests, use SHAP or LIME to provide local interpretability, even though the overall model is a black box.

3. Cross-Validation

Always use robust cross-validation (e.g., K-fold) to get reliable performance estimates and tune hyperparameters.

4. Ensemble More Than Trees

Consider Gradient Boosting Machines (like XGBoost, LightGBM) for even higher performance, especially if you have highly correlated features or need to capture complex interactions.

Conclusion

Decision Trees and Random Forests are fundamental tools in a data scientist's toolkit. While simple decision trees offer unparalleled interpretability, Random Forests elevate their power by combining multiple trees into a robust, high-performing ensemble.

Mastering these algorithms, from their theoretical underpinnings to practical production deployment, is key to delivering impactful machine learning solutions.

Do you prefer the simplicity of a single decision tree or the power of a Random Forest? Share your experiences and use cases in the comments!