Skip to main content

Bayesian GMM - Bayesian Gaussian Mixture Model

Python: NimbusGMM | Julia: RxGMMModel
Mathematical Model: Heteroscedastic Gaussian Classifier (HGC)
Bayesian GMM is a Bayesian classification model with class-specific covariance matrices, making it more flexible than Bayesian LDA for modeling complex class distributions.
Available in Both SDKs:
  • Python SDK: NimbusGMM class (sklearn-compatible)
  • Julia SDK: RxGMMModel (RxInfer.jl-based)
Both implementations provide class-specific covariances for flexible modeling.

Overview

Bayesian GMM extends beyond traditional Gaussian classifiers by allowing each class to have its own covariance structure:
  • Class-specific covariances (unlike Bayesian LDA’s shared covariance)
  • More flexible modeling of heterogeneous distributions
  • Posterior probability distributions with uncertainty quantification
  • Fast inference (~15-25ms per trial)
  • Training and calibration support
  • Batch and streaming inference modes

Quick Start

from nimbus_bci import NimbusGMM
import numpy as np

# Create and fit classifier
clf = NimbusGMM(mu_scale=3.0)
clf.fit(X_train, y_train)

# Predict with uncertainty
predictions = clf.predict(X_test)
probabilities = clf.predict_proba(X_test)

# Better for P300 and overlapping distributions

When to Use Bayesian GMM

Bayesian GMM is ideal for:
  • Complex, overlapping class distributions
  • Classes with significantly different variances
  • P300 detection (target/non-target with different spreads)
  • When Bayesian LDA accuracy is unsatisfactory
  • When you need maximum flexibility
Use Bayesian LDA instead if:
  • Classes are well-separated and have similar spreads
  • Speed is critical (Bayesian LDA is faster)
  • Training data is limited (Bayesian LDA needs less data)
  • Memory is constrained (Bayesian LDA uses less memory)

Model Architecture

Mathematical Foundation (Heteroscedastic Gaussian Classifier)

Bayesian GMM implements a Heteroscedastic Gaussian Classifier (HGC), which models class-conditional distributions with class-specific precision matrices:
p(x | y=k) = N(μ_k, W_k^-1)
Where:
  • μ_k = mean vector for class k
  • W_k = class-specific precision matrix (different for each class)
  • Allows different covariance structures per class
Key Difference from Bayesian LDA: Each class can have its own covariance structure, making the model more flexible but also more parameter-heavy.

Hyperparameters

Bayesian GMM supports configurable hyperparameters for optimal performance tuning: Available Hyperparameters (training):
ParameterTypeDefaultRangeDescription
dof_offsetInt2[1, 5]Degrees of freedom offset for Wishart priors
mean_prior_precisionFloat640.01[0.001, 0.1]Prior precision for class means
Parameter Effects:
  • dof_offset: Controls regularization strength
    • Lower values (1) → More data-driven, less regularization
    • Higher values (3-5) → More regularization, more conservative
  • mean_prior_precision: Controls prior strength on class means
    • Lower values (0.001) → Weaker prior, trusts data more
    • Higher values (0.05-0.1) → Stronger prior, more regularization

Model Structure

struct RxGMMModel <: BCIModel
    mean_posteriors::Vector        # MvNormal posteriors for class means
    precision_posteriors::Vector   # Wishart posteriors for class precisions
    priors::Vector{Float64}        # Empirical class priors
    metadata::ModelMetadata        # Model info
    dof_offset::Int                # Degrees of freedom offset (training)
    mean_prior_precision::Float64  # Mean prior precision (training)
end

RxInfer Implementation

Learning Phase:
@model function RxGMM_learning_model(y, labels, n_features, n_classes)
    # Priors on class means
    for k in 1:n_classes
        m[k] ~ MvNormal(0, 10*I)
    end
    
    # Priors on class-specific precisions
    for k in 1:n_classes
        W[k] ~ Wishart(n_features + 5, I)  # Each class has its own W
    end
    
    # Likelihood
    for i in eachindex(y)
        k = labels[i]
        y[i] ~ MvNormal(m[k], inv(W[k]))  # Class-specific precision
    end
