Bayesian LDA - Bayesian Linear Discriminant Analysis
Python: NimbusLDA | Julia: NimbusLDA Mathematical Model: Pooled Gaussian Classifier (PGC)Bayesian LDA is a Bayesian classification model with uncertainty quantification. It uses a shared precision matrix across all classes, making it fast and efficient for BCI classification.
Available in Both SDKs:
Python SDK: NimbusLDA class (sklearn-compatible)
Julia SDK: NimbusLDA (RxInfer.jl-based)
Both implementations provide the same Bayesian inference with uncertainty quantification.
Bayesian LDA extends traditional Linear Discriminant Analysis with full Bayesian inference, providing:
✅ Posterior probability distributions (not just point estimates)
✅ Uncertainty quantification for each prediction
✅ Probabilistic confidence scores
✅ Fast inference (<20ms per trial)
✅ Training and calibration support
✅ Batch and streaming inference modes
from nimbus_bci import NimbusLDAimport numpy as np# Create and fit classifierclf = NimbusLDA(mu_scale=3.0)clf.fit(X_train, y_train)# Predict with uncertaintypredictions = clf.predict(X_test)probabilities = clf.predict_proba(X_test)# Online learningclf.partial_fit(X_new, y_new)
struct NimbusLDA <: BCIModel mean_posteriors::Vector # MvNormal posteriors for class means precision_posterior::Any # Wishart posterior for shared precision priors::Vector{Float64} # Empirical class priors metadata::ModelMetadata # Model info dof_offset::Int # Degrees of freedom offset (training) mean_prior_precision::Float64 # Mean prior precision (training)end
The Bayesian LDA model uses RxInfer.jl for variational Bayesian inference:Learning Phase:
@model function RxLDA_learning_model(y, labels, n_features, n_classes, dof_offset, mean_prior_precision) # Prior on shared precision dof = n_features + dof_offset W ~ Wishart(dof, I) # Priors on class means for k in 1:n_classes m[k] ~ MvNormal(mean=zeros(n_features), precision=mean_prior_precision * I) end # Likelihood with known labels for i in eachindex(y) k = labels[i] y[i] ~ MvNormal(mean=m[k], precision=W) endend
Prediction Phase (Per-Class Likelihood):To avoid mixture collapse with single observations, inference computes per-class likelihoods independently:
@model function RxLDA_predictive_single_class(y, class_mean, class_precision) n_features = length(class_mean) # Prior for this specific class (illustrative; actual implementation uses full posteriors) m ~ MvNormalMeanPrecision(class_mean, 1e6 * I) w ~ Wishart(n_features + 2, class_precision) # Observations (no mixture variable) for i in eachindex(y) y[i] ~ MvNormal(mean=m, precision=w) endend
For each class k, compute p(y | μ_k, W) independently, then combine using Bayes’ rule with softmax normalization.
Python SDK: The Python SDK (nimbus-bci) trains models locally and doesn’t require pre-trained model loading. See Python SDK Quickstart for training examples.
Python
Julia
from nimbus_bci import NimbusLDAimport pickle# Python SDK: Train your own model locally# No authentication or model zoo needed# Quick training exampleclf = NimbusLDA(mu_scale=3.0)clf.fit(X_train, y_train)# Save for later usewith open("my_motor_imagery.pkl", "wb") as f: pickle.dump(clf, f)# Load saved modelwith open("my_motor_imagery.pkl", "rb") as f: clf = pickle.load(f)print("Model info:")print(f" Classes: {clf.classes_}")print(f" Features: {clf.n_features_in_}")
using NimbusSDK# AuthenticateNimbusSDK.install_core("nbci_live_your_key")# Load from Nimbus model zoomodel = load_model(NimbusLDA, "motor_imagery_4class_v1")println("Model loaded:")println(" Features: $(get_n_features(model))")println(" Classes: $(get_n_classes(model))")println(" Paradigm: $(get_paradigm(model))")
Fine-tune a pre-trained model with subject-specific data (much faster than training from scratch):
Python
Julia
from nimbus_bci import NimbusLDAimport numpy as npimport pickle# Load baseline model (trained on multiple subjects)with open("motor_imagery_baseline.pkl", "rb") as f: base_clf = pickle.load(f)# Collect 10-20 calibration trials from new subjectX_calib, y_calib = collect_calibration_trials() # Your function# Personalize using online learning (partial_fit)personalized_clf = NimbusLDA()personalized_clf.fit(X_baseline, y_baseline) # Start with baseline# Fine-tune on calibration datafor _ in range(10): # Multiple passes for adaptation personalized_clf.partial_fit(X_calib, y_calib)# Save personalized modelwith open("subject_001_calibrated.pkl", "wb") as f: pickle.dump(personalized_clf, f)print("Model personalized successfully!")
# Load base modelbase_model = load_model(NimbusLDA, "motor_imagery_baseline_v1")# Collect 10-20 calibration trials from new subjectcalib_features = collect_calibration_trials() # Your functioncalib_labels = [1, 2, 3, 4, 1, 2, ...]calib_data = BCIData(calib_features, metadata, calib_labels)# Calibrate (personalize) the modelpersonalized_model = calibrate_model( base_model, calib_data; iterations = 20 # Fewer iterations needed)save_model(personalized_model, "subject_001_calibrated.jld2")
Calibration Benefits:
Requires only 10-20 trials per class (vs 50-100 for training from scratch)
Faster: 20 iterations vs 50-100
Better generalization: Uses pre-trained model as prior
Typical accuracy improvement: 5-15% over generic model
Hyperparameters automatically preserved: calibrate_model() inherits training hyperparameters from the base model (dof_offset, mean_prior_precision) ensuring consistency
Important: You don’t need to specify hyperparameters when calibrating - they are automatically inherited from the base model. This ensures the calibrated model uses the same regularization strategy that worked well during initial training.
Use lower regularization to let the data drive the model:
Python
Julia
from nimbus_bci import NimbusLDA# Lower regularization for clean dataclf = NimbusLDA( mu_scale=1.0, # Weaker prior, trust data more sigma_scale=0.1 # Less covariance regularization)clf.fit(X_train, y_train)
model = train_model( NimbusLDA, train_data; iterations = 50, dof_offset = 1, # Less regularization mean_prior_precision = 0.001 # Weaker prior, trust data more)
Pro Tip: Start with defaults (dof_offset=2, mean_prior_precision=0.01) and only tune if performance is unsatisfactory. The defaults are optimized for typical BCI scenarios.
Minimum: 40 trials per class (160 total for 4-class)
Recommended: 80+ trials per class (320+ total for 4-class)
For calibration: 10-20 trials per class sufficient
Bayesian LDA requires at least 2 trials to estimate class statistics and shared precision matrix. Single-trial training is not statistically valid for LDA and will raise an ArgumentError.Your training data must have shape (n_features, n_samples, n_trials) where n_trials >= 2.
Critical for cross-session BCI performance!Normalize your features before training for 15-30% accuracy improvement across sessions.
Python
Julia
from sklearn.preprocessing import StandardScalerimport pickle# Estimate normalization from training datascaler = StandardScaler()X_train_norm = scaler.fit_transform(X_train)# Train with normalized featuresclf = NimbusLDA()clf.fit(X_train_norm, y_train)# Save model and scaler togetherwith open("model_with_scaler.pkl", "wb") as f: pickle.dump({'model': clf, 'scaler': scaler}, f)# Later: Apply same normalization to test dataX_test_norm = scaler.transform(X_test)predictions = clf.predict(X_test_norm)
# Estimate normalization from training datanorm_params = estimate_normalization_params(train_features; method=:zscore)train_norm = apply_normalization(train_features, norm_params)# Train with normalized featurestrain_data = BCIData(train_norm, metadata, labels)model = train_model(NimbusLDA, train_data)# Save params with model@save "model.jld2" model norm_params# Later: Apply same params to test datatest_norm = apply_normalization(test_features, norm_params)
Critical Preprocessing Step: Before training, the SDK automatically aggregates the temporal dimension of each trial into a single feature vector. This prevents treating temporally correlated samples as independent observations, which would violate the i.i.d. assumption of the model.
The aggregation method depends on your feature type and paradigm:
CSP features: Log-variance aggregation (default for motor imagery)
Power spectral features: Mean or median aggregation
Other features: Configurable via BCIMetadata.temporal_aggregation
This aggregation happens automatically during train_model() and predict_batch() calls.
import numpy as np# Class meansprint("Class means:")for k, class_label in enumerate(clf.classes_): print(f" Class {class_label}: {clf.model_['means'][k]}")# Shared covariance matrix (first 3x3)print("\nShared covariance matrix (first 3x3):")print(clf.model_['covariance'][:3, :3])# Model infoprint("\nModel info:")print(f" Features: {clf.n_features_in_}")print(f" Classes: {clf.classes_}")# Class priors (learned from data)print("\nClass priors:")for k, class_label in enumerate(clf.classes_): print(f" Class {class_label}: {clf.model_['priors'][k]:.3f}")
# Extract point estimates from posterior distributionsusing Distributions# Class means (extract from posterior distributions)println("Class means:")for (k, mean_posterior) in enumerate(model.mean_posteriors) mean_point = mean(mean_posterior) # Extract point estimate println(" Class $k: ", mean_point)end# Shared precision matrix (extract from posterior distribution)println("\nShared precision matrix (first 3x3):")precision_point = mean(model.precision_posterior) # Extract point estimateprintln(precision_point[1:3, 1:3])# Model metadataprintln("\nMetadata:")println(" Name: ", model.metadata.name)println(" Paradigm: ", model.metadata.paradigm)println(" Features: ", model.metadata.n_features)println(" Classes: ", model.metadata.n_classes)# Class priorsprintln("\nClass priors:")for (k, prior) in enumerate(model.priors) println(" Class $k: ", prior)end
Accessing model parameters: The SDK stores full posterior distributions (not just point estimates) for proper Bayesian inference. To get point estimates, use mean(posterior) to extract the mean of the posterior distribution. For precision matrices, use mean(precision_posterior) to get the expected precision matrix.
❌ Shared Covariance Assumption: May not fit well if classes have very different spreads
❌ Linear Decision Boundary: Cannot capture non-linear class boundaries
❌ Gaussian Assumption: Assumes normal class distributions
❌ Not Ideal for Overlapping Classes: Use NimbusQDA for complex distributions