Bayesian GMM - Bayesian Gaussian Mixture Model
Python: NimbusGMM | Julia: RxGMMModel
Mathematical Model: Heteroscedastic Gaussian Classifier (HGC)
Bayesian GMM is a Bayesian classification model with class-specific covariance matrices, making it more flexible than Bayesian LDA for modeling complex class distributions.
Available in Both SDKs:
- Python SDK:
NimbusGMM class (sklearn-compatible)
- Julia SDK:
RxGMMModel (RxInfer.jl-based)
Both implementations provide class-specific covariances for flexible modeling.
Overview
Bayesian GMM extends beyond traditional Gaussian classifiers by allowing each class to have its own covariance structure:
- ✅ Class-specific covariances (unlike Bayesian LDA’s shared covariance)
- ✅ More flexible modeling of heterogeneous distributions
- ✅ Posterior probability distributions with uncertainty quantification
- ✅ Fast inference (~15-25ms per trial)
- ✅ Training and calibration support
- ✅ Batch and streaming inference modes
Quick Start
from nimbus_bci import NimbusGMM
import numpy as np
# Create and fit classifier
clf = NimbusGMM(mu_scale=3.0)
clf.fit(X_train, y_train)
# Predict with uncertainty
predictions = clf.predict(X_test)
probabilities = clf.predict_proba(X_test)
# Better for P300 and overlapping distributions
using NimbusSDK
# Train model
model = train_model(
RxGMMModel,
train_data;
iterations=50
)
# Predict
results = predict_batch(model, test_data)
When to Use Bayesian GMM
Bayesian GMM is ideal for:
- Complex, overlapping class distributions
- Classes with significantly different variances
- P300 detection (target/non-target with different spreads)
- When Bayesian LDA accuracy is unsatisfactory
- When you need maximum flexibility
Use Bayesian LDA instead if:
- Classes are well-separated and have similar spreads
- Speed is critical (Bayesian LDA is faster)
- Training data is limited (Bayesian LDA needs less data)
- Memory is constrained (Bayesian LDA uses less memory)
Model Architecture
Mathematical Foundation (Heteroscedastic Gaussian Classifier)
Bayesian GMM implements a Heteroscedastic Gaussian Classifier (HGC), which models class-conditional distributions with class-specific precision matrices:
p(x | y=k) = N(μ_k, W_k^-1)
Where:
μ_k = mean vector for class k
W_k = class-specific precision matrix (different for each class)
- Allows different covariance structures per class
Key Difference from Bayesian LDA: Each class can have its own covariance structure, making the model more flexible but also more parameter-heavy.
Hyperparameters
Bayesian GMM supports configurable hyperparameters for optimal performance tuning:
Available Hyperparameters (training):
| Parameter | Type | Default | Range | Description |
|---|
dof_offset | Int | 2 | [1, 5] | Degrees of freedom offset for Wishart priors |
mean_prior_precision | Float64 | 0.01 | [0.001, 0.1] | Prior precision for class means |
Parameter Effects:
- dof_offset: Controls regularization strength
- Lower values (1) → More data-driven, less regularization
- Higher values (3-5) → More regularization, more conservative
- mean_prior_precision: Controls prior strength on class means
- Lower values (0.001) → Weaker prior, trusts data more
- Higher values (0.05-0.1) → Stronger prior, more regularization
Model Structure
struct RxGMMModel <: BCIModel
mean_posteriors::Vector # MvNormal posteriors for class means
precision_posteriors::Vector # Wishart posteriors for class precisions
priors::Vector{Float64} # Empirical class priors
metadata::ModelMetadata # Model info
dof_offset::Int # Degrees of freedom offset (training)
mean_prior_precision::Float64 # Mean prior precision (training)
end
RxInfer Implementation
Learning Phase:
@model function RxGMM_learning_model(y, labels, n_features, n_classes)
# Priors on class means
for k in 1:n_classes
m[k] ~ MvNormal(0, 10*I)
end
# Priors on class-specific precisions
for k in 1:n_classes
W[k] ~ Wishart(n_features + 5, I) # Each class has its own W
end
# Likelihood
for i in eachindex(y)
k = labels[i]
y[i] ~ MvNormal(m[k], inv(W[k])) # Class-specific precision
end
end
Usage
1. Load Pre-trained Model
Python SDK: The Python SDK (nimbus-bci) trains models locally. See Python SDK Quickstart for training examples.
from nimbus_bci import NimbusGMM
import pickle
# Python SDK: Train locally, no model zoo
clf = NimbusGMM(mu_scale=3.0)
clf.fit(X_train, y_train)
# Save and load
with open("my_p300_gmm.pkl", "wb") as f:
pickle.dump(clf, f)
with open("my_p300_gmm.pkl", "rb") as f:
clf = pickle.load(f)
print(f"Model info: {clf.classes_} classes, {clf.n_features_in_} features")
using NimbusSDK
# Authenticate
NimbusSDK.install_core("nbci_live_your_key")
# Load from Nimbus model zoo
model = load_model(RxGMMModel, "p300_gmm_v1")
println("Model loaded:")
println(" Features: $(get_n_features(model))")
println(" Classes: $(get_n_classes(model))")
println(" Paradigm: $(get_paradigm(model))")
2. Train Custom Model
from nimbus_bci import NimbusGMM
from nimbus_bci.compat import from_mne_epochs
import mne
import pickle
# Load P300 data
raw = mne.io.read_raw_fif("p300_oddball.fif", preload=True)
raw.filter(0.5, 10)
events = mne.find_events(raw)
event_id = {'target': 1, 'non-target': 2}
epochs = mne.Epochs(raw, events, event_id, tmin=-0.2, tmax=0.8, preload=True)
# Extract features
X, y = from_mne_epochs(epochs)
# Train NimbusGMM with default hyperparameters
clf = NimbusGMM()
clf.fit(X, y)
# Or train with custom hyperparameters
clf_tuned = NimbusGMM(
mu_scale=3.0, # Prior strength
sigma_scale=1.0 # Covariance regularization
)
clf_tuned.fit(X, y)
# Save
with open("my_p300_gmm.pkl", "wb") as f:
pickle.dump(clf_tuned, f)
print(f"Training accuracy: {clf_tuned.score(X, y):.1%}")
using NimbusSDK
# Prepare training data with labels
train_features = erp_features # (8 × 200 × 150) - preprocessed P300 features
train_labels = [1, 2, 1, 2, ...] # 150 labels (1=target, 2=non-target)
train_data = BCIData(
train_features,
BCIMetadata(
sampling_rate = 250.0,
paradigm = :p300,
feature_type = :erp,
n_features = 8,
n_classes = 2,
chunk_size = nothing
),
train_labels # Required for training!
)
# Train RxGMM model with default hyperparameters
model = train_model(
RxGMMModel,
train_data;
iterations = 50, # Inference iterations
showprogress = true, # Show progress bar
name = "my_p300_gmm",
description = "P300 binary classifier with RxGMM"
)
# Or train with custom hyperparameters (v0.2.0+)
model = train_model(
RxGMMModel,
train_data;
iterations = 50,
showprogress = true,
name = "my_p300_gmm_tuned",
description = "P300 classifier with tuned hyperparameters",
dof_offset = 2, # DOF offset (default: 2)
mean_prior_precision = 0.01 # Prior precision (default: 0.01)
)
# Save for later use
save_model(model, "my_p300_gmm.jld2")
Training Parameters:
iterations: Number of variational inference iterations (default: 50)
- More iterations = better convergence, typical range: 50-100
showprogress: Display progress bar during training
name: Model identifier
description: Model description
dof_offset: Degrees of freedom offset (default: 2, range: [1, 5])
mean_prior_precision: Prior precision for means (default: 0.01, range: [0.001, 0.1])
3. Subject-Specific Calibration
from nimbus_bci import NimbusGMM
import pickle
# Load baseline model
with open("p300_baseline.pkl", "rb") as f:
base_clf = pickle.load(f)
# Collect calibration trials (10-20 per class)
X_calib, y_calib = collect_calibration_trials()
# Personalize using online learning
personalized_clf = NimbusGMM()
personalized_clf.fit(X_baseline, y_baseline)
for _ in range(10):
personalized_clf.partial_fit(X_calib, y_calib)
# Save
with open("subject_001_p300_calibrated.pkl", "wb") as f:
pickle.dump(personalized_clf, f)
# Load base model
base_model = load_model(RxGMMModel, "p300_baseline_v1")
# Collect calibration trials (10-20 per class)
calib_data = BCIData(calib_features, metadata, calib_labels)
# Calibrate model
personalized_model = calibrate_model(
base_model,
calib_data;
iterations = 20 # Fewer iterations needed
)
save_model(personalized_model, "subject_001_p300_calibrated.jld2")
Calibration Benefits:
- Requires only 10-20 trials per class
- Faster than training from scratch
- Adapts to subject-specific characteristics
- Hyperparameters preserved:
calibrate_model() automatically uses the same hyperparameters as the base model (v0.2.0+)
4. Batch Inference
import numpy as np
# Run batch inference
predictions = clf.predict(X_test)
probabilities = clf.predict_proba(X_test)
confidences = np.max(probabilities, axis=1)
# Analyze results
print(f"Predictions: {predictions}")
print(f"Mean confidence: {np.mean(confidences):.3f}")
# Calculate metrics
accuracy = np.mean(predictions == y_test)
print(f"Accuracy: {accuracy * 100:.1f}%")
# Prepare test data
test_data = BCIData(test_features, metadata, test_labels)
# Run batch inference
results = predict_batch(model, test_data; iterations=10)
# Analyze results
println("Predictions: ", results.predictions)
println("Mean confidence: ", mean(results.confidences))
# Calculate metrics
accuracy = sum(results.predictions .== test_labels) / length(test_labels)
println("Accuracy: $(round(accuracy * 100, digits=1))%")
5. Streaming Inference
from nimbus_bci import StreamingSession
# Initialize streaming session
session = StreamingSession(clf.model_, metadata_with_chunk_size)
# Process chunks
for chunk in eeg_feature_stream:
result = session.process_chunk(chunk)
print(f"Chunk: pred={result.prediction}, conf={result.confidence:.3f}")
# Finalize trial
final_result = session.finalize_trial()
print(f"Final: pred={final_result.prediction}, conf={final_result.confidence:.3f}")
session.reset()
# Initialize streaming session
session = init_streaming(model, metadata_with_chunk_size)
# Process chunks
for chunk in eeg_feature_stream
result = process_chunk(session, chunk; iterations=10)
println("Chunk: pred=$(result.prediction), conf=$(round(result.confidence, digits=3))")
end
# Finalize trial
final_result = finalize_trial(session; method=:weighted_vote)
println("Final: pred=$(final_result.prediction), conf=$(round(final_result.confidence, digits=3))")
Hyperparameter Tuning (v0.2.0+)
Fine-tune Bayesian GMM for optimal performance on your specific dataset.
When to Tune Hyperparameters
Consider tuning when:
- Default performance is unsatisfactory
- You have specific data characteristics (very noisy or very clean)
- You have limited or extensive training data
- Working with complex, overlapping class distributions
- P300 or other paradigms with heterogeneous class variances
Tuning Strategies
For High SNR / Clean Data / Many Trials
Use lower regularization to let the data drive the model:
from nimbus_bci import NimbusGMM
# Lower regularization for clean data
clf = NimbusGMM(
mu_scale=1.0, # Weaker prior
sigma_scale=0.1 # Less covariance regularization
)
clf.fit(X_train, y_train)
model = train_model(
RxGMMModel,
train_data;
iterations = 50,
dof_offset = 1, # Less regularization
mean_prior_precision = 0.001 # Weaker prior, trust data more
)
Use when:
- SNR > 5 dB
- 100+ trials per class
- Clean, artifact-free data
- Well-controlled experimental conditions
For Low SNR / Noisy Data / Few Trials
Use higher regularization for stability (especially important for GMM with class-specific covariances):
from nimbus_bci import NimbusGMM
# Higher regularization for noisy data
clf = NimbusGMM(
mu_scale=5.0, # Stronger prior
sigma_scale=10.0 # More covariance regularization
)
clf.fit(X_train, y_train)
model = train_model(
RxGMMModel,
train_data;
iterations = 50,
dof_offset = 3, # More regularization
mean_prior_precision = 0.05 # Stronger prior
)
Use when:
- SNR < 2 dB
- 40-80 trials per class
- Noisy data or limited artifact removal
- Challenging recording conditions
- Risk of overfitting to class-specific noise
Balanced / Default Settings
The defaults work well for most scenarios:
model = train_model(
RxGMMModel,
train_data;
iterations = 50,
dof_offset = 2, # Balanced (default)
mean_prior_precision = 0.01 # Balanced (default)
)
Use when:
- Moderate SNR (2-5 dB)
- 80-150 trials per class
- Standard BCI recording conditions
- Starting point for experimentation
P300-Specific Tuning
For P300 paradigms where target/non-target classes have different variances:
# P300 often benefits from GMM's class-specific covariances
# Use moderate regularization to avoid overfitting
model = train_model(
RxGMMModel,
p300_train_data;
iterations = 50,
dof_offset = 2, # Standard regularization
mean_prior_precision = 0.02 # Slightly stronger than default
)
Hyperparameter Search Example
Systematically search for optimal hyperparameters:
using NimbusSDK
# Define search grid
dof_values = [1, 2, 3, 4]
prior_values = [0.001, 0.01, 0.03, 0.05]
# Split data
train_data, val_data = split_data(all_data, ratio=0.8)
best_accuracy = 0.0
best_params = nothing
println("Searching hyperparameters for RxGMM...")
for dof in dof_values
for prior in prior_values
# Train model
model = train_model(
RxGMMModel,
train_data;
iterations = 50,
dof_offset = dof,
mean_prior_precision = prior,
predictive_dof_offset = dof,
showprogress = false
)
# Validate
results = predict_batch(model, val_data)
accuracy = sum(results.predictions .== val_data.labels) / length(val_data.labels)
println(" dof=$dof, prior=$prior: $(round(accuracy*100, digits=1))%")
if accuracy > best_accuracy
best_accuracy = accuracy
best_params = (dof=dof, prior=prior)
end
end
end
println("\nBest hyperparameters:")
println(" dof_offset: $(best_params.dof)")
println(" mean_prior_precision: $(best_params.prior)")
println(" Validation accuracy: $(round(best_accuracy*100, digits=1))%")
# Retrain with best params
final_model = train_model(
RxGMMModel,
all_data;
iterations = 50,
dof_offset = best_params.dof,
mean_prior_precision = best_params.prior,
predictive_dof_offset = best_params.dof
)
Quick Tuning Guidelines
| Scenario | dof_offset | mean_prior_precision | Notes |
|---|
| Excellent data quality | 1 | 0.001 | Minimal regularization |
| Good data quality | 2 (default) | 0.01 (default) | Balanced approach |
| Moderate data quality | 2-3 | 0.01-0.03 | Slight regularization |
| Poor data quality | 3-4 | 0.05-0.1 | Strong regularization |
| Very limited trials | 4 | 0.1 | Maximum regularization |
| P300 target/non-target | 2 | 0.02 | Moderate, class-specific covariances helpful |
Pro Tip: Bayesian GMM’s class-specific covariances can overfit to noise with poor data. When in doubt, start with defaults and increase regularization (dof_offset=3, mean_prior_precision=0.03) if you see overfitting.
Important: Always set predictive_dof_offset to match dof_offset for consistency between training and inference phases.
Training Requirements
Data Requirements
- Minimum: 40 trials per class (80 total for 2-class)
- Recommended: 80+ trials per class
- For calibration: 10-20 trials per class
Bayesian GMM requires at least 2 observations per class to estimate class-specific statistics. Training will fail if any class has fewer than 2 observations.
Feature Normalization
Critical for cross-session BCI performance!Normalize your features before training for 15-30% accuracy improvement across sessions.
from sklearn.preprocessing import StandardScaler
import pickle
# Estimate normalization from training data
scaler = StandardScaler()
X_train_norm = scaler.fit_transform(X_train)
# Train with normalized features
clf = NimbusGMM()
clf.fit(X_train_norm, y_train)
# Save model and scaler together
with open("model_with_scaler.pkl", "wb") as f:
pickle.dump({'model': clf, 'scaler': scaler}, f)
# Later: Apply same normalization to test data
X_test_norm = scaler.transform(X_test)
predictions = clf.predict(X_test_norm)
# Estimate normalization from training data
norm_params = estimate_normalization_params(train_features; method=:zscore)
train_norm = apply_normalization(train_features, norm_params)
# Train with normalized features
train_data = BCIData(train_norm, metadata, labels)
model = train_model(RxGMMModel, train_data)
# Save params with model
@save "model.jld2" model norm_params
# Later: Apply same params to test data
test_norm = apply_normalization(test_features, norm_params)
See the Feature Normalization guide for complete details.
Feature Requirements
Bayesian GMM expects preprocessed features, not raw EEG:
✅ Required preprocessing:
- Bandpass filtering (paradigm-specific)
- Artifact removal
- Feature extraction (CSP, ERP amplitude, bandpower, etc.)
- Proper temporal aggregation
❌ NOT accepted:
- Raw EEG channels
- Unfiltered data
See Preprocessing Requirements.
| Operation | Latency | Notes |
|---|
| Training | 15-40 seconds | 50 iterations, 100 trials per class |
| Calibration | 8-20 seconds | 20 iterations, 20 trials per class |
| Batch Inference | 15-25ms per trial | 10 iterations |
| Streaming Chunk | 15-25ms | 10 iterations per chunk |
Slightly slower than RxLDA due to class-specific covariances.
Classification Accuracy
| Paradigm | Classes | Typical Accuracy | When to Use Bayesian GMM |
|---|
| P300 | 2 (Target/Non-target) | 85-95% | Target/non-target have different variances |
| Motor Imagery | 2-4 | 70-85% | When Bayesian LDA accuracy insufficient |
| SSVEP | 2-6 | 85-98% | Complex frequency responses |
Bayesian GMM typically provides 2-5% higher accuracy than Bayesian LDA when class covariances differ significantly, at the cost of ~5-10ms additional latency.
Model Inspection
View Model Parameters
import numpy as np
# Class means
print("Class means:")
for k, class_label in enumerate(clf.classes_):
print(f" Class {class_label}: {clf.model_['means'][k]}")
# Class-specific covariance matrices
print("\nClass-specific covariance matrices:")
for k, class_label in enumerate(clf.classes_):
print(f" Class {class_label} (first 3x3):")
print(clf.model_['covariances'][k][:3, :3])
# Compare covariances across classes
print("\nCovariance structure comparison:")
for k, class_label in enumerate(clf.classes_):
cov_k = clf.model_['covariances'][k]
print(f" Class {class_label} variance (diagonal): {np.diag(cov_k)}")
# Class priors
print("\nClass priors:")
for k, class_label in enumerate(clf.classes_):
print(f" Class {class_label}: {clf.model_['priors'][k]:.3f}")
# Extract point estimates from posterior distributions
using Distributions
# Class means (extract from posterior distributions)
println("Class means:")
for (k, mean_posterior) in enumerate(model.mean_posteriors)
mean_point = mean(mean_posterior) # Extract point estimate
println(" Class $k: ", mean_point)
end
# Class-specific precisions (extract from posterior distributions)
println("\nClass-specific precision matrices:")
for (k, precision_posterior) in enumerate(model.precision_posteriors)
prec_point = mean(precision_posterior) # Extract point estimate
println(" Class $k (first 3x3):")
println(prec_point[1:3, 1:3])
end
# Compare covariances across classes
println("\nCovariance structure comparison:")
for k in 1:length(model.precision_posteriors)
prec_point = mean(model.precision_posteriors[k])
cov_k = inv(prec_point) # Convert precision to covariance
println(" Class $k variance (diagonal): ", diag(cov_k))
end
# Class priors
println("\nClass priors:")
for (k, prior) in enumerate(model.priors)
println(" Class $k: ", prior)
end
Accessing model parameters: The SDK stores full posterior distributions (not just point estimates) for proper Bayesian inference. To get point estimates, use mean(posterior) to extract the mean of the posterior distribution. For precision matrices, use mean(precision_posterior) to get the expected precision matrix.
Visualize Class Differences
using Plots
using Distributions
# Compare class covariances
for k in 1:length(model.precision_posteriors)
prec_point = mean(model.precision_posteriors[k]) # Extract point estimate
cov_k = inv(prec_point) # Convert precision to covariance
heatmap(cov_k, title="Class $k Covariance", colorbar=true)
end
Advantages & Limitations
Advantages
✅ Flexible Modeling: Each class has its own covariance
✅ Better for Complex Data: Handles heterogeneous distributions
✅ Higher Accuracy: 2-5% improvement when classes differ significantly
✅ Uncertainty Quantification: Full Bayesian posteriors
✅ Production-Ready: Battle-tested in P300 applications
Limitations
❌ More Parameters: Requires more training data than RxLDA
❌ Slower Inference: ~15-25ms vs ~10-15ms for RxLDA
❌ Higher Memory: Stores n_classes precision matrices
❌ More Complex: Longer training time
Comparison: Bayesian GMM vs Bayesian LDA
| Aspect | Bayesian GMM (RxGMM) | Bayesian LDA (RxLDA) |
|---|
| Precision Matrix | Class-specific | Shared (one for all) |
| Mathematical Model | Heteroscedastic Gaussian Classifier (HGC) | Pooled Gaussian Classifier (PGC) |
| Training Speed | Slower | Faster |
| Inference Speed | 15-25ms | 10-15ms |
| Flexibility | High | Moderate |
| Data Requirements | More | Less |
| Memory Usage | Higher | Lower |
| Best For | Heterogeneous classes | Homogeneous classes |
| Accuracy Gain | +2-5% (when applicable) | Baseline |
Decision Tree
Is speed critical (<15ms)?
├─ Yes → Use Bayesian LDA
└─ No → Do classes have similar spreads?
├─ Yes → Use Bayesian LDA (faster, same accuracy)
└─ No → Use Bayesian GMM (more accurate)
Practical Examples
P300 Detection
# Bayesian GMM is ideal for P300 where target/non-target have different variances
# Train on P300 data
p300_data = BCIData(erp_features, p300_metadata, labels) # 1=target, 2=non-target
model = train_model(RxGMMModel, p300_data; iterations=50)
# Test
results = predict_batch(model, test_data)
# Calculate sensitivity/specificity
targets = test_labels .== 1
target_preds = results.predictions .== 1
sensitivity = sum(target_preds .& targets) / sum(targets)
specificity = sum(.!target_preds .& .!targets) / sum(.!targets)
println("Sensitivity: $(round(sensitivity * 100, digits=1))%")
println("Specificity: $(round(specificity * 100, digits=1))%")
When Bayesian LDA Fails
# Try Bayesian LDA first
rxlda_model = train_model(RxLDAModel, train_data)
rxlda_results = predict_batch(rxlda_model, test_data)
rxlda_acc = sum(rxlda_results.predictions .== test_labels) / length(test_labels)
# If accuracy insufficient, try Bayesian GMM
if rxlda_acc < 0.75
println("Bayesian LDA accuracy low ($(round(rxlda_acc*100, digits=1))%), trying Bayesian GMM...")
rxgmm_model = train_model(RxGMMModel, train_data)
rxgmm_results = predict_batch(rxgmm_model, test_data)
rxgmm_acc = sum(rxgmm_results.predictions .== test_labels) / length(test_labels)
println("Bayesian GMM accuracy: $(round(rxgmm_acc*100, digits=1))%")
println("Improvement: +$(round((rxgmm_acc-rxlda_acc)*100, digits=1))%")
end
Next Steps
References
Implementation:
Theory:
- McLachlan, G. J., & Peel, D. (2000). “Finite Mixture Models”
- Bishop, C. M. (2006). “Pattern Recognition and Machine Learning” (Chapter 9)
- Heteroscedastic Gaussian Classifier (HGC) with class-specific covariances
BCI Applications:
- Farwell, L. A., & Donchin, E. (1988). “Talking off the top of your head”
- Lotte et al. (2018). “A review of classification algorithms for EEG-based BCI”