end

Usage

1. Load Pre-trained Model

Python SDK: The Python SDK (nimbus-bci) trains models locally. See Python SDK Quickstart for training examples.
from nimbus_bci import NimbusGMM
import pickle

# Python SDK: Train locally, no model zoo
clf = NimbusGMM(mu_scale=3.0)
clf.fit(X_train, y_train)

# Save and load
with open("my_p300_gmm.pkl", "wb") as f:
    pickle.dump(clf, f)

with open("my_p300_gmm.pkl", "rb") as f:
    clf = pickle.load(f)

print(f"Model info: {clf.classes_} classes, {clf.n_features_in_} features")

2. Train Custom Model

from nimbus_bci import NimbusGMM
from nimbus_bci.compat import from_mne_epochs
import mne
import pickle

# Load P300 data
raw = mne.io.read_raw_fif("p300_oddball.fif", preload=True)
raw.filter(0.5, 10)

events = mne.find_events(raw)
event_id = {'target': 1, 'non-target': 2}
epochs = mne.Epochs(raw, events, event_id, tmin=-0.2, tmax=0.8, preload=True)

# Extract features
X, y = from_mne_epochs(epochs)

# Train NimbusGMM with default hyperparameters
clf = NimbusGMM()
clf.fit(X, y)

# Or train with custom hyperparameters
clf_tuned = NimbusGMM(
    mu_scale=3.0,        # Prior strength
    sigma_scale=1.0      # Covariance regularization
)
clf_tuned.fit(X, y)

# Save
with open("my_p300_gmm.pkl", "wb") as f:
    pickle.dump(clf_tuned, f)

print(f"Training accuracy: {clf_tuned.score(X, y):.1%}")
Training Parameters:
  • iterations: Number of variational inference iterations (default: 50)
    • More iterations = better convergence, typical range: 50-100
  • showprogress: Display progress bar during training
  • name: Model identifier
  • description: Model description
  • dof_offset: Degrees of freedom offset (default: 2, range: [1, 5])
  • mean_prior_precision: Prior precision for means (default: 0.01, range: [0.001, 0.1])

3. Subject-Specific Calibration

from nimbus_bci import NimbusGMM
import pickle

# Load baseline model
with open("p300_baseline.pkl", "rb") as f:
    base_clf = pickle.load(f)

# Collect calibration trials (10-20 per class)
X_calib, y_calib = collect_calibration_trials()

# Personalize using online learning
personalized_clf = NimbusGMM()
personalized_clf.fit(X_baseline, y_baseline)

for _ in range(10):
    personalized_clf.partial_fit(X_calib, y_calib)

# Save
with open("subject_001_p300_calibrated.pkl", "wb") as f:
    pickle.dump(personalized_clf, f)
Calibration Benefits:
  • Requires only 10-20 trials per class
  • Faster than training from scratch
  • Adapts to subject-specific characteristics
  • Hyperparameters preserved: calibrate_model() automatically uses the same hyperparameters as the base model (v0.2.0+)

4. Batch Inference

import numpy as np

# Run batch inference
predictions = clf.predict(X_test)
probabilities = clf.predict_proba(X_test)
confidences = np.max(probabilities, axis=1)

# Analyze results
print(f"Predictions: {predictions}")
print(f"Mean confidence: {np.mean(confidences):.3f}")

# Calculate metrics
accuracy = np.mean(predictions == y_test)
print(f"Accuracy: {accuracy * 100:.1f}%")

5. Streaming Inference

For detailed Python streaming examples, see Python SDK Streaming Inference.
from nimbus_bci import StreamingSession

# Initialize streaming session
session = StreamingSession(clf.model_, metadata_with_chunk_size)

# Process chunks
for chunk in eeg_feature_stream:
    result = session.process_chunk(chunk)
    print(f"Chunk: pred={result.prediction}, conf={result.confidence:.3f}")

