Batch Processing for BCI
Batch processing is essential for offline analysis, model training, research studies, and comprehensive data evaluation. NimbusSDK provides efficient batch inference capabilities optimized for processing multiple trials simultaneously.
When to Use Batch Processing
Batch processing is ideal for:
- Model training and calibration: Train custom models on labeled data
- Offline data analysis: Analyze recorded BCI sessions
- Cross-validation: Evaluate model performance systematically
- Research studies: Process data from multiple subjects/sessions
- Performance benchmarking: Compare different models or parameters
- Quality assessment: Identify problematic trials or sessions
Batch processing trades real-time responsiveness for computational efficiency. Use streaming inference for online/real-time applications.
Basic Batch Inference
Simple Batch Setup
using NimbusSDK
# 1. Authenticate
NimbusSDK.install_core("your-api-key")
# 2. Load model
model = load_model(RxLDAModel, "motor_imagery_4class_v1")
# 3. Prepare batch data
# Features shape: (n_features × n_samples × n_trials)
features = load_preprocessed_features() # Your preprocessed CSP/bandpower features
labels = load_trial_labels() # Ground truth labels (optional)
metadata = BCIMetadata(
sampling_rate = 250.0,
paradigm = :motor_imagery,
feature_type = :csp,
n_features = 16,
n_classes = 4,
chunk_size = nothing # Batch mode (not streaming)
)
data = BCIData(features, metadata, labels)
# 4. Run batch inference
results = predict_batch(model, data; iterations=10)
# 5. Analyze results
println("Processed $(length(results.predictions)) trials")
println("Mean confidence: $(round(mean(results.confidences), digits=3))")
if !isnothing(labels)
accuracy = sum(results.predictions .== labels) / length(labels)
println("Accuracy: $(round(accuracy * 100, digits=1))%")
end
Typical performance:
- 200 trials with 16 features: ~5 seconds (~25ms per trial)
- 1000 trials with 32 features: ~30 seconds (~30ms per trial)
- Scales linearly with number of trials and features
Batch inference is more efficient than processing trials individually. Process all your trials in one predict_batch() call when possible.
Training Custom Models
Basic Model Training
Train a custom model on your own labeled data:
using NimbusSDK
# Authenticate
NimbusSDK.install_core("your-api-key")
# Prepare training data with labels
train_features = load_training_features() # (n_features × n_samples × n_trials)
train_labels = load_training_labels() # Vector: [1, 2, 1, 3, 4, ...]
train_data = BCIData(
train_features,
BCIMetadata(
sampling_rate = 250.0,
paradigm = :motor_imagery,
feature_type = :csp,
n_features = 16,
n_classes = 4
),
train_labels # Labels required for training!
)
# Train Bayesian LDA model
model = train_model(
RxLDAModel,
train_data;
iterations = 50, # Number of inference iterations
showprogress = true, # Display progress bar
name = "my_motor_imagery",
description = "Custom 4-class motor imagery with CSP"
)
# Save trained model
save_model(model, "my_trained_model.jld2")
println("✓ Model trained and saved")
# Evaluate on test data
test_data = BCIData(test_features, metadata, test_labels)
test_results = predict_batch(model, test_data)
test_accuracy = sum(test_results.predictions .== test_labels) / length(test_labels)
println("Test accuracy: $(round(test_accuracy * 100, digits=1))%")
Subject-Specific Calibration
Fine-tune a pre-trained model with subject-specific data:
using NimbusSDK
# Load pre-trained baseline model
base_model = load_model(RxLDAModel, "motor_imagery_baseline_v1")
# Collect calibration data (10-20 trials is often sufficient)
calib_features = collect_calibration_trials() # Your function
calib_labels = [1, 2, 3, 4, 1, 2, 3, 4, ...] # Labels for calibration
calib_data = BCIData(
calib_features,
BCIMetadata(
sampling_rate = 250.0,
paradigm = :motor_imagery,
feature_type = :csp,
n_features = 16,
n_classes = 4
),
calib_labels
)
# Calibrate (faster than training from scratch)
personalized_model = calibrate_model(
base_model,
calib_data;
iterations = 20 # Fewer iterations needed for calibration
)
# Save personalized model
save_model(personalized_model, "subject_001_calibrated.jld2")
println("✓ Model calibrated for subject")
Training Best Practices
Data requirements:
- Minimum: 50-100 trials per class
- Recommended: 200+ trials per class for robust models
- Calibration: 10-20 trials per class sufficient
Model selection:
- Bayesian LDA (RxLDA): Faster, shared covariance, good for well-separated classes
- Bayesian GMM (RxGMM): More flexible, class-specific covariances, better for overlapping classes
Iteration count:
- Training: 50-100 iterations (default: 50)
- Calibration: 20-30 iterations (default: 20)
- More iterations = better convergence but slower training
Cross-Validation
K-Fold Cross-Validation
Evaluate model performance systematically:
using NimbusSDK
# Prepare full dataset
all_features = load_all_features()
all_labels = load_all_labels()
# K-fold cross-validation
k_folds = 5
n_trials = size(all_features, 3)
fold_size = div(n_trials, k_folds)
accuracies = Float64[]
for fold in 1:k_folds
println("\n=== Fold $fold/$k_folds ===")
# Split data
test_idx = ((fold-1)*fold_size + 1):(fold*fold_size)
train_idx = setdiff(1:n_trials, test_idx)
train_features = all_features[:, :, train_idx]
train_labels_fold = all_labels[train_idx]
test_features = all_features[:, :, test_idx]
test_labels_fold = all_labels[test_idx]
# Train model
train_data = BCIData(train_features, metadata, train_labels_fold)
model = train_model(RxLDAModel, train_data; iterations=50)
# Test model
test_data = BCIData(test_features, metadata, test_labels_fold)
results = predict_batch(model, test_data)
accuracy = sum(results.predictions .== test_labels_fold) / length(test_labels_fold)
push!(accuracies, accuracy)
println("Fold accuracy: $(round(accuracy * 100, digits=1))%")
end
# Overall results
mean_acc = mean(accuracies)
std_acc = std(accuracies)
println("\n=== Cross-Validation Results ===")
println("Mean accuracy: $(round(mean_acc * 100, digits=1))% ± $(round(std_acc * 100, digits=1))%")
Calculate BCI communication rate:
using NimbusSDK
# Run inference
results = predict_batch(model, data)
# Calculate accuracy
accuracy = sum(results.predictions .== labels) / length(labels)
# Calculate ITR (bits/minute)
n_classes = 4
trial_duration = 4.0 # seconds
itr = calculate_ITR(accuracy, n_classes, trial_duration)
println("Accuracy: $(round(accuracy * 100, digits=1))%")
println("ITR: $(round(itr, digits=1)) bits/minute")
Track metrics across trials:
using NimbusSDK
# Initialize tracker
tracker = OnlinePerformanceTracker(window_size=50)
# Process trials
for (pred, true_label, conf) in zip(results.predictions, labels, results.confidences)
metrics = update_and_report!(tracker, pred, true_label, conf)
# Print running metrics every 10 trials
if length(tracker.recent_correct) % 10 == 0
println("Running accuracy: $(round(metrics.accuracy * 100, digits=1))%")
end
end
# Get final comprehensive metrics
final_metrics = get_metrics(tracker, n_classes=4, trial_duration=4.0)
println("\n=== Session Metrics ===")
println("Accuracy: $(round(final_metrics.accuracy * 100, digits=1))%")
println("ITR: $(round(final_metrics.information_transfer_rate, digits=1)) bits/min")
println("Mean confidence: $(round(final_metrics.mean_confidence, digits=3))")
println("Mean latency: $(round(final_metrics.mean_latency_ms, digits=1)) ms")
Quality Assessment
Trial Quality Analysis
Identify low-quality trials:
using NimbusSDK
# Run batch inference
results = predict_batch(model, data)
# Assess quality for each trial
low_confidence_trials = Int[]
for (i, confidence) in enumerate(results.confidences)
if should_reject_trial(confidence, 0.7)
push!(low_confidence_trials, i)
end
end
println("Low confidence trials: $(length(low_confidence_trials))/$(length(results.predictions))")
println("Rejection rate: $(round(100 * length(low_confidence_trials) / length(results.predictions), digits=1))%")
# Overall quality assessment
quality = assess_trial_quality(results)
println("\nOverall quality score: $(round(quality.overall_score, digits=2))")
println("Confidence acceptable: $(quality.confidence_acceptable)")
println("Recommendation: $(quality.recommendation)")
Preprocessing Diagnostics
Check preprocessing quality:
using NimbusSDK
# Diagnose preprocessing issues
report = diagnose_preprocessing(data)
println("Preprocessing Quality: $(round(report.quality_score * 100, digits=1))%")
if !isempty(report.errors)
println("\n⚠️ ERRORS:")
for error in report.errors
println(" • $error")
end
end
if !isempty(report.warnings)
println("\n⚠️ WARNINGS:")
for warning in report.warnings
println(" • $warning")
end
end
if !isempty(report.recommendations)
println("\n💡 RECOMMENDATIONS:")
for rec in report.recommendations
println(" • $rec")
end
end
Model Comparison
Compare Different Models
using NimbusSDK
# Train multiple models
rxlda = train_model(RxLDAModel, train_data; iterations=50)
rxgmm = train_model(RxGMMModel, train_data; iterations=50)
# Test both models
models = [rxlda, rxgmm]
model_names = ["RxLDA", "RxGMM"]
println("=== Model Comparison ===")
for (model, name) in zip(models, model_names)
results = predict_batch(model, test_data)
accuracy = sum(results.predictions .== test_labels) / length(test_labels)
mean_conf = mean(results.confidences)
println("\n$name:")
println(" Accuracy: $(round(accuracy * 100, digits=1))%")
println(" Mean confidence: $(round(mean_conf, digits=3))")
end
Best Practices
Data Organization
# Organize your data clearly
struct BCISession
features::Array{Float64, 3} # (n_features × n_samples × n_trials)
labels::Vector{Int} # Trial labels
subject_id::String
session_id::String
date::Date
end
# Process multiple sessions
function process_sessions(sessions::Vector{BCISession})
for session in sessions
data = BCIData(session.features, metadata, session.labels)
results = predict_batch(model, data)
# Save results
save_results(session.subject_id, session.session_id, results)
end
end
Error Handling
using NimbusSDK
try
results = predict_batch(model, data)
catch e
if isa(e, DataValidationError)
@error "Data validation failed" exception=e
# Check data format and preprocessing
elseif isa(e, ModelCompatibilityError)
@error "Model incompatible with data" exception=e
# Check n_features, n_classes match
else
@error "Inference failed" exception=e
rethrow()
end
end
Next Steps