Skip to main content

Bayesian MPR (RxPolya)

Primary Name: Bayesian MPR (Bayesian Multinomial Probit Regression)
API Name: RxPolyaModel
Mathematical Model: Bayesian Multinomial Probit Regression
Bayesian MPR is a Bayesian multinomial classification model with uncertainty quantification, implemented using RxInfer.jl’s reactive message passing framework. It uses continuous transitions to map features to a (K-1)-dimensional latent space and MultinomialPolya likelihood for probabilistic classification, making it ideal for complex multinomial tasks.
Bayesian MPR is currently implemented in NimbusSDK.jl (API name: RxPolyaModel) and ready for use in production BCI applications. This model provides full Bayesian uncertainty quantification over multinomial distributions, offering advanced capabilities beyond traditional Gaussian classifiers.

Overview

Bayesian MPR implements Bayesian Multinomial Probit Regression, providing a powerful alternative to traditional Gaussian classifiers:
  • Full Bayesian inference with posterior probability distributions
  • Uncertainty quantification for each prediction
  • Continuous transition mapping from features to latent space
  • MultinomialPolya likelihood for natural multinomial classification
  • (K-1) dimensional latent representation with sum-to-one constraint
  • Fast inference (~15-25ms per trial)
  • Training and calibration support
  • Batch and streaming inference modes

When to Use Bayesian MPR

Bayesian MPR is ideal for:
  • Complex multinomial classification tasks
  • When you need full posterior uncertainty over multinomial distributions
  • Advanced BCI applications requiring flexible classification
  • Tasks where Bayesian multinomial probit regression is theoretically appropriate
  • When RxLDA/RxGMM accuracy is insufficient for complex distributions
Consider Bayesian LDA or Bayesian GMM instead if:
  • You need faster inference (RxLDA/RxGMM are 5-10ms faster)
  • Classes are well-separated with Gaussian distributions
  • Interpretability of class centers is important (Bayesian MPR is discriminative)
  • You want Mahalanobis distance-based outlier detection

Model Architecture

Mathematical Foundation (Bayesian Multinomial Probit Regression)

Bayesian MPR implements Bayesian Multinomial Probit Regression using continuous transitions and MultinomialPolya likelihood: Generative Model:
B ~ N(ξβ, Wβ⁻¹)           # Regression coefficients prior
W ~ Wishart(ν, Λ)          # Precision matrix prior
Ψᵢ ~ ContinuousTransition(xᵢ, B, W)  # Latent transformation
yᵢ ~ MultinomialPolya(N, Ψᵢ)         # Observation model
Where:
  • B ∈ ℝ^((K-1)×D) = regression coefficient matrix
  • W = precision matrix in (K-1)-dimensional latent space
  • xᵢ = feature vector for observation i
  • Ψᵢ = latent (K-1)-dimensional representation
  • N = number of trials per observation (typically 1 for classification)
Key Feature: The model works in (K-1) dimensions due to the sum-to-one constraint in multinomial distributions, with the K-th class serving as the reference category.

Hyperparameters

Bayesian MPR supports configurable hyperparameters for optimal performance tuning: Available Hyperparameters:
ParameterTypeDefaultDescription
NInt1Number of trials per observation (typically 1 for classification)
ξβVectorones((K-1)×D)Prior mean for regression coefficients B
Matrix1e-5 × IPrior precision for regression coefficients B
W_dfFloat64K + 5Wishart degrees of freedom for precision matrix W
W_scaleMatrixIWishart scale matrix (K-1 × K-1)
Parameter Effects:
  • N: Number of trials per observation
    • Typically set to 1 for classification tasks
    • Higher values model count data (less common in BCI)
  • ξβ: Prior mean for regression coefficients
    • Default: ones((K-1)×D) provides mild regularization
    • Can be customized if you have prior knowledge of coefficient values
  • : Prior precision for regression coefficients
    • Lower values (1e-6) → Weaker prior, more data-driven
    • Higher values (1e-4) → Stronger prior, more regularization
    • Default (1e-5) provides balanced regularization
  • W_df: Wishart degrees of freedom
    • Controls strength of precision matrix prior
    • Higher values → Stronger regularization
    • Default (K + 5) provides reasonable regularization
  • W_scale: Wishart scale matrix
    • Shape of the prior precision distribution
    • Default (identity matrix) assumes no prior covariance structure
Hyperparameter configuration allows you to optimize model behavior for your specific dataset characteristics (SNR, trial count, data quality).

Model Structure