# Finalize trial
final_result = session.finalize_trial()
print(f"Final: pred={final_result.prediction}, conf={final_result.confidence:.3f}")

session.reset()

Hyperparameter Tuning (v0.2.0+)

Fine-tune Bayesian GMM for optimal performance on your specific dataset.

When to Tune Hyperparameters

Consider tuning when:
  • Default performance is unsatisfactory
  • You have specific data characteristics (very noisy or very clean)
  • You have limited or extensive training data
  • Working with complex, overlapping class distributions
  • P300 or other paradigms with heterogeneous class variances

Tuning Strategies

For High SNR / Clean Data / Many Trials

Use lower regularization to let the data drive the model:
from nimbus_bci import NimbusGMM

# Lower regularization for clean data
clf = NimbusGMM(
    mu_scale=1.0,      # Weaker prior
    sigma_scale=0.1    # Less covariance regularization
)
clf.fit(X_train, y_train)
Use when:
  • SNR > 5 dB
  • 100+ trials per class
  • Clean, artifact-free data
  • Well-controlled experimental conditions

For Low SNR / Noisy Data / Few Trials

Use higher regularization for stability (especially important for GMM with class-specific covariances):
from nimbus_bci import NimbusGMM

# Higher regularization for noisy data
clf = NimbusGMM(
    mu_scale=5.0,      # Stronger prior
    sigma_scale=10.0   # More covariance regularization
)
clf.fit(X_train, y_train)
Use when:
  • SNR < 2 dB
  • 40-80 trials per class
  • Noisy data or limited artifact removal
  • Challenging recording conditions
  • Risk of overfitting to class-specific noise

Balanced / Default Settings

The defaults work well for most scenarios:
model = train_model(
    RxGMMModel,
    train_data;
    iterations = 50,
    dof_offset = 2,                    # Balanced (default)
    mean_prior_precision = 0.01        # Balanced (default)
)
Use when:
  • Moderate SNR (2-5 dB)
  • 80-150 trials per class
  • Standard BCI recording conditions
  • Starting point for experimentation

P300-Specific Tuning

For P300 paradigms where target/non-target classes have different variances:
# P300 often benefits from GMM's class-specific covariances
# Use moderate regularization to avoid overfitting
model = train_model(
    RxGMMModel,
    p300_train_data;
    iterations = 50,
    dof_offset = 2,                    # Standard regularization
    mean_prior_precision = 0.02        # Slightly stronger than default
)

Hyperparameter Search Example

Systematically search for optimal hyperparameters:
using NimbusSDK

# Define search grid
dof_values = [1, 2, 3, 4]
prior_values = [0.001, 0.01, 0.03, 0.05]

# Split data
train_data, val_data = split_data(all_data, ratio=0.8)

best_accuracy = 0.0
best_params = nothing

println("Searching hyperparameters for RxGMM...")
for dof in dof_values
    for prior in prior_values
        # Train model
        model = train_model(
            RxGMMModel,
            train_data;
            iterations = 50,
            dof_offset = dof,
            mean_prior_precision = prior,
            predictive_dof_offset = dof,
            showprogress = false
        )
        
        # Validate
        results = predict_batch(model, val_data)
        accuracy = sum(results.predictions .== val_data.labels) / length(val_data.labels)
        
        println("  dof=$dof, prior=$prior: $(round(accuracy*100, digits=1))%")
        
        if accuracy > best_accuracy
            best_accuracy = accuracy
            best_params = (dof=dof, prior=prior)
        end
    end
end

println("\nBest hyperparameters:")
println("  dof_offset: $(best_params.dof)")
println("  mean_prior_precision: $(best_params.prior)")
println("  Validation accuracy: $(round(best_accuracy*100, digits=1))%")

# Retrain with best params
final_model = train_model(
    RxGMMModel,
    all_data;
    iterations = 50,
    dof_offset = best_params.dof,
    mean_prior_precision = best_params.prior,
    predictive_dof_offset = best_params.dof
)

Quick Tuning Guidelines

