← Back to Index

Decision Tree Data Analysis Guide

Author: P Baburaj Ambalam • Created: 2025-12-27 • Adaptable template for any tabular dataset

This page replaces a textbook chapter: it walks you through a complete Decision Tree workflow with motivations, concise explanations, and runnable Python snippets. Adapt the code by changing the dataset path and the target column name.

Step 1

Set up the notebook

Reason for next step: Establish required libraries and reproducibility so results match across runs.

Explanation: Import analysis, plotting, and model utilities. Fix a random seed for deterministic splits and trees.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn import metrics

RANDOM_STATE = 1
pd.set_option("display.precision", 3)
np.random.seed(RANDOM_STATE)

Step 2

Load the dataset

Reason for next step: Bring data into memory for inspection and modeling.

Explanation: Multiple loading methods support different workflows—local files, cloud storage, or web URLs.

Loading from Local File

Use when file is in same directory as notebook:

DATA_PATH = "data.csv"  # change to your dataset
TARGET = "result"       # change to your target column

df = pd.read_csv(DATA_PATH)
df.head()

Loading from Google Drive (Colab)

Method 1: Mount drive

from google.colab import drive
drive.mount('/content/drive')
df = pd.read_csv('/content/drive/MyDrive/path/to/data.csv')

Method 2: Using file ID with gdown

!pip install gdown
import gdown
file_id = 'YOUR_FILE_ID_HERE'
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'data.csv', quiet=False)
df = pd.read_csv('data.csv')

Extract file_id from shareable link: https://drive.google.com/file/d/FILE_ID/view?usp=sharing

Loading from URL

Use for publicly accessible datasets; ensure URL points to raw CSV content:

df = pd.read_csv("https://raw.githubusercontent.com/user/repo/main/data.csv")

Step 3

Understand structure

Reason for next step: Confirm column types, row counts, and target distribution.

Explanation: Quick scans help spot obvious issues (mixed types, severe imbalance).

df.info()
df.describe().T
print(df[TARGET].value_counts(normalize=True))

Common Data Issues and Fixes

Issue 1: Mixed Data Types in Numeric Columns

Problem: Strings like 'N/A' or '?' in numeric columns cause dtype=object

df['column_name'] = pd.to_numeric(df['column_name'], errors='coerce')
# 'coerce' converts invalid values to NaN

Issue 2: Duplicate Rows

Problem: Repeated records inflate dataset size and bias model

print(f"Duplicates: {df.duplicated().sum()}")
df = df.drop_duplicates(keep='first')  # keep first occurrence

Issue 3: Inconsistent Categorical Values

Problem: Case variations ('Yes'/'yes'/'YES') treated as different categories

df['category_col'] = df['category_col'].str.lower().str.strip()

Issue 4: Outliers Detection

Problem: Extreme values can skew tree splits

Q1 = df['numeric_col'].quantile(0.25)
Q3 = df['numeric_col'].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
outliers = df[(df['numeric_col'] < lower_bound) | (df['numeric_col'] > upper_bound)]
print(f"Outliers detected: {len(outliers)}")
# Option: remove or cap
df = df[(df['numeric_col'] >= lower_bound) & (df['numeric_col'] <= upper_bound)]

Issue 5: High Cardinality Categorical Features

Problem: 100+ unique categories create sparse, overfitting-prone trees

value_counts = df['high_card_col'].value_counts()
        top_categories = value_counts.head(10).index
        df['high_card_col'] = df['high_card_col'].apply(
            lambda x: x if x in top_categories else 'Other'
        )

        

Exploratory Data Analysis & Visuals (Generic CSV)

Goals: spot distribution issues, outliers, dominant categories, missingness, class balance, and correlations before modeling.

Identify column types

num_cols = df.select_dtypes(include=["number"]).columns.tolist()
    cat_cols = df.select_dtypes(exclude=["number"]).columns.tolist()
    print(f"Numeric cols ({len(num_cols)}):", num_cols)
    print(f"Categorical cols ({len(cat_cols)}):", cat_cols)

Repeated values per numeric column (top 5)

for col in num_cols:
      counts = df[col].value_counts().head(5)
      print(f"\nTop repeats for {col}:")
      print(counts)
    

Outlier visualization (box + violin)

for col in num_cols:
      fig, axes = plt.subplots(1, 2, figsize=(10, 4))
      sns.boxplot(x=df[col], ax=axes[0], color="#7cc7ff")
      sns.violinplot(x=df[col], ax=axes[1], color="#4f81c7")
      axes[0].set_title(f"Boxplot: {col}")
      axes[1].set_title(f"Violin: {col}")
      plt.tight_layout()
      plt.show()
    