struct RxPolyaModel <: BCIModel
    B_posterior                        # Learned regression coefficients posterior
    W_posterior                        # Learned precision matrix posterior
    metadata::ModelMetadata            # Model info
    N::Int                            # Trials per observation
    ξβ::Vector{Float64}               # Prior mean for B
::Matrix{Float64}               # Prior precision for B
    W_df::Float64                     # Wishart DOF
    W_scale::Matrix{Float64}          # Wishart scale matrix
end

RxInfer Implementation

The RxPolya model uses RxInfer.jl for variational Bayesian inference: Learning Phase:
@model function RxPolya_learning_model(obs, N, X, n_classes, n_features, 
                                       ξβ, Wβ, W_df, W_scale)
    # Prior on regression coefficients
    B ~ MvNormalWeightedMeanPrecision(ξβ, Wβ)
    
    # Prior on precision matrix
    W ~ Wishart(W_df, W_scale)
    
    # Likelihood with continuous transition
    for i in eachindex(obs)
        Ψ[i] ~ ContinuousTransition(X[i], B, W)
        obs[i] ~ MultinomialPolya(N, Ψ[i])
    end
end
Prediction Phase:
@model function RxPolya_predictive(obs, N, X, B_posterior, W_posterior, 
                                   n_classes, n_features)
    B ~ B_posterior
    W ~ W_posterior
    
    # Predict latent representations
    for i in eachindex(X)
        Ψ[i] ~ ContinuousTransition(X[i], B, W)
        obs[i] ~ MultinomialPolya(N, Ψ[i])
    end
end

Usage

1. Load Pre-trained Model

using NimbusSDK

# Authenticate
NimbusSDK.install_core("nbci_live_your_key")

# Load from Nimbus model zoo
model = load_model(RxPolyaModel, "motor_imagery_polya_v1")

println("Model loaded:")
println("  Features: $(get_n_features(model))")
println("  Classes: $(get_n_classes(model))")
println("  Paradigm: $(get_paradigm(model))")

2. Train Custom Model

using NimbusSDK

# Prepare training data with labels
train_features = csp_features  # (16 × 250 × 100)
train_labels = [1, 2, 3, 4, 1, 2, ...]  # 100 labels

train_data = BCIData(
    train_features,
    BCIMetadata(
        sampling_rate = 250.0,
        paradigm = :motor_imagery,
        feature_type = :csp,
        n_features = 16,
        n_classes = 4,
        chunk_size = nothing
    ),
    train_labels  # Required for training!
)

# Train RxPolya model with default hyperparameters
model = train_model(
    RxPolyaModel,
    train_data;
    iterations = 50,        # Inference iterations
    showprogress = true,    # Show progress bar
    name = "my_motor_imagery_polya",
    description = "4-class MI classifier with RxPolya"
)

# Or train with custom hyperparameters
model = train_model(
    RxPolyaModel,
    train_data;
    iterations = 50,
    showprogress = true,
    name = "my_motor_imagery_polya_tuned",
    description = "4-class MI with tuned hyperparameters",
    N = 1,                           # Trials per observation (default: 1)
    ξβ = nothing,                    # Auto-configured if not provided
= nothing,                    # Auto-configured if not provided
    W_df = nothing,                  # Auto-configured if not provided
    W_scale = nothing                # Auto-configured if not provided
)

# Save for later use
save_model(model, "my_model.jld2")
Training Parameters:
  • iterations: Number of variational inference iterations (default: 50)
    • More iterations = better convergence but slower training
    • 50-100 is typically sufficient
  • showprogress: Display progress bar during training
  • name: Model identifier
  • description: Model description for documentation
  • N: Trials per observation (default: 1, range: [1, ∞))
  • ξβ: Prior mean for B (default: ones((K-1)×D))
  • : Prior precision for B (default: 1e-5 × I)
  • W_df: Wishart degrees of freedom (default: K + 5)
  • W_scale: Wishart scale matrix (default: I)

3. Subject-Specific Calibration

Fine-tune a pre-trained model with subject-specific data (much faster than training from scratch):
# Load base model
base_model = load_model(RxPolyaModel, "motor_imagery_polya_baseline_v1")

# Collect 10-20 calibration trials from new subject
calib_features = collect_calibration_trials()  # Your function
calib_labels = [1, 2, 3, 4, 1, 2, ...]

calib_data = BCIData(calib_features, metadata, calib_labels)

# Calibrate (personalize) the model
personalized_model = calibrate_model(
    base_model,
    calib_data;
    iterations = 20  # Fewer iterations needed
)

