Skip to main content

Bayesian GMM (RxGMM) - Bayesian Gaussian Mixture Model

API Name: RxGMMModel
Mathematical Model: Heteroscedastic Gaussian Classifier (HGC)
Bayesian GMM (also known as RxGMM in the codebase) is a Bayesian classification model with class-specific covariance matrices, making it more flexible than Bayesian LDA for modeling complex class distributions. Implemented using RxInfer.jl’s reactive message passing.
Bayesian GMM (RxGMM) is currently implemented in NimbusSDK.jl and ready for production BCI applications. GMM is widely recognized in machine learning, and “Bayesian” signals our uncertainty quantification and posterior probability outputs.

Overview

Bayesian GMM extends beyond traditional Gaussian classifiers by allowing each class to have its own covariance structure:
  • Class-specific covariances (unlike RxLDA’s shared covariance)
  • More flexible modeling of heterogeneous distributions
  • Posterior probability distributions with uncertainty quantification
  • Fast inference (~15-25ms per trial)
  • Training and calibration support
  • Batch and streaming inference modes

When to Use Bayesian GMM

Bayesian GMM is ideal for:
  • Complex, overlapping class distributions
  • Classes with significantly different variances
  • P300 detection (target/non-target with different spreads)
  • When Bayesian LDA accuracy is unsatisfactory
  • When you need maximum flexibility
Use Bayesian LDA instead if:
  • Classes are well-separated and have similar spreads
  • Speed is critical (Bayesian LDA is faster)
  • Training data is limited (Bayesian LDA needs less data)
  • Memory is constrained (Bayesian LDA uses less memory)

Model Architecture

Mathematical Foundation (Heteroscedastic Gaussian Classifier)

Bayesian GMM implements a Heteroscedastic Gaussian Classifier (HGC), which models class-conditional distributions with class-specific precision matrices:
p(x | y=k) = N(μ_k, W_k^-1)
Where:
  • μ_k = mean vector for class k
  • W_k = class-specific precision matrix (different for each class)
  • Allows different covariance structures per class
Key Difference from Bayesian LDA: Each class can have its own covariance structure, making the model more flexible but also more parameter-heavy.

Model Structure

struct RxGMMModel <: BCIModel
    means::Vector{Vector{Float64}}       # Class means [μ₁, μ₂, ..., μₖ]
    precisions::Vector                   # Class-specific precisions [W₁, W₂, ..., Wₖ]
    metadata::ModelMetadata              # Model info
end

RxInfer Implementation

Learning Phase:
@model function RxGMM_learning_model(y, labels, n_features, n_classes)
    # Priors on class means
    for k in 1:n_classes
        m[k] ~ MvNormal(0, 10*I)
    end
    
    # Priors on class-specific precisions
    for k in 1:n_classes
        W[k] ~ Wishart(n_features + 5, I)  # Each class has its own W
    end
    
    # Likelihood
    for i in eachindex(y)
        k = labels[i]
        y[i] ~ MvNormal(m[k], inv(W[k]))  # Class-specific precision
    end
end

Usage

1. Load Pre-trained Model

using NimbusSDK

# Authenticate
NimbusSDK.install_core("nbci_live_your_key")

# Load from Nimbus model zoo
model = load_model(RxGMMModel, "p300_gmm_v1")

println("Model loaded:")
println("  Features: $(get_n_features(model))")
println("  Classes: $(get_n_classes(model))")
println("  Paradigm: $(get_paradigm(model))")

2. Train Custom Model

using NimbusSDK

# Prepare training data with labels
train_features = erp_features  # (8 × 200 × 150) - preprocessed P300 features
train_labels = [1, 2, 1, 2, ...]  # 150 labels (1=target, 2=non-target)

train_data = BCIData(
    train_features,
    BCIMetadata(
        sampling_rate = 250.0,
        paradigm = :p300,
        feature_type = :erp,
        n_features = 8,
        n_classes = 2,
        chunk_size = nothing
    ),
    train_labels  # Required for training!
)

