Skip to main content

Code Samples

Complete, production-ready Julia code examples for BCI inference with NimbusSDK.jl. Each example includes full implementation, error handling, and best practices.

Basic Motor Imagery Inference

Complete example for 4-class Motor Imagery classification using Bayesian LDA (RxLDA).
using NimbusSDK
using Statistics

# Step 1: Setup (one-time)
println("Installing core...")
NimbusSDK.install_core("nbci_live_your_api_key_here")
println("✓ Core installation successful")

# Step 2: Load model
println("\nLoading model...")
model = load_model(RxLDAModel, "motor_imagery_4class_v1")
println("✓ Model loaded: $(model.metadata.name)")
println("  - Features: $(get_n_features(model))")
println("  - Classes: $(get_n_classes(model))")

# Step 3: Prepare data (features already preprocessed!)
println("\nPreparing data...")
n_features = 16
n_samples = 250  # 1 second at 250 Hz
n_trials = 20

# In practice, load your preprocessed CSP features
features = randn(n_features, n_samples, n_trials)
labels = rand(1:4, n_trials)

metadata = BCIMetadata(
    sampling_rate = 250.0,
    paradigm = :motor_imagery,
    feature_type = :csp,
    n_features = n_features,
    n_classes = 4,
    chunk_size = nothing
)

data = BCIData(features, metadata, labels)
println("✓ Data prepared: $n_trials trials")

# Step 4: Run inference
println("\nRunning batch inference...")
results = predict_batch(model, data; iterations=10)

# Step 5: Analyze results
println("\nResults:")
for (i, (pred, conf)) in enumerate(zip(results.predictions, results.confidences))
    println("  Trial $i: Class $pred (confidence: $(round(conf, digits=3)))")
end

println("\nOverall Performance:")
accuracy = sum(results.predictions .== labels) / length(labels)
println("  Accuracy: $(round(accuracy * 100, digits=1))%")
println("  Mean confidence: $(round(mean(results.confidences), digits=3))")

# Calculate ITR
itr = calculate_ITR(accuracy, 4, 4.0)
println("  ITR: $(round(itr, digits=1)) bits/minute")

# Quality assessment
quality = assess_trial_quality(results)
println("  Quality score: $(round(quality.overall_score, digits=3))")

if !quality.confidence_acceptable
    println("⚠️  Warning: Low confidence trials detected")
end

Training Custom Model

Example of training a Bayesian LDA (RxLDA) model on your own labeled data.
using NimbusSDK

println("="^60)
println("Training Custom RxLDA Model")
println("="^60)

# Step 1: Setup (one-time)
NimbusSDK.install_core("your-api-key")

# Step 2: Prepare training data
println("\n1. Loading training data...")
train_features = load_csp_features("train_data.mat")  # Your function
train_labels = load_labels("train_labels.mat")         # Your function

println("  Data loaded: $(size(train_features))")

# Step 3: Create BCIData
train_data = BCIData(
    train_features,
    BCIMetadata(
        sampling_rate = 250.0,
        paradigm = :motor_imagery,
        feature_type = :csp,
        n_features = size(train_features, 1),
        n_classes = length(unique(train_labels)),
        chunk_size = nothing
    ),
    train_labels
)

println("✓ Training data prepared")
println("  - Features: $(get_n_features(train_data.metadata))")
println("  - Classes: $(get_n_classes(train_data.metadata))")
println("  - Trials: $(length(train_labels))")

# Step 4: Train model
println("\n2. Training model...")
model = train_model(
    RxLDAModel,
    train_data;
    iterations = 50,
    showprogress = true,
    name = "my_motor_imagery_model",
    description = "4-class motor imagery with CSP"
)

println("\n✓ Model trained successfully")
println("  Model name: $(model.metadata.name)")

# Step 5: Save model
save_path = "my_motor_imagery_model.jld2"
save_model(model, save_path)
println("✓ Model saved to: $save_path")

# Step 6: Test on held-out data
println("\n3. Testing on held-out data...")
test_features = load_csp_features("test_data.mat")
test_labels = load_labels("test_labels.mat")

test_data = BCIData(
    test_features,
    train_data.metadata,  # Same metadata
    test_labels
)

test_results = predict_batch(model, test_data)

# Calculate accuracy
test_accuracy = sum(test_results.predictions .== test_labels) / length(test_labels)
println("✓ Test accuracy: $(round(test_accuracy * 100, digits=1))%")

println("\n" * "="^60)
println("Training complete!")
println("="^60)

Subject-Specific Calibration

Fine-tune a pre-trained model with a small subject-specific dataset.
using NimbusSDK

println("Subject-Specific Calibration")

# Step 1: Load baseline model
println("\n1. Loading baseline model...")
base_model = load_model(RxLDAModel, "motor_imagery_baseline_v1")
println("✓ Baseline model loaded")