save_model(personalized_model, "subject_001_calibrated.jld2")
Calibration Benefits:
  • Requires only 10-20 trials per class (vs 50-100 for training from scratch)
  • Faster: 20 iterations vs 50-100
  • Better generalization: Uses pre-trained model as prior
  • Typical accuracy improvement: 5-15% over generic model
  • Hyperparameters preserved: calibrate_model() automatically uses the same hyperparameters as the base model

4. Batch Inference

Process multiple trials efficiently:
# Prepare test data
test_data = BCIData(test_features, metadata, test_labels)

# Run batch inference
results = predict_batch(model, test_data; iterations=10)

# Analyze results
println("Predictions: ", results.predictions)
println("Mean confidence: ", mean(results.confidences))

# Calculate accuracy
accuracy = sum(results.predictions .== test_labels) / length(test_labels)
println("Accuracy: $(round(accuracy * 100, digits=1))%")

# Calculate ITR
itr = calculate_ITR(accuracy, 4, 4.0)  # 4 classes, 4-second trials
println("ITR: $(round(itr, digits=1)) bits/minute")

5. Streaming Inference

Real-time chunk-by-chunk processing:
# Initialize streaming session
session = init_streaming(model, metadata_with_chunk_size)

# Process chunks as they arrive
for chunk in eeg_feature_stream
    result = process_chunk(session, chunk; iterations=10)
    println("Chunk: pred=$(result.prediction), conf=$(round(result.confidence, digits=3))")
end

# Finalize trial with aggregation
final_result = finalize_trial(session; method=:weighted_vote)
println("Final: pred=$(final_result.prediction), conf=$(round(final_result.confidence, digits=3))")

Hyperparameter Tuning

Fine-tune Bayesian MPR for optimal performance on your specific dataset.

When to Tune Hyperparameters

Consider tuning when:
  • Default performance is unsatisfactory
  • You have specific data characteristics (very noisy or very clean)
  • You have limited or extensive training data
  • You want to optimize for your specific paradigm

Tuning Strategies

For High SNR / Clean Data / Many Trials

Use weaker priors to let the data drive the model:
model = train_model(
    RxPolyaModel,
    train_data;
    iterations = 50,
    N = 1,
= 1e-6 * diageye((n_classes - 1) * n_features),  # Weaker prior
    W_df = Float64(n_classes + 2)                       # Less regularization
)
Use when:
  • SNR > 5 dB
  • 100+ trials per class
  • Clean, artifact-free data
  • Well-controlled experimental conditions

For Low SNR / Noisy Data / Few Trials

Use stronger priors for stability:
model = train_model(
    RxPolyaModel,
    train_data;
    iterations = 50,
    N = 1,
= 1e-4 * diageye((n_classes - 1) * n_features),  # Stronger prior
    W_df = Float64(n_classes + 10)                      # More regularization
)
Use when:
  • SNR < 2 dB
  • 40-80 trials per class
  • Noisy data or limited artifact removal
  • Challenging recording conditions

Balanced / Default Settings

The defaults work well for most scenarios:
model = train_model(
    RxPolyaModel,
    train_data;
    iterations = 50,
    N = 1,                           # Standard (default)
= 1e-5 * diageye(B_dim),     # Balanced (default)
    W_df = Float64(n_classes + 5)    # Balanced (default)
)
Use when:
  • Moderate SNR (2-5 dB)
  • 80-150 trials per class
  • Standard BCI recording conditions
  • Starting point for experimentation

Hyperparameter Search Example

Systematically search for optimal hyperparameters:
using NimbusSDK

# Define search grid
Wβ_scales = [1e-6, 1e-5, 1e-4]
W_df_offsets = [2, 5, 10]

# Split data into train/validation
train_data, val_data = split_data(all_data, ratio=0.8)

best_accuracy = 0.0
best_params = nothing

println("Searching hyperparameters...")
for Wβ_scale in Wβ_scales
    for df_offset in W_df_offsets
        B_dim = (n_classes - 1) * n_features
        
        # Train model with these hyperparameters
        model = train_model(
            RxPolyaModel,
            train_data;
            iterations = 50,
            N = 1,
= Wβ_scale * diageye(B_dim),
            W_df = Float64(n_classes + df_offset),
            showprogress = false
        )
        
        # Validate
        results = predict_batch(model, val_data)
        accuracy = sum(results.predictions .== val_data.labels) / length(val_data.labels)
        
        println("  Wβ_scale=$Wβ_scale, df_offset=$df_offset: $(round(accuracy*100, digits=1))%")
        
        # Track best
        if accuracy > best_accuracy
            best_accuracy = accuracy
            best_params = (Wβ_scale=Wβ_scale, df_offset=df_offset)
        end
    end