# Train RxGMM model
model = train_model(
    RxGMMModel,
    train_data;
    iterations = 50,        # Inference iterations
    showprogress = true,    # Show progress bar
    name = "my_p300_gmm",
    description = "P300 binary classifier with RxGMM"
)

# Save for later use
save_model(model, "my_p300_gmm.jld2")
Training Parameters:
  • iterations: Number of variational inference iterations (default: 50)
    • More iterations = better convergence, typical range: 50-100
  • showprogress: Display progress bar during training
  • name: Model identifier
  • description: Model description

3. Subject-Specific Calibration

# Load base model
base_model = load_model(RxGMMModel, "p300_baseline_v1")

# Collect calibration trials (10-20 per class)
calib_data = BCIData(calib_features, metadata, calib_labels)

# Calibrate model
personalized_model = calibrate_model(
    base_model,
    calib_data;
    iterations = 20  # Fewer iterations needed
)

save_model(personalized_model, "subject_001_p300_calibrated.jld2")

4. Batch Inference

# Prepare test data
test_data = BCIData(test_features, metadata, test_labels)

# Run batch inference
results = predict_batch(model, test_data; iterations=10)

# Analyze results
println("Predictions: ", results.predictions)
println("Mean confidence: ", mean(results.confidences))

# Calculate metrics
accuracy = sum(results.predictions .== test_labels) / length(test_labels)
println("Accuracy: $(round(accuracy * 100, digits=1))%")

5. Streaming Inference

# Initialize streaming session
session = init_streaming(model, metadata_with_chunk_size)

# Process chunks
for chunk in eeg_feature_stream
    result = process_chunk(session, chunk; iterations=10)
    println("Chunk: pred=$(result.prediction), conf=$(round(result.confidence, digits=3))")
end

# Finalize trial
final_result = finalize_trial(session; method=:weighted_vote)
println("Final: pred=$(final_result.prediction), conf=$(round(final_result.confidence, digits=3))")

Training Requirements

Data Requirements

  • Minimum: 40 trials per class (80 total for 2-class)
  • Recommended: 80+ trials per class
  • For calibration: 10-20 trials per class
Bayesian GMM requires at least 2 observations per class to estimate class-specific statistics. Training will fail if any class has fewer than 2 observations.

Feature Requirements

Bayesian GMM expects preprocessed features, not raw EEG: Required preprocessing:
  • Bandpass filtering (paradigm-specific)
  • Artifact removal
  • Feature extraction (CSP, ERP amplitude, bandpower, etc.)
  • Proper temporal aggregation
NOT accepted:
  • Raw EEG channels
  • Unfiltered data
See Preprocessing Requirements.

Performance Characteristics

Computational Performance

OperationLatencyNotes
Training15-40 seconds50 iterations, 100 trials per class
Calibration8-20 seconds20 iterations, 20 trials per class
Batch Inference15-25ms per trial10 iterations
Streaming Chunk15-25ms10 iterations per chunk
Slightly slower than RxLDA due to class-specific covariances.

Classification Accuracy

ParadigmClassesTypical AccuracyWhen to Use Bayesian GMM
P3002 (Target/Non-target)85-95%Target/non-target have different variances
Motor Imagery2-470-85%When Bayesian LDA accuracy insufficient
SSVEP2-685-98%Complex frequency responses
Bayesian GMM typically provides 2-5% higher accuracy than Bayesian LDA when class covariances differ significantly, at the cost of ~5-10ms additional latency.

Model Inspection

View Model Parameters

# Class means
println("Class means:")
for (k, mean) in enumerate(model.means)
    println("  Class $k: ", mean)
end

# Class-specific precisions
println("\nClass-specific precision matrices:")
for (k, prec) in enumerate(model.precisions)
    println("  Class $k (first 3x3):")
    println(prec[1:3, 1:3])
end