Top categories per categorical column (top 5)

for col in cat_cols:
      counts = df[col].value_counts().head(5)
      print(f"\nTop categories for {col}:")
      print(counts)
      plt.figure(figsize=(6,4))
      sns.barplot(x=counts.values, y=counts.index, color="#7cc7ff")
      plt.title(f"Top categories: {col}")
      plt.xlabel("Count")
      plt.ylabel(col)
      plt.tight_layout()
      plt.show()
    

Missingness overview

missing = df.isna().sum().sort_values(ascending=False)
    print(missing[missing > 0])

    plt.figure(figsize=(10, 4))
    sns.heatmap(df.isna(), cbar=False, cmap="mako")
    plt.title("Missingness heatmap (rows vs columns)")
    plt.tight_layout()
    plt.show()
    

Target / class balance

class_counts = df[TARGET].value_counts()
    class_ratio = class_counts / class_counts.sum()
    print("Class counts:\n", class_counts)
    print("Class ratio:\n", class_ratio)

    plt.figure(figsize=(5,4))
    sns.barplot(x=class_counts.index, y=class_counts.values, color="#7cc7ff")
    plt.title("Class balance")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.show()
    

Correlation heatmap (numeric)

if len(num_cols) > 1:
      corr = df[num_cols].corr()
      plt.figure(figsize=(8,6))
      sns.heatmap(corr, annot=False, cmap="coolwarm", center=0, square=True)
      plt.title("Correlation heatmap (numeric features)")
      plt.tight_layout()
      plt.show()
    

Optional: focused pairplot

plot_cols = num_cols[:5]
    if len(plot_cols) > 1:
      sns.pairplot(df[plot_cols], corner=True, diag_kind="kde")
      plt.suptitle("Pairplot (sampled numeric features)", y=1.02)
      plt.show()
    

Business interpretation prompts:

  • Which numeric features show heavy skew or many outliers?
  • Are any categories overwhelmingly dominant or too sparse?
  • Is the target imbalanced? Plan for class_weight or resampling.
  • Do strong correlations hint at redundancy or leakage?

Step 4

Basic cleaning

Reason for next step: Handle missing values and tidy column names before modeling.

Explanation: Decision Trees tolerate unscaled numerics but still require complete rows; consistent names simplify code.

# check missingness
missing = df.isna().sum()
print(missing[missing > 0])

# example: drop or impute
# df = df.dropna()
# or:
# num_cols = df.select_dtypes(include=["number"]).columns
# df[num_cols] = df[num_cols].fillna(df[num_cols].median())

# optional: standardize column names
# df.columns = [c.strip().replace(" ", "_") for c in df.columns]

Step 5

Encode target and features

Reason for next step: Convert labels to numeric and separate predictors.

Explanation: Trees need numeric y; mapping keeps Positive/Negative interpretable.

y = df[TARGET].map({"Positive": 1, "Negative": 0})  # adjust mapping to your labels
X = df.drop(columns=[TARGET])

Step 6

Train/test split

Reason for next step: Reserve unseen data to estimate generalization.

Explanation: Stratification preserves class balance; fixed random_state keeps runs comparable.

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=RANDOM_STATE
)

Step 7

Fit baseline Decision Tree

Reason for next step: Establish a reference model before tuning.

Explanation: A simple tree highlights overfitting risk and sets a performance floor. We'll measure comprehensive classification metrics to fully evaluate model performance.

dt = DecisionTreeClassifier(random_state=RANDOM_STATE)
dt.fit(X_train, y_train)

y_pred = dt.predict(X_test)
y_pred_proba = dt.predict_proba(X_test)[:, 1]  # Probability for positive class

# Basic metrics
print("Confusion Matrix:")
print(metrics.confusion_matrix(y_test, y_pred))
print("\nClassification Report:")
print(metrics.classification_report(y_test, y_pred))

Comprehensive Classification Metrics

Beyond basic accuracy, we measure precision, recall, F1-score, and ROC-AUC to understand model behavior across different aspects:

# Calculate all classification metrics
accuracy = metrics.accuracy_score(y_test, y_pred)
precision = metrics.precision_score(y_test, y_pred)
recall = metrics.recall_score(y_test, y_pred)
f1 = metrics.f1_score(y_test, y_pred)
roc_auc = metrics.roc_auc_score(y_test, y_pred_proba)

print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"ROC-AUC:   {roc_auc:.4f}")