end

println("\nBest hyperparameters:")
println("  Wβ_scale: $(best_params.Wβ_scale)")
println("  W_df offset: $(best_params.df_offset)")
println("  Validation accuracy: $(round(best_accuracy*100, digits=1))%")

# Retrain on all data with best hyperparameters
final_model = train_model(
    RxPolyaModel,
    all_data;
    iterations = 50,
    N = 1,
= best_params.Wβ_scale * diageye((n_classes - 1) * n_features),
    W_df = Float64(n_classes + best_params.df_offset)
)

Quick Tuning Guidelines

Scenario scaleW_df offsetNotes
Excellent data quality1e-62Minimal regularization
Good data quality1e-5 (default)5 (default)Balanced approach
Moderate data quality1e-5 to 1e-45-8Slight regularization
Poor data quality1e-410Strong regularization
Very limited trials1e-415Maximum regularization
Pro Tip: Start with defaults (Wβ = 1e-5 × I, W_df = K + 5) and only tune if performance is unsatisfactory. The defaults are optimized for typical BCI scenarios.

Training Requirements

Data Requirements

  • Minimum: 40 trials per class (160 total for 4-class)
  • Recommended: 80+ trials per class (320+ total for 4-class)
  • For calibration: 10-20 trials per class sufficient
RxPolya requires at least 2 observations to estimate regression coefficients and the precision matrix. Single-trial training will raise an error.

Feature Requirements

Bayesian MPR expects preprocessed features, not raw EEG: Required preprocessing:
  • Bandpass filtering (8-30 Hz for motor imagery)
  • Artifact removal (ICA recommended)
  • Spatial filtering (CSP for motor imagery)
  • Feature extraction (log-variance for CSP features)
NOT accepted:
  • Raw EEG channels
  • Unfiltered data
  • Non-extracted features
See Preprocessing Requirements for details.

Performance Characteristics

Computational Performance

OperationLatencyNotes
Training15-40 seconds50 iterations, 100 trials per class
Calibration8-20 seconds20 iterations, 20 trials per class
Batch Inference15-25ms per trial10 iterations
Streaming Chunk15-25ms10 iterations per chunk
Slightly slower than RxLDA/RxGMM due to continuous transition complexity.

Classification Accuracy

ParadigmClassesTypical AccuracyWhen to Use RxPolya
Motor Imagery2-470-85%When RxLDA/RxGMM insufficient
P3002 (Target/Non-target)85-95%Complex distributions
SSVEP2-685-98%Advanced applications
Bayesian MPR typically provides comparable or slightly better accuracy than RxLDA/RxGMM for complex distributions, at the cost of ~5-10ms additional latency.

Model Inspection

View Model Parameters

# Learned posteriors
println("B posterior (regression coefficients):")
println("  Type: ", typeof(model.B_posterior))
println("  Mean: ", mean(model.B_posterior))

println("\nW posterior (precision matrix):")
println("  Type: ", typeof(model.W_posterior))
println("  Mean: ", mean(model.W_posterior))

# Hyperparameters
println("\nHyperparameters:")
println("  N: ", model.N)
println("  ξβ dimensions: ", size(model.ξβ))
println("  Wβ dimensions: ", size(model.Wβ))
println("  W_df: ", model.W_df)
println("  W_scale dimensions: ", size(model.W_scale))

# Model metadata
println("\nMetadata:")
println("  Name: ", model.metadata.name)
println("  Paradigm: ", model.metadata.paradigm)
println("  Features: ", model.metadata.n_features)
println("  Classes: ", model.metadata.n_classes)

Compare Models

# Train multiple models and compare
models = []
for n_iter in [20, 50, 100]
    model = train_model(RxPolyaModel, train_data; iterations=n_iter)
    results = predict_batch(model, test_data)
    accuracy = sum(results.predictions .== test_labels) / length(test_labels)
    
    println("Iterations: $n_iter, Accuracy: $(round(accuracy*100, digits=1))%")
    push!(models, (n_iter, model, accuracy))
end

Advantages & Limitations

Advantages

Flexible Multinomial Classification: Natural handling of multinomial distributions
Continuous Transition Mapping: Sophisticated feature-to-latent-space transformation
Full Bayesian Uncertainty: Complete posterior distributions over predictions
No Gaussian Assumption: More flexible than Gaussian classifiers (RxLDA/RxGMM)
Production-Ready: Battle-tested in real BCI applications
Calibration Support: Fast subject-specific adaptation

Limitations