# Step 2: Collect calibration data from new subject
println("\n2. Collecting calibration data...")
println("  (Collect 10-20 trials per class from subject)")

calib_features = collect_calibration_trials()  # Your function
calib_labels = [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 1, 2, 3, 4, 2, 3, 1, 4, 3, 2]

calib_data = BCIData(
    calib_features,
    BCIMetadata(
        sampling_rate = 250.0,
        paradigm = :motor_imagery,
        feature_type = :csp,
        n_features = 16,
        n_classes = 4
    ),
    calib_labels
)

println("✓ Calibration data prepared: $(length(calib_labels)) trials")

# Step 3: Calibrate model
println("\n3. Calibrating model...")
println("  (Fine-tuning with subject-specific data)")

personalized_model = calibrate_model(
    base_model,
    calib_data;
    iterations = 20  # Fewer than training from scratch
)

println("✓ Model calibrated")

# Step 4: Save personalized model
save_path = "subject_001_calibrated.jld2"
save_model(personalized_model, save_path)
println("✓ Saved: $save_path")

# Step 5: Compare performance
println("\n4. Performance comparison...")

# Test on same test data
test_results_baseline = predict_batch(base_model, test_data)
test_results_personalized = predict_batch(personalized_model, test_data)

acc_baseline = sum(test_results_baseline.predictions .== test_labels) / length(test_labels)
acc_personalized = sum(test_results_personalized.predictions .== test_labels) / length(test_labels)

println("  Baseline accuracy: $(round(acc_baseline * 100, digits=1))%")
println("  Personalized accuracy: $(round(acc_personalized * 100, digits=1))%")
println("  Improvement: +$(round((acc_personalized - acc_baseline) * 100, digits=1))%")

Streaming Inference

Real-time chunk-by-chunk processing for continuous BCI control.
using NimbusSDK

# Initialize streaming
model = load_model(RxLDAModel, "motor_imagery_4class_v1")

metadata = BCIMetadata(
    sampling_rate = 250.0,
    paradigm = :motor_imagery,
    feature_type = :csp,
    n_features = 16,
    n_classes = 4,
    chunk_size = 250  # 1 second chunks
)

session = init_streaming(model, metadata)

# Simulate real-time EEG stream
function process_real_time_eeg()
    for trial in 1:20  # 20 trials
        println("\n--- Trial $trial ---")
        
        for chunk_idx in 1:4  # 4 chunks per trial (4 seconds total)
            # In practice: get chunk from EEG hardware
            chunk = randn(16, 250)  # 16 features × 250 samples
            
            # Process chunk
            result = process_chunk(session, chunk)
            
            # Real-time feedback
            class_names = ["Left", "Right", "Feet", "Tongue"]
            println("  Chunk $chunk_idx: $(class_names[result.prediction]) " *
                   "(confidence: $(round(result.confidence, digits=3)))")
            
            # Implement control logic
            if result.confidence > 0.7
                handle_bci_command(result.prediction)
            end
        end
        
        # Finalize trial
        final = finalize_trial(session; method=:weighted_vote)
        println("✓ Final: $(class_names[final.prediction]) " *
               "(confidence: $(round(final.confidence, digits=3)))")
    end
end

# Your BCI control function
function handle_bci_command(prediction::Int)
    if prediction == 1
        # Move cursor left
        println("    ↖ Moving left")
    elseif prediction == 2
        # Move cursor right
        println("    ↗ Moving right")
    elseif prediction == 3
        # Move cursor up
        println("    ↑ Moving up")
    elseif prediction == 4
        # Move cursor down
        println("    ↓ Moving down")
    end
end

# Run the processing
process_real_time_eeg()

Error Handling

Robust error handling and recovery for production use.
using NimbusSDK
using Logging

# Configure logging
global_logger(ConsoleLogger(stderr, Logging.Info))

function robust_inference(model, data)
    """Run inference with comprehensive error handling"""
    
    try
        # Validate data first
        is_valid, errors = validate_data(data)
        if !is_valid
            @error "Data validation failed" errors
            return nothing
        end
        
        # Check model compatibility
        if !check_model_compatibility(model, data)
            @error "Model incompatible with data"
            return nothing
        end
        
        # Run inference
        results = predict_batch(model, data; iterations=10)
        
        # Check quality
        quality = assess_trial_quality(results)
        if !quality.confidence_acceptable
            @warn "Low quality inference" quality
        end
        
        return results
        
    catch e
        if isa(e, AuthenticationError)
            @error "Authentication failed - re-authenticate"
            reauthenticate()
        elseif isa(e, DataValidationError)
            @error "Invalid data format" error_msg(e)
        elseif isa(e, QuotaExceededError)
            @error "API quota exceeded" e
        else
            @error "Unexpected error" exception=(e, catch_backtrace())
        end
        return nothing
    end
end

# Usage
model = load_model(RxLDAModel, "motor_imagery_4class_v1")
results = robust_inference(model, data)