Understanding Each Metric

Accuracy: Overall correctness = (TP + TN) / (TP + TN + FP + FN). Use when classes are balanced. Can be misleading with imbalanced data (e.g., 95% negative class → predicting all negative gives 95% accuracy).

Precision: Of all positive predictions, how many were correct? = TP / (TP + FP). High precision means few false alarms. Critical when false positives are costly (e.g., spam filtering - don't want legitimate emails marked as spam).

Recall (Sensitivity): Of all actual positives, how many did we catch? = TP / (TP + FN). High recall means few missed cases. Critical when false negatives are costly (e.g., disease detection - don't want to miss sick patients).

F1 Score: Harmonic mean of precision and recall = 2 × (Precision × Recall) / (Precision + Recall). Balances both metrics; useful when you need good performance on both false positives and false negatives. Use when classes are imbalanced.

ROC-AUC: Area Under the Receiver Operating Characteristic curve. Measures model's ability to distinguish between classes across all classification thresholds. Ranges 0.5 (random) to 1.0 (perfect). Threshold-independent; robust to class imbalance. Higher values indicate better discrimination between positive and negative classes.

Visualizing ROC Curve

The ROC curve plots True Positive Rate (Recall) vs False Positive Rate at various threshold settings:

from sklearn.metrics import roc_curve, auc

# Calculate ROC curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc_calc = auc(fpr, tpr)

# Plot ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, 
         label=f'ROC curve (AUC = {roc_auc_calc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', 
         label='Random classifier (AUC = 0.50)')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate (Recall)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.show()

Metric Trade-offs and Selection

Scenario Primary Metric Reasoning
Balanced classes, equal cost for both error types Accuracy Simple, intuitive, reflects overall correctness
Minimize false positives (e.g., spam detection) Precision High precision = fewer legitimate items flagged
Minimize false negatives (e.g., cancer screening) Recall High recall = catch all positive cases
Imbalanced classes, care about both errors F1 Score Balances precision and recall
Need threshold-independent assessment ROC-AUC Evaluates all possible classification thresholds
Highly imbalanced (e.g., fraud: 0.1% positive) Precision-Recall AUC ROC-AUC can be optimistic; PR-AUC better reflects performance

Best Practice: Always report multiple metrics. A single metric rarely tells the full story. For production systems, define acceptable thresholds for each metric based on business requirements.


Step 8

Inspect feature importance

Reason for next step: Understand which predictors drive splits.

Explanation: Importance scores are model-specific and help with interpretation or pruning.

importances = pd.Series(dt.feature_importances_, index=X.columns)
print(importances.sort_values(ascending=False))

Tree Pruning Techniques

Pruning reduces overfitting by limiting tree growth (pre-pruning) or removing branches after construction (post-pruning).

Pre-Pruning Parameters

Parameter Effect Example Use Case
max_depth Limits tree height max_depth=5 Prevents deep, overfit trees
min_samples_split Min samples to split node min_samples_split=10 Avoids splits on tiny subsets
min_samples_leaf Min samples in leaf min_samples_leaf=5 Ensures leaf has sufficient support
max_features Features per split max_features='sqrt' Adds randomness, reduces correlation

Post-Pruning (Cost Complexity Pruning)

Cost complexity pruning uses ccp_alpha to penalize tree complexity. Higher alpha = more aggressive pruning.

# Find optimal alpha
path = dt.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas

# Train trees with different alphas
clfs = []
for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=RANDOM_STATE, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)

# Plot accuracy vs alpha
train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

plt.figure(figsize=(10, 5))
plt.plot(ccp_alphas, train_scores, marker='o', label='Train', drawstyle="steps-post")
plt.plot(ccp_alphas, test_scores, marker='o', label='Test', drawstyle="steps-post")
plt.xlabel("ccp_alpha (complexity parameter)")
plt.ylabel("Accuracy")
plt.title("Accuracy vs Cost Complexity Parameter")
plt.legend()
plt.show()

# Select optimal alpha (max test accuracy)
optimal_idx = np.argmax(test_scores)
optimal_alpha = ccp_alphas[optimal_idx]
print(f"Optimal ccp_alpha: {optimal_alpha:.4f}")

# Build final pruned tree
pruned_tree = DecisionTreeClassifier(random_state=RANDOM_STATE, ccp_alpha=optimal_alpha)
pruned_tree.fit(X_train, y_train)
print(f"Pruned tree test accuracy: {pruned_tree.score(X_test, y_test):.4f}")

Step 9

Hyperparameter tuning

