Skip to main content

Batch Processing for BCI

Batch processing is essential for offline analysis, model training, research studies, and comprehensive data evaluation. NimbusSDK provides efficient batch inference capabilities optimized for processing multiple trials simultaneously.

When to Use Batch Processing

Batch processing is ideal for:
  • Model training and calibration: Train custom models on labeled data
  • Offline data analysis: Analyze recorded BCI sessions
  • Cross-validation: Evaluate model performance systematically
  • Research studies: Process data from multiple subjects/sessions
  • Performance benchmarking: Compare different models or parameters
  • Quality assessment: Identify problematic trials or sessions
Batch processing trades real-time responsiveness for computational efficiency. Use streaming inference for online/real-time applications.

Basic Batch Inference

Simple Batch Setup

using NimbusSDK

# 1. Authenticate
NimbusSDK.install_core("your-api-key")

# 2. Load model
model = load_model(RxLDAModel, "motor_imagery_4class_v1")

# 3. Prepare batch data
# Features shape: (n_features × n_samples × n_trials)
features = load_preprocessed_features()  # Your preprocessed CSP/bandpower features
labels = load_trial_labels()             # Ground truth labels (optional)

metadata = BCIMetadata(
    sampling_rate = 250.0,
    paradigm = :motor_imagery,
    feature_type = :csp,
    n_features = 16,
    n_classes = 4,
    chunk_size = nothing  # Batch mode (not streaming)
)

data = BCIData(features, metadata, labels)

# 4. Run batch inference
results = predict_batch(model, data; iterations=10)

# 5. Analyze results
println("Processed $(length(results.predictions)) trials")
println("Mean confidence: $(round(mean(results.confidences), digits=3))")

if !isnothing(labels)
    accuracy = sum(results.predictions .== labels) / length(labels)
    println("Accuracy: $(round(accuracy * 100, digits=1))%")
end

Performance Characteristics

Typical performance:
  • 200 trials with 16 features: ~5 seconds (~25ms per trial)
  • 1000 trials with 32 features: ~30 seconds (~30ms per trial)
  • Scales linearly with number of trials and features
Batch inference is more efficient than processing trials individually. Process all your trials in one predict_batch() call when possible.

Training Custom Models

Basic Model Training

Train a custom model on your own labeled data:
using NimbusSDK

# Authenticate
NimbusSDK.install_core("your-api-key")

# Prepare training data with labels
train_features = load_training_features()  # (n_features × n_samples × n_trials)
train_labels = load_training_labels()      # Vector: [1, 2, 1, 3, 4, ...]

train_data = BCIData(
    train_features,
    BCIMetadata(
        sampling_rate = 250.0,
        paradigm = :motor_imagery,
        feature_type = :csp,
        n_features = 16,
        n_classes = 4
    ),
    train_labels  # Labels required for training!
)

# Train Bayesian LDA model
model = train_model(
    RxLDAModel,
    train_data;
    iterations = 50,           # Number of inference iterations
    showprogress = true,       # Display progress bar
    name = "my_motor_imagery",
    description = "Custom 4-class motor imagery with CSP"
)

# Save trained model
save_model(model, "my_trained_model.jld2")
println("✓ Model trained and saved")

# Evaluate on test data
test_data = BCIData(test_features, metadata, test_labels)
test_results = predict_batch(model, test_data)

test_accuracy = sum(test_results.predictions .== test_labels) / length(test_labels)
println("Test accuracy: $(round(test_accuracy * 100, digits=1))%")

Subject-Specific Calibration

Fine-tune a pre-trained model with subject-specific data:
using NimbusSDK

# Load pre-trained baseline model
base_model = load_model(RxLDAModel, "motor_imagery_baseline_v1")

# Collect calibration data (10-20 trials is often sufficient)
calib_features = collect_calibration_trials()  # Your function
calib_labels = [1, 2, 3, 4, 1, 2, 3, 4, ...]  # Labels for calibration

calib_data = BCIData(
    calib_features,
    BCIMetadata(
        sampling_rate = 250.0,
        paradigm = :motor_imagery,
        feature_type = :csp,
        n_features = 16,
        n_classes = 4
    ),
    calib_labels
)

# Calibrate (faster than training from scratch)
personalized_model = calibrate_model(
    base_model,
    calib_data;
    iterations = 20  # Fewer iterations needed for calibration
)

# Save personalized model
save_model(personalized_model, "subject_001_calibrated.jld2")
println("✓ Model calibrated for subject")

Training Best Practices

Data requirements:
  • Minimum: 50-100 trials per class
  • Recommended: 200+ trials per class for robust models
  • Calibration: 10-20 trials per class sufficient
Model selection:
  • Bayesian LDA (RxLDA): Faster, shared covariance, good for well-separated classes
  • Bayesian GMM (RxGMM): More flexible, class-specific covariances, better for overlapping classes
Iteration count:
  • Training: 50-100 iterations (default: 50)
  • Calibration: 20-30 iterations (default: 20)
  • More iterations = better convergence but slower training

Cross-Validation

K-Fold Cross-Validation

Evaluate model performance systematically:
using NimbusSDK

# Prepare full dataset
all_features = load_all_features()
all_labels = load_all_labels()

# K-fold cross-validation
k_folds = 5
n_trials = size(all_features, 3)
fold_size = div(n_trials, k_folds)

accuracies = Float64[]