Scenariodof_offsetmean_prior_precisionNotes
Excellent data quality10.001Minimal regularization
Good data quality2 (default)0.01 (default)Balanced approach
Moderate data quality2-30.01-0.03Slight regularization
Poor data quality3-40.05-0.1Strong regularization
Very limited trials40.1Maximum regularization
P300 target/non-target20.02Moderate, class-specific covariances helpful
Pro Tip: Bayesian GMM’s class-specific covariances can overfit to noise with poor data. When in doubt, start with defaults and increase regularization (dof_offset=3, mean_prior_precision=0.03) if you see overfitting.
Important: Always set predictive_dof_offset to match dof_offset for consistency between training and inference phases.

Training Requirements

Data Requirements

  • Minimum: 40 trials per class (80 total for 2-class)
  • Recommended: 80+ trials per class
  • For calibration: 10-20 trials per class
Bayesian GMM requires at least 2 observations per class to estimate class-specific statistics. Training will fail if any class has fewer than 2 observations.

Feature Normalization

Critical for cross-session BCI performance!Normalize your features before training for 15-30% accuracy improvement across sessions.
from sklearn.preprocessing import StandardScaler
import pickle

# Estimate normalization from training data
scaler = StandardScaler()
X_train_norm = scaler.fit_transform(X_train)

# Train with normalized features
clf = NimbusGMM()
clf.fit(X_train_norm, y_train)

# Save model and scaler together
with open("model_with_scaler.pkl", "wb") as f:
    pickle.dump({'model': clf, 'scaler': scaler}, f)

# Later: Apply same normalization to test data
X_test_norm = scaler.transform(X_test)
predictions = clf.predict(X_test_norm)
See the Feature Normalization guide for complete details.

Feature Requirements

Bayesian GMM expects preprocessed features, not raw EEG: Required preprocessing:
  • Bandpass filtering (paradigm-specific)
  • Artifact removal
  • Feature extraction (CSP, ERP amplitude, bandpower, etc.)
  • Proper temporal aggregation
NOT accepted:
  • Raw EEG channels
  • Unfiltered data
See Preprocessing Requirements.

Performance Characteristics

Computational Performance

OperationLatencyNotes
Training15-40 seconds50 iterations, 100 trials per class
Calibration8-20 seconds20 iterations, 20 trials per class
Batch Inference15-25ms per trial10 iterations
Streaming Chunk15-25ms10 iterations per chunk
Slightly slower than RxLDA due to class-specific covariances.

Classification Accuracy

ParadigmClassesTypical AccuracyWhen to Use Bayesian GMM
P3002 (Target/Non-target)85-95%Target/non-target have different variances
Motor Imagery2-470-85%When Bayesian LDA accuracy insufficient
SSVEP2-685-98%Complex frequency responses
Bayesian GMM typically provides 2-5% higher accuracy than Bayesian LDA when class covariances differ significantly, at the cost of ~5-10ms additional latency.

Model Inspection

View Model Parameters

import numpy as np

# Class means
print("Class means:")
for k, class_label in enumerate(clf.classes_):
    print(f"  Class {class_label}: {clf.model_['means'][k]}")

# Class-specific covariance matrices
print("\nClass-specific covariance matrices:")
for k, class_label in enumerate(clf.classes_):
    print(f"  Class {class_label} (first 3x3):")
    print(clf.model_['covariances'][k][:3, :3])

# Compare covariances across classes
print("\nCovariance structure comparison:")
for k, class_label in enumerate(clf.classes_):
    cov_k = clf.model_['covariances'][k]
    print(f"  Class {class_label} variance (diagonal): {np.diag(cov_k)}")

# Class priors
print("\nClass priors:")
for k, class_label in enumerate(clf.classes_):
    print(f"  Class {class_label}: {clf.model_['priors'][k]:.3f}")
Accessing model parameters: The SDK stores full posterior distributions (not just point estimates) for proper Bayesian inference. To get point estimates, use mean(posterior) to extract the mean of the posterior distribution. For precision matrices, use mean(precision_posterior) to get the expected precision matrix.