Reason for next step: Control depth/complexity to balance bias and variance.

Explanation: Grid search over depth and leaf constraints typically stabilizes generalization.

Understanding Impurity Criteria

Gini Impurity

Formula: Gini = 1 - Σ(pi²) where pi is proportion of class i

Interpretation: Measures probability of incorrect classification; ranges 0 (pure) to 0.5 (binary, balanced)

Computational note: Faster to compute; default in scikit-learn

Example: If node has 70% class A, 30% class B: Gini = 1 - (0.7² + 0.3²) = 0.42

Entropy (Information Gain)

Formula: Entropy = -Σ(pi × log₂(pi))

Interpretation: Measures disorder/uncertainty; ranges 0 (pure) to 1 (binary, balanced)

Connection: Rooted in information theory (Shannon entropy); measures bits needed to encode class labels

Example: For 70%/30% split: Entropy = -(0.7×log₂(0.7) + 0.3×log₂(0.3)) ≈ 0.88

When to Use Each

  • Gini: Default choice; computationally efficient; tends toward isolating most frequent class
  • Entropy: Slightly more balanced splits; better when all classes matter equally; ~2-3% slower
  • Practical advice: Try both; performance difference usually < 2% on most datasets

Extended Parameter Grid Options

Parameter Options Purpose
criterion ['gini', 'entropy'] Controls split quality metric
max_depth [None, 3, 5, 7, 9, 12, 15] Limits tree height; None = unlimited
min_samples_split [2, 5, 10, 20, 50] Min samples to attempt split
min_samples_leaf [1, 2, 4, 8, 10] Min samples required in leaf nodes
max_features [None, 'sqrt', 'log2', 0.5, 0.8] Features considered per split; None=all
min_impurity_decrease [0.0, 0.001, 0.01, 0.05] Min decrease to justify split
class_weight [None, 'balanced', {0:1, 1:3}] Handles imbalanced classes
splitter ['best', 'random'] Split selection strategy; 'random' adds stochasticity

Parameter Interaction Guide

Depth parameters (max_depth, max_features) control tree height; split parameters (min_samples_split, min_impurity_decrease) control when splits happen; leaf parameters (min_samples_leaf) control terminal nodes.