for fold in 1:k_folds
    println("\n=== Fold $fold/$k_folds ===")
    
    # Split data
    test_idx = ((fold-1)*fold_size + 1):(fold*fold_size)
    train_idx = setdiff(1:n_trials, test_idx)
    
    train_features = all_features[:, :, train_idx]
    train_labels_fold = all_labels[train_idx]
    test_features = all_features[:, :, test_idx]
    test_labels_fold = all_labels[test_idx]
    
    # Train model
    train_data = BCIData(train_features, metadata, train_labels_fold)
    model = train_model(RxLDAModel, train_data; iterations=50)
    
    # Test model
    test_data = BCIData(test_features, metadata, test_labels_fold)
    results = predict_batch(model, test_data)
    
    accuracy = sum(results.predictions .== test_labels_fold) / length(test_labels_fold)
    push!(accuracies, accuracy)
    println("Fold accuracy: $(round(accuracy * 100, digits=1))%")
end

# Overall results
mean_acc = mean(accuracies)
std_acc = std(accuracies)
println("\n=== Cross-Validation Results ===")
println("Mean accuracy: $(round(mean_acc * 100, digits=1))% ± $(round(std_acc * 100, digits=1))%")

Performance Metrics

Information Transfer Rate (ITR)

Calculate BCI communication rate:
using NimbusSDK

# Run inference
results = predict_batch(model, data)

# Calculate accuracy
accuracy = sum(results.predictions .== labels) / length(labels)

# Calculate ITR (bits/minute)
n_classes = 4
trial_duration = 4.0  # seconds

itr = calculate_ITR(accuracy, n_classes, trial_duration)

println("Accuracy: $(round(accuracy * 100, digits=1))%")
println("ITR: $(round(itr, digits=1)) bits/minute")

Online Performance Tracking

Track metrics across trials:
using NimbusSDK

# Initialize tracker
tracker = OnlinePerformanceTracker(window_size=50)

# Process trials
for (pred, true_label, conf) in zip(results.predictions, labels, results.confidences)
    metrics = update_and_report!(tracker, pred, true_label, conf)
    
    # Print running metrics every 10 trials
    if length(tracker.recent_correct) % 10 == 0
        println("Running accuracy: $(round(metrics.accuracy * 100, digits=1))%")
    end
end

# Get final comprehensive metrics
final_metrics = get_metrics(tracker, n_classes=4, trial_duration=4.0)

println("\n=== Session Metrics ===")
println("Accuracy: $(round(final_metrics.accuracy * 100, digits=1))%")
println("ITR: $(round(final_metrics.information_transfer_rate, digits=1)) bits/min")
println("Mean confidence: $(round(final_metrics.mean_confidence, digits=3))")
println("Mean latency: $(round(final_metrics.mean_latency_ms, digits=1)) ms")

Quality Assessment

Trial Quality Analysis

Identify low-quality trials:
using NimbusSDK

# Run batch inference
results = predict_batch(model, data)

# Assess quality for each trial
low_confidence_trials = Int[]
for (i, confidence) in enumerate(results.confidences)
    if should_reject_trial(confidence, 0.7)
        push!(low_confidence_trials, i)
    end
end

println("Low confidence trials: $(length(low_confidence_trials))/$(length(results.predictions))")
println("Rejection rate: $(round(100 * length(low_confidence_trials) / length(results.predictions), digits=1))%")

# Overall quality assessment
quality = assess_trial_quality(results)
println("\nOverall quality score: $(round(quality.overall_score, digits=2))")
println("Confidence acceptable: $(quality.confidence_acceptable)")
println("Recommendation: $(quality.recommendation)")

Preprocessing Diagnostics

Check preprocessing quality:
using NimbusSDK

# Diagnose preprocessing issues
report = diagnose_preprocessing(data)

println("Preprocessing Quality: $(round(report.quality_score * 100, digits=1))%")

if !isempty(report.errors)
    println("\n⚠️  ERRORS:")
    for error in report.errors
        println("  • $error")
    end
end

if !isempty(report.warnings)
    println("\n⚠️  WARNINGS:")
    for warning in report.warnings
        println("  • $warning")
    end
end

if !isempty(report.recommendations)
    println("\n💡 RECOMMENDATIONS:")
    for rec in report.recommendations
        println("  • $rec")
    end
end

Model Comparison

Compare Different Models

using NimbusSDK

# Train multiple models
rxlda = train_model(RxLDAModel, train_data; iterations=50)
rxgmm = train_model(RxGMMModel, train_data; iterations=50)

# Test both models
models = [rxlda, rxgmm]
model_names = ["RxLDA", "RxGMM"]

println("=== Model Comparison ===")
for (model, name) in zip(models, model_names)
    results = predict_batch(model, test_data)
    accuracy = sum(results.predictions .== test_labels) / length(test_labels)
    mean_conf = mean(results.confidences)
    
    println("\n$name:")
    println("  Accuracy: $(round(accuracy * 100, digits=1))%")
    println("  Mean confidence: $(round(mean_conf, digits=3))")
end

Best Practices

Data Organization

# Organize your data clearly
struct BCISession
    features::Array{Float64, 3}  # (n_features × n_samples × n_trials)
    labels::Vector{Int}           # Trial labels
    subject_id::String
    session_id::String
    date::Date
end

# Process multiple sessions
function process_sessions(sessions::Vector{BCISession})
    for session in sessions
        data = BCIData(session.features, metadata, session.labels)
        results = predict_batch(model, data)
        
        # Save results
        save_results(session.subject_id, session.session_id, results)
    end
end

Error Handling

using NimbusSDK

try
    results = predict_batch(model, data)
catch e
    if isa(e, DataValidationError)
        @error "Data validation failed" exception=e
        # Check data format and preprocessing
    elseif isa(e, ModelCompatibilityError)
        @error "Model incompatible with data" exception=e
        # Check n_features, n_classes match
    else
        @error "Inference failed" exception=e
        rethrow()
    end
end

Next Steps


Next: Learn about streaming inference for real-time BCI applications.