More Complex: More parameters than RxLDA, requires careful hyperparameter tuning
Slower Inference: ~15-25ms vs ~10-15ms for RxLDA
(K-1) Dimensional Space: Works in reduced dimension (not always intuitive)
No Mahalanobis Distance: Discriminative model lacks explicit class centers for outlier detection
Requires Multiple Trials: Cannot train on single trial (minimum 2 observations)
Less Interpretable: Harder to interpret than generative models with explicit means

Comparison: Bayesian MPR vs RxLDA vs RxGMM

AspectBayesian MPRRxLDARxGMM
Mathematical ModelBayesian Multinomial Probit RegressionPooled Gaussian Classifier (PGC)Heteroscedastic Gaussian Classifier (HGC)
Representation(K-1) dimensional latent spaceK class means with shared precisionK class means with class-specific precisions
Training SpeedModerateFastestModerate
Inference Speed15-25ms10-15ms15-20ms
FlexibilityHighestLowestHigh
Best ForComplex multinomial tasksWell-separated classesOverlapping classes with different covariances
InterpretabilityLow (discriminative)High (generative with means)High (generative with means)
Mahalanobis Distance❌ No (no explicit means)✅ Yes✅ Yes
Entropy Metrics✅ Yes✅ Yes✅ Yes
Free Energy✅ Yes (training only)✅ Yes✅ Yes

Decision Tree

Do you need Mahalanobis-based outlier detection?
├─ Yes → Use RxLDA or RxGMM
└─ No → Are classes well-separated with Gaussian structure?
    ├─ Yes → Use RxLDA (fastest)
    └─ No → Do classes have different covariances?
        ├─ Yes → Use RxGMM
        └─ No → Are distributions complex/non-Gaussian?
            ├─ Yes → Use Bayesian MPR
            └─ No → Use RxLDA (simplest)

Practical Examples

Motor Imagery Classification

# Bayesian MPR for complex motor imagery with overlapping classes

# Train on motor imagery data
mi_data = BCIData(csp_features, mi_metadata, labels)  # 4-class MI
model = train_model(RxPolyaModel, mi_data; iterations=50)

# Test
results = predict_batch(model, test_data)

# Calculate per-class accuracy
for class_id in 1:4
    class_mask = test_labels .== class_id
    class_accuracy = sum(results.predictions[class_mask] .== class_id) / sum(class_mask)
    println("Class $class_id accuracy: $(round(class_accuracy * 100, digits=1))%")
end

P300 Detection

# Bayesian MPR for P300 target/non-target classification

# Train on P300 data
p300_data = BCIData(erp_features, p300_metadata, labels)  # 1=target, 2=non-target
model = train_model(RxPolyaModel, p300_data; iterations=50)

# Test
results = predict_batch(model, test_data)

# Calculate sensitivity/specificity
targets = test_labels .== 1
target_preds = results.predictions .== 1

sensitivity = sum(target_preds .& targets) / sum(targets)
specificity = sum(.!target_preds .& .!targets) / sum(.!targets)

println("Sensitivity: $(round(sensitivity * 100, digits=1))%")
println("Specificity: $(round(specificity * 100, digits=1))%")

Comparing with RxLDA/RxGMM

# Compare Bayesian MPR with other models

# Train all three models
rxlda_model = train_model(RxLDAModel, train_data)
rxgmm_model = train_model(RxGMMModel, train_data)
rxpolya_model = train_model(RxPolyaModel, train_data)

# Evaluate all three
models = [
    ("RxLDA", rxlda_model),
    ("RxGMM", rxgmm_model),
    ("Bayesian MPR", rxpolya_model)
]

for (name, model) in models
    results = predict_batch(model, test_data)
    accuracy = sum(results.predictions .== test_labels) / length(test_labels)
    mean_conf = mean(results.confidences)
    mean_entropy = mean(results.entropy)
    
    println("$name:")
    println("  Accuracy: $(round(accuracy*100, digits=1))%")
    println("  Mean confidence: $(round(mean_conf, digits=3))")
    println("  Mean entropy: $(round(mean_entropy, digits=3)) bits")
    println("  Latency: $(results.latency_ms) ms")
    println()
end

Next Steps

References

Implementation: Theory:
  • Albert, J. H., & Chib, S. (1993). “Bayesian analysis of binary and polychotomous response data”
  • Imai, K., & van Dyk, D. A. (2005). “A Bayesian analysis of the multinomial probit model”
  • Multinomial probit regression with Bayesian inference
BCI Applications:
  • Blankertz et al. (2008). “Optimizing spatial filters for robust EEG single-trial analysis”
  • Lotte et al. (2018). “A review of classification algorithms for EEG-based BCI”