Visualize Class Differences

using Plots
using Distributions

# Compare class covariances
for k in 1:length(model.precision_posteriors)
    prec_point = mean(model.precision_posteriors[k])  # Extract point estimate
    cov_k = inv(prec_point)  # Convert precision to covariance
    heatmap(cov_k, title="Class $k Covariance", colorbar=true)
end

Advantages & Limitations

Advantages

Flexible Modeling: Each class has its own covariance
Better for Complex Data: Handles heterogeneous distributions
Higher Accuracy: 2-5% improvement when classes differ significantly
Uncertainty Quantification: Full Bayesian posteriors
Production-Ready: Battle-tested in P300 applications

Limitations

More Parameters: Requires more training data than RxLDA
Slower Inference: ~15-25ms vs ~10-15ms for RxLDA
Higher Memory: Stores n_classes precision matrices
More Complex: Longer training time

Comparison: Bayesian GMM vs Bayesian LDA

AspectBayesian GMM (RxGMM)Bayesian LDA (RxLDA)
Precision MatrixClass-specificShared (one for all)
Mathematical ModelHeteroscedastic Gaussian Classifier (HGC)Pooled Gaussian Classifier (PGC)
Training SpeedSlowerFaster
Inference Speed15-25ms10-15ms
FlexibilityHighModerate
Data RequirementsMoreLess
Memory UsageHigherLower
Best ForHeterogeneous classesHomogeneous classes
Accuracy Gain+2-5% (when applicable)Baseline

Decision Tree

Is speed critical (&lt;15ms)?
├─ Yes → Use Bayesian LDA
└─ No → Do classes have similar spreads?
    ├─ Yes → Use Bayesian LDA (faster, same accuracy)
    └─ No → Use Bayesian GMM (more accurate)

Practical Examples

P300 Detection

# Bayesian GMM is ideal for P300 where target/non-target have different variances

# Train on P300 data
p300_data = BCIData(erp_features, p300_metadata, labels)  # 1=target, 2=non-target
model = train_model(RxGMMModel, p300_data; iterations=50)

# Test
results = predict_batch(model, test_data)

# Calculate sensitivity/specificity
targets = test_labels .== 1
target_preds = results.predictions .== 1

sensitivity = sum(target_preds .& targets) / sum(targets)
specificity = sum(.!target_preds .& .!targets) / sum(.!targets)

println("Sensitivity: $(round(sensitivity * 100, digits=1))%")
println("Specificity: $(round(specificity * 100, digits=1))%")

When Bayesian LDA Fails

# Try Bayesian LDA first
rxlda_model = train_model(RxLDAModel, train_data)
rxlda_results = predict_batch(rxlda_model, test_data)
rxlda_acc = sum(rxlda_results.predictions .== test_labels) / length(test_labels)

# If accuracy insufficient, try Bayesian GMM
if rxlda_acc < 0.75
    println("Bayesian LDA accuracy low ($(round(rxlda_acc*100, digits=1))%), trying Bayesian GMM...")
    rxgmm_model = train_model(RxGMMModel, train_data)
    rxgmm_results = predict_batch(rxgmm_model, test_data)
    rxgmm_acc = sum(rxgmm_results.predictions .== test_labels) / length(test_labels)
    
    println("Bayesian GMM accuracy: $(round(rxgmm_acc*100, digits=1))%")
    println("Improvement: +$(round((rxgmm_acc-rxlda_acc)*100, digits=1))%")
end

Next Steps

References

Implementation: Theory:
  • McLachlan, G. J., & Peel, D. (2000). “Finite Mixture Models”
  • Bishop, C. M. (2006). “Pattern Recognition and Machine Learning” (Chapter 9)
  • Heteroscedastic Gaussian Classifier (HGC) with class-specific covariances
BCI Applications:
  • Farwell, L. A., & Donchin, E. (1988). “Talking off the top of your head”
  • Lotte et al. (2018). “A review of classification algorithms for EEG-based BCI”