if !isnothing(results)
    println("Inference successful!")
    println("Accuracy: ", sum(results.predictions .== labels) / length(labels))
else
    println("Inference failed - check logs")
end

Data Loading and Preprocessing

Load data from various formats and validate for NimbusSDK.
using NimbusSDK, MAT, CSV, DataFrames

function load_csp_features_from_mat(filepath::String)
    """Load CSP features from .mat file"""
    data = matread(filepath)
    
    features = data["features"]  # (n_features × n_samples × n_trials)
    labels = Int.(vec(data["labels"]))
    
    # Validate shape
    @assert ndims(features) == 3 "Features must be 3D"
    @assert length(labels) == size(features, 3) "Labels must match number of trials"
    
    # Check for NaN/Inf
    if any(isnan, features) || any(isinf, features)
        error("Features contain NaN or Inf values")
    end
    
    println("Loaded: $(size(features))")
    println("Classes: $(unique(labels))")
    
    return features, labels
end

function create_bcidata(features, labels, metadata::BCIMetadata)
    """Create BCIData with validation"""
    
    # Validate dimensions
    n_features = size(features, 1)
    n_classes = length(unique(labels))
    
    @assert n_features == metadata.n_features "Feature count mismatch"
    @assert n_classes == metadata.n_classes "Class count mismatch"
    
    # Check labels are 1-indexed
    @assert minimum(labels) >= 1 "Labels must be 1-indexed"
    @assert maximum(labels) <= n_classes "Labels out of range"
    
    data = BCIData(features, metadata, labels)
    
    # Run diagnostics
    report = diagnose_preprocessing(data)
    if !isempty(report.errors)
        @warn "Preprocessing issues" report.errors
    end
    
    println("Data quality score: $(round(report.quality_score * 100, digits=1))%")
    
    return data
end

# Usage
features, labels = load_csp_features_from_mat("motor_imagery.mat")

metadata = BCIMetadata(
    sampling_rate = 250.0,
    paradigm = :motor_imagery,
    feature_type = :csp,
    n_features = 16,
    n_classes = 4
)

data = create_bcidata(features, labels, metadata)

Performance Metrics

Calculate comprehensive performance metrics including ITR.
using NimbusSDK, Statistics

function calculate_comprehensive_metrics(results, true_labels)
    """Calculate all performance metrics"""
    
    # Basic accuracy
    accuracy = sum(results.predictions .== true_labels) / length(true_labels)
    
    # Per-class accuracy
    n_classes = length(unique(true_labels))
    per_class_accuracy = [
        sum((results.predictions .== i) .& (true_labels .== i)) /
        sum(true_labels .== i)
        for i in 1:n_classes
    ]
    
    # Confusion matrix
    confusion = zeros(Int, n_classes, n_classes)
    for (pred, true_lbl) in zip(results.predictions, true_labels)
        confusion[pred, true_lbl] += 1
    end
    
    # Mean confidence
    mean_confidence = mean(results.confidences)
    
    # High confidence trials
    high_conf = sum(results.confidences .> 0.7) / length(results.confidences)
    
    return (
        accuracy = accuracy,
        per_class_accuracy = per_class_accuracy,
        confusion_matrix = confusion,
        mean_confidence = mean_confidence,
        high_confidence_rate = high_conf
    )
end

function print_performance_report(results, true_labels, trial_duration::Float64)
    """Print formatted performance report"""
    
    n_classes = length(unique(true_labels))
    accuracy = sum(results.predictions .== true_labels) / length(true_labels)
    itr = calculate_ITR(accuracy, n_classes, trial_duration)
    
    println("="^60)
    println("Performance Report")
    println("="^60)
    
    println("\nOverall Metrics:")
    println("  Accuracy: $(round(accuracy * 100, digits=1))%")
    println("  ITR: $(round(itr, digits=1)) bits/minute")
    println("  Mean confidence: $(round(mean(results.confidences), digits=3))")
    
    # Per-class performance
    class_names = ["Left", "Right", "Feet", "Tongue"][1:n_classes]
    metrics = calculate_comprehensive_metrics(results, true_labels)
    
    println("\nPer-Class Accuracy:")
    for (i, (name, acc)) in enumerate(zip(class_names, metrics.per_class_accuracy))
        println("  $name: $(round(acc * 100, digits=1))%")
    end
    
    println("\nConfusion Matrix:")
    println("  " * join([rpad("", 8) for _ in 1:n_classes], ""))
    for i in 1:n_classes
        print("$(rpad(class_names[i], 8))")
        for j in 1:n_classes
            print(rpad(metrics.confusion_matrix[i, j], 12))
        end
        println()
    end
    
    println("\nQuality Metrics:")
    println("  High confidence trials: $(round(metrics.high_confidence_rate * 100, digits=1))%")
    
    println("="^60)
end

# Usage
results = predict_batch(model, data)
print_performance_report(results, data.labels, 4.0)

Next Steps