# Compare covariances across classes
println("\nCovariance structure comparison:")
for k in 1:length(model.precisions)
    cov_k = inv(model.precisions[k])
    println("  Class $k variance (diagonal): ", diag(cov_k))
end

Visualize Class Differences

using Plots

# Compare class covariances
for k in 1:length(model.precisions)
    cov_k = inv(model.precisions[k])
    heatmap(cov_k, title="Class $k Covariance", colorbar=true)
end

Advantages & Limitations

Advantages

Flexible Modeling: Each class has its own covariance
Better for Complex Data: Handles heterogeneous distributions
Higher Accuracy: 2-5% improvement when classes differ significantly
Uncertainty Quantification: Full Bayesian posteriors
Production-Ready: Battle-tested in P300 applications

Limitations

More Parameters: Requires more training data than RxLDA
Slower Inference: ~15-25ms vs ~10-15ms for RxLDA
Higher Memory: Stores n_classes precision matrices
More Complex: Longer training time

Comparison: Bayesian GMM vs Bayesian LDA

AspectBayesian GMM (RxGMM)Bayesian LDA (RxLDA)
Precision MatrixClass-specificShared (one for all)
Mathematical ModelHeteroscedastic Gaussian Classifier (HGC)Pooled Gaussian Classifier (PGC)
Training SpeedSlowerFaster
Inference Speed15-25ms10-15ms
FlexibilityHighModerate
Data RequirementsMoreLess
Memory UsageHigherLower
Best ForHeterogeneous classesHomogeneous classes
Accuracy Gain+2-5% (when applicable)Baseline

Decision Tree

Is speed critical (&lt;15ms)?
├─ Yes → Use Bayesian LDA
└─ No → Do classes have similar spreads?
    ├─ Yes → Use Bayesian LDA (faster, same accuracy)
    └─ No → Use Bayesian GMM (more accurate)

Practical Examples

P300 Detection

# Bayesian GMM is ideal for P300 where target/non-target have different variances

# Train on P300 data
p300_data = BCIData(erp_features, p300_metadata, labels)  # 1=target, 2=non-target
model = train_model(RxGMMModel, p300_data; iterations=50)

# Test
results = predict_batch(model, test_data)

# Calculate sensitivity/specificity
targets = test_labels .== 1
target_preds = results.predictions .== 1

sensitivity = sum(target_preds .& targets) / sum(targets)
specificity = sum(.!target_preds .& .!targets) / sum(.!targets)

println("Sensitivity: $(round(sensitivity * 100, digits=1))%")
println("Specificity: $(round(specificity * 100, digits=1))%")

When Bayesian LDA Fails

# Try Bayesian LDA first
rxlda_model = train_model(RxLDAModel, train_data)
rxlda_results = predict_batch(rxlda_model, test_data)
rxlda_acc = sum(rxlda_results.predictions .== test_labels) / length(test_labels)

# If accuracy insufficient, try Bayesian GMM
if rxlda_acc < 0.75
    println("Bayesian LDA accuracy low ($(round(rxlda_acc*100, digits=1))%), trying Bayesian GMM...")
    rxgmm_model = train_model(RxGMMModel, train_data)
    rxgmm_results = predict_batch(rxgmm_model, test_data)
    rxgmm_acc = sum(rxgmm_results.predictions .== test_labels) / length(test_labels)
    
    println("Bayesian GMM accuracy: $(round(rxgmm_acc*100, digits=1))%")
    println("Improvement: +$(round((rxgmm_acc-rxlda_acc)*100, digits=1))%")
end

Next Steps

References

Implementation: Theory:
  • McLachlan, G. J., & Peel, D. (2000). “Finite Mixture Models”
  • Bishop, C. M. (2006). “Pattern Recognition and Machine Learning” (Chapter 9)
  • Heteroscedastic Gaussian Classifier (HGC) with class-specific covariances
BCI Applications:
  • Farwell, L. A., & Donchin, E. (1988). “Talking off the top of your head”
  • Lotte et al. (2018). “A review of classification algorithms for EEG-based BCI”