Best Practices:

  • Start with max_depth tuning (biggest impact on overfitting)
  • Use min_samples_leaf with small datasets (< 1000 rows)
  • Set max_features='sqrt' for high-dimensional data (> 50 features)
  • Use class_weight='balanced' for imbalanced targets
  • Computational cost: max_depth × (# split candidates) × (# feature subsets)

Grid Search Examples

Quick Search (faster, fewer combinations):

param_grid_quick = {
    "criterion": ["gini"],
    "max_depth": [3, 5, 7],
    "min_samples_leaf": [5, 10]
}

Comprehensive Search (thorough, slower):

param_grid_full = {
    "criterion": ["gini", "entropy"],
    "max_depth": [None, 3, 5, 7, 9, 12],
    "min_samples_split": [2, 5, 10, 20],
    "min_samples_leaf": [1, 2, 4, 8],
    "max_features": [None, "sqrt", "log2"],
    "min_impurity_decrease": [0.0, 0.001, 0.01]
}

grid = GridSearchCV(
    DecisionTreeClassifier(random_state=RANDOM_STATE),
    param_grid_full,
    cv=5,
    scoring="accuracy",
    n_jobs=-1,
)
grid.fit(X_train, y_train)
print("Best params", grid.best_params_)
print("CV accuracy", grid.best_score_)

best = grid.best_estimator_
print("Test accuracy (tuned)", metrics.accuracy_score(y_test, best.predict(X_test)))

Step 10

Cross-validation baseline

Reason for next step: Get a dataset-wide stability estimate beyond one split.

Understanding K-Fold Cross-Validation

Why Cross-Validation?

Single train/test split gives ONE performance estimate, which varies based on which rows land in test set. Cross-validation runs MULTIPLE splits, averaging results for stable, reliable estimate.

How 5-Fold CV Works (Step-by-Step)

  1. Split dataset into 5 equal parts (folds): [Fold1][Fold2][Fold3][Fold4][Fold5]
  2. Iteration 1: Train on Folds 1,2,3,4 → Test on Fold 5 → Record accuracy
  3. Iteration 2: Train on Folds 1,2,3,5 → Test on Fold 4 → Record accuracy
  4. Iteration 3: Train on Folds 1,2,4,5 → Test on Fold 3 → Record accuracy
  5. Iteration 4: Train on Folds 1,3,4,5 → Test on Fold 2 → Record accuracy
  6. Iteration 5: Train on Folds 2,3,4,5 → Test on Fold 1 → Record accuracy
  7. Report: Mean accuracy ± standard deviation across all 5 runs

Benefits

  • Every data point used for both training AND testing (but never simultaneously)
  • Reduces variance from random split luck
  • Detects overfitting: large gap between train and CV scores indicates problem

cross_val_score Function Parameters (Detailed)

Parameter Type Description
estimator Unfitted model object The classifier/regressor to evaluate; must have fit() and predict() methods
Example: DecisionTreeClassifier(random_state=1)
X Array-like (n_samples, n_features) Complete feature matrix (do NOT pre-split); CV handles splitting internally
y Array-like (n_samples,) Complete target vector; must match X row count
cv int or cross-validator Number of folds (default 5); or use StratifiedKFold, TimeSeriesSplit, etc.
Example: cv=5 or cv=StratifiedKFold(n_splits=10, shuffle=True)
scoring string or callable Metric to compute; common options: 'accuracy' (default for classifiers), 'f1' (harmonic mean), 'f1_macro' (unweighted class average), 'roc_auc', 'precision', 'recall', 'neg_mean_squared_error' (regression)
n_jobs int Parallel processing; -1 uses all CPU cores; 1 (default) runs sequentially

Accessing Individual Fold Scores

cv_scores = cross_val_score(
    DecisionTreeClassifier(random_state=RANDOM_STATE),
    X, y,
    cv=5,
    scoring="accuracy",
    n_jobs=-1,
)
print("Individual fold accuracies:", cv_scores)
print(f"Mean CV accuracy: {cv_scores.mean():.4f}")
print(f"Std deviation: {cv_scores.std():.4f}")
print(f"95% confidence interval: {cv_scores.mean():.4f} ± {1.96 * cv_scores.std():.4f}")

Choosing Number of Folds

Dataset Size Recommended Folds Reasoning
< 100 rows Leave-One-Out CV (cv=n) Maximize training data per fold
100-1000 rows 5-10 folds Balance between bias and computational cost
1000-10000 rows 5 folds (default) Good trade-off; industry standard
> 10000 rows 3-5 folds Reduces computation; still reliable
Time-series data TimeSeriesSplit Respects temporal ordering

Final note: Higher k = less bias (more training data per fold) but higher variance (fewer test samples per fold) and longer runtime. k=5 or k=10 are empirically robust choices.


Step 11

Visualize the tree

Reason for next step: Provide interpretable structure for stakeholders.

Explanation: Plotting shallow trees reveals decision logic; limit depth for readability.

plt.figure(figsize=(12, 6))
plot_tree(dt, feature_names=X.columns, class_names=["Negative", "Positive"],
          filled=True, max_depth=3, fontsize=8)
plt.tight_layout()
plt.show()

Step 12

Report and reuse

Reason for next step: Summarize results for quizzes or production handoff.

Explanation: Capture all classification metrics, important features, and chosen hyperparameters so others can replicate and assess model quality comprehensively.

# Generate predictions for both baseline and tuned models
y_pred_baseline = dt.predict(X_test)
y_pred_tuned = best.predict(X_test)
y_pred_proba_baseline = dt.predict_proba(X_test)[:, 1]
y_pred_proba_tuned = best.predict_proba(X_test)[:, 1]

# Comprehensive metrics summary
summary = {
    "baseline_metrics": {
        "accuracy": metrics.accuracy_score(y_test, y_pred_baseline),
        "precision": metrics.precision_score(y_test, y_pred_baseline),
        "recall": metrics.recall_score(y_test, y_pred_baseline),
        "f1_score": metrics.f1_score(y_test, y_pred_baseline),
        "roc_auc": metrics.roc_auc_score(y_test, y_pred_proba_baseline),
    },
    "tuned_metrics": {
        "accuracy": metrics.accuracy_score(y_test, y_pred_tuned),
        "precision": metrics.precision_score(y_test, y_pred_tuned),
        "recall": metrics.recall_score(y_test, y_pred_tuned),
        "f1_score": metrics.f1_score(y_test, y_pred_tuned),
        "roc_auc": metrics.roc_auc_score(y_test, y_pred_proba_tuned),
    },
    "best_params": grid.best_params_,
    "top_5_features": importances.sort_values(ascending=False).head(5).to_dict(),
}

# Print formatted summary
import json
print(json.dumps(summary, indent=2))

# Export to file for documentation
with open('model_evaluation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)