Skip to main content

Streaming Inference

Process EEG data in real-time with chunk-by-chunk inference for online BCI applications.

Overview

Streaming inference allows you to process EEG data as it arrives, making predictions on partial trials before the full trial is complete. This is essential for real-time BCI applications where low latency is critical. Supported models: NimbusLDA, NimbusQDA, NimbusSoftmax, and NimbusSTS. Key Features:
  • Chunk-by-chunk processing: Process data as it streams in
  • Weighted aggregation: Combine chunk predictions intelligently
  • Quality assessment: Monitor confidence and reject low-quality trials
  • Low latency: Get predictions within milliseconds

Basic Streaming Setup

1. Train a Model

First, train a classifier on your calibration data:
from nimbus_bci import NimbusLDA
import numpy as np

# Train on calibration data
clf = NimbusLDA()
clf.fit(X_calibration, y_calibration)

2. Create Streaming Session

Set up a streaming session with metadata:
from nimbus_bci import StreamingSession
from nimbus_bci.data import BCIMetadata

metadata = BCIMetadata(
    sampling_rate=250.0,          # Hz
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,               # 500ms chunks at 250 Hz
    temporal_aggregation="logvar" # For CSP features
)

session = StreamingSession(clf.model_, metadata)

3. Process Chunks

Process data chunks as they arrive:
# Simulate streaming (in practice, from real-time EEG)
for i, chunk in enumerate(eeg_stream):
    # chunk shape: (n_features, chunk_size)
    result = session.process_chunk(chunk)
    
    print(f"Chunk {i+1}:")
    print(f"  Prediction: class {result.prediction}")
    print(f"  Confidence: {result.confidence:.2%}")
    print(f"  Entropy: {result.entropy:.2f} bits")

4. Finalize Trial

After all chunks, get the final prediction:
final = session.finalize_trial(method="weighted_vote")

print(f"\nFinal Prediction: class {final.prediction}")
print(f"Final Confidence: {final.confidence:.2%}")
print(f"Chunk predictions: {final.chunk_predictions}")

5. Reset for Next Trial

Reset the session for the next trial:
session.reset()

Complete Streaming Example

from nimbus_bci import NimbusLDA, StreamingSession
from nimbus_bci.data import BCIMetadata
import numpy as np

# 1. Train model
clf = NimbusLDA()
clf.fit(X_train, y_train)

# 2. Setup streaming
metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,  # 500ms chunks
    temporal_aggregation="logvar"
)

session = StreamingSession(clf.model_, metadata)

# 3. Process multiple trials
n_trials = 10
n_chunks_per_trial = 4

for trial_idx in range(n_trials):
    print(f"\n=== Trial {trial_idx + 1} ===")
    
    # Process chunks for this trial
    for chunk_idx in range(n_chunks_per_trial):
        # Get chunk from EEG stream (shape: n_features x chunk_size)
        chunk = get_next_chunk()  # Your streaming function
        
        result = session.process_chunk(chunk)
        print(f"Chunk {chunk_idx + 1}: class {result.prediction} "
              f"({result.confidence:.2%})")
    
    # Finalize trial
    final = session.finalize_trial(method="weighted_vote")
    print(f"Final: class {final.prediction} ({final.confidence:.2%})")
    
    # Check if we should reject this trial
    if final.confidence < 0.7:
        print("⚠️ Low confidence - trial rejected")
    
    # Reset for next trial
    session.reset()

Aggregation Methods

Different methods for combining chunk predictions:

Weighted Vote

Weight chunks by confidence (default):
final = session.finalize_trial(method="weighted_vote")
Chunks with higher confidence have more influence on the final prediction.

Majority Vote

Simple majority voting:
final = session.finalize_trial(method="majority_vote")
Each chunk gets equal weight regardless of confidence.

Last Chunk

Use only the most recent chunk:
final = session.finalize_trial(method="last_chunk")
Useful when later chunks are more informative.

Average Probabilities

Average probability distributions:
final = session.finalize_trial(method="average_probs")
Smooth predictions across all chunks.

Temporal Aggregation

For time-series features, aggregate temporally before classification:

Log Variance (CSP)

For CSP features in motor imagery:
metadata = BCIMetadata(
    # ... other params ...
    temporal_aggregation="logvar"
)
Computes log variance across time dimension.

Mean

Simple temporal average:
metadata = BCIMetadata(
    # ... other params ...
    temporal_aggregation="mean"
)

RMS (Root Mean Square)

For power-based features:
metadata = BCIMetadata(
    # ... other params ...
    temporal_aggregation="rms"
)

Quality Assessment

Monitor and reject low-quality trials:
from nimbus_bci import should_reject_trial

# Process trial
for chunk in trial_chunks:
    result = session.process_chunk(chunk)

final = session.finalize_trial()

# Check quality
if should_reject_trial(final.confidence, threshold=0.7):
    print("Trial rejected - low confidence")
    # Ask user to repeat or skip
else:
    # Use prediction
    execute_command(final.prediction)

Adaptive Thresholds

Adjust rejection threshold based on application:
# High-stakes application (wheelchair control)
safety_threshold = 0.85

# Low-stakes application (game control)
game_threshold = 0.6

if final.confidence < safety_threshold:
    print("High-confidence prediction required")

Real-Time Motor Imagery BCI

Complete example for motor imagery:
from nimbus_bci import NimbusLDA, StreamingSession
from nimbus_bci.data import BCIMetadata
from nimbus_bci.compat import extract_csp_features
import mne

# 1. Calibration phase
print("=== Calibration Phase ===")

# Load calibration data
raw = mne.io.read_raw_gdf("calibration.gdf", preload=True)
events = mne.find_events(raw)
epochs = mne.Epochs(raw, events, tmin=0, tmax=4, baseline=None, preload=True)
epochs.filter(8, 30)  # Mu + Beta bands

# Extract CSP features
csp_features, csp = extract_csp_features(epochs, n_components=8)

# Train classifier
clf = NimbusLDA()
clf.fit(csp_features, epochs.events[:, 2])
print(f"Calibration complete. Accuracy: {clf.score(csp_features, epochs.events[:, 2]):.2%}")

# 2. Online phase
print("\n=== Online Phase ===")

metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=8,
    n_classes=4,
    chunk_size=125,  # 500ms chunks
    temporal_aggregation="logvar"
)

session = StreamingSession(clf.model_, metadata)

# Process real-time EEG
while True:
    # Wait for trial start cue
    wait_for_cue()
    
    # Process 4 seconds of data in chunks
    for t in range(8):  # 8 chunks of 500ms
        # Get raw EEG chunk
        raw_chunk = acquire_eeg_chunk(duration=0.5)
        
        # Apply CSP transform
        csp_chunk = csp.transform(raw_chunk)
        
        # Process chunk
        result = session.process_chunk(csp_chunk)
        print(f"Chunk {t+1}: {result.prediction} ({result.confidence:.2%})")
    
    # Finalize and execute
    final = session.finalize_trial(method="weighted_vote")
    
    if final.confidence > 0.75:
        print(f"✓ Command: {final.prediction}")
        execute_command(final.prediction)
    else:
        print("✗ Low confidence - no action")
    
    session.reset()

P300 Speller Streaming

For P300-based BCI:
from nimbus_bci import NimbusQDA, StreamingSession
from nimbus_bci.data import BCIMetadata

# Train on P300 data
clf = NimbusQDA()  # QDA often works better for P300
clf.fit(erp_features, labels)  # 0=non-target, 1=target

# Streaming setup
metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="p300",
    feature_type="erp",
    n_features=32,
    n_classes=2,
    chunk_size=50,  # 200ms chunks
    temporal_aggregation="mean"
)

session = StreamingSession(clf.model_, metadata)

# Process P300 epochs
for flash in flashing_sequence:
    # Get ERP response
    erp_chunk = extract_erp(flash)
    
    result = session.process_chunk(erp_chunk)
    
    if result.prediction == 1 and result.confidence > 0.8:
        print(f"Target detected: {flash.character}")

Performance Optimization

Batch Processing

Process multiple chunks at once for efficiency:
# Instead of processing one chunk at a time
for chunk in chunks:
    result = session.process_chunk(chunk)

# Process in batches (if your application allows slight delay)
batch_size = 4
for i in range(0, len(chunks), batch_size):
    batch = chunks[i:i+batch_size]
    results = [session.process_chunk(c) for c in batch]

Pre-compute Features

Extract features offline when possible:
# Pre-compute CSP transform
csp_transformer = fit_csp(calibration_data)

# During streaming, just apply transform
while streaming:
    raw_chunk = acquire_eeg()
    csp_chunk = csp_transformer.transform(raw_chunk)  # Fast
    result = session.process_chunk(csp_chunk)

Optimize Chunk Size

Balance latency vs. accuracy:
# Smaller chunks = lower latency, less accurate
metadata_fast = BCIMetadata(chunk_size=62)  # 250ms

# Larger chunks = higher latency, more accurate
metadata_accurate = BCIMetadata(chunk_size=250)  # 1000ms

# Find sweet spot for your application

Error Handling

Handle streaming errors gracefully:
from nimbus_bci import StreamingSession

session = StreamingSession(clf.model_, metadata)

try:
    while streaming:
        chunk = acquire_eeg_chunk()
        
        # Validate chunk shape
        if chunk.shape != (n_features, chunk_size):
            print(f"Invalid chunk shape: {chunk.shape}")
            continue
        
        # Check for artifacts
        if has_artifacts(chunk):
            print("Artifact detected - skipping chunk")
            continue
        
        result = session.process_chunk(chunk)
        
        # Check for NaN
        if np.isnan(result.confidence):
            print("Invalid result - resetting session")
            session.reset()
            continue
        
        # Process result
        handle_result(result)
        
except KeyboardInterrupt:
    print("Streaming stopped by user")
finally:
    session.reset()
    cleanup()

Monitoring and Logging

Track streaming performance:
import time
from collections import deque

class StreamingMonitor:
    def __init__(self, window_size=100):
        self.latencies = deque(maxlen=window_size)
        self.confidences = deque(maxlen=window_size)
        self.predictions = deque(maxlen=window_size)
    
    def log_chunk(self, result, latency):
        self.latencies.append(latency)
        self.confidences.append(result.confidence)
        self.predictions.append(result.prediction)
    
    def report(self):
        print(f"Mean latency: {np.mean(self.latencies):.1f}ms")
        print(f"Mean confidence: {np.mean(self.confidences):.2%}")
        print(f"Prediction distribution: {np.bincount(self.predictions)}")

# Use monitor
monitor = StreamingMonitor()

for chunk in stream:
    start = time.time()
    result = session.process_chunk(chunk)
    latency = (time.time() - start) * 1000  # ms
    
    monitor.log_chunk(result, latency)
    
    if len(monitor.latencies) % 100 == 0:
        monitor.report()

NimbusSTS Streaming (Stateful, Adaptive)

NimbusSTS provides unique streaming capabilities for non-stationary data with automatic state propagation.

StreamingSessionSTS

Unlike static models, NimbusSTS maintains a latent state that evolves over time. The StreamingSessionSTS class automatically handles state propagation between chunks.
from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
from nimbus_bci.data import BCIMetadata

# Train adaptive model
clf = NimbusSTS(
    transition_cov=0.05,  # Moderate drift tracking
    num_steps=100
)
clf.fit(X_train, y_train)

# Setup streaming session
metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,
    temporal_aggregation="logvar"
)

# StreamingSessionSTS automatically manages state
session = StreamingSessionSTS(clf, metadata)

# Process chunks - state propagates automatically
for chunk in eeg_stream:
    result = session.process_chunk(chunk)
    print(f"Prediction: {result.prediction}, Confidence: {result.confidence:.2%}")

# Finalize and reset
final = session.finalize_trial()
session.reset()

Key Differences: STS vs Static Models

FeatureStreamingSessionSTS (NimbusSTS)StreamingSession (LDA/GMM/Softmax)
State Management✅ Automatic state propagation❌ Stateless
Drift Adaptation✅ Adapts to temporal changes❌ Fixed parameters
Chunk Correlation✅ Considers temporal dependencies❌ Treats chunks independently
Use CaseLong sessions, adaptive BCIShort, stationary sessions
Latency~20-30ms per chunk~10-25ms per chunk

Online Learning with Streaming STS

Combine streaming inference with online adaptation:
from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
from nimbus_bci.data import BCIMetadata

# Train initial model
clf = NimbusSTS(transition_cov=0.05, learning_rate=0.1)
clf.fit(X_calibration, y_calibration)

# Setup streaming
metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,
    temporal_aggregation="logvar"
)

session = StreamingSessionSTS(clf, metadata)

# Online session with delayed feedback
for trial_data in online_trials:
    # Process chunks for this trial
    for chunk in trial_data.chunks:
        result = session.process_chunk(chunk)
        print(f"Chunk prediction: {result.prediction}")
    
    # Get final prediction
    final = session.finalize_trial()
    
    # Execute BCI command
    execute_command(final.prediction)
    
    # Wait for user feedback
    true_label = wait_for_feedback()
    
    # Update model with feedback (adapts to user)
    # Aggregate chunks back to trial level
    trial_features = aggregate_chunks(trial_data.chunks)
    clf.partial_fit(trial_features.reshape(1, -1), [true_label])
    
    # Reset session for next trial
    session.reset()

State Monitoring During Streaming

Monitor how the latent state evolves:
from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
import matplotlib.pyplot as plt

clf = NimbusSTS(transition_cov=0.05)
clf.fit(X_train, y_train)

session = StreamingSessionSTS(clf, metadata)

# Track state evolution
state_history = []
uncertainty_history = []

for trial_idx in range(n_trials):
    for chunk in trial_chunks[trial_idx]:
        result = session.process_chunk(chunk)
        
        # Get current state
        z_mean, z_cov = clf.get_latent_state()
        state_history.append(z_mean.copy())
        uncertainty_history.append(np.trace(z_cov))
    
    session.finalize_trial()
    session.reset()

# Visualize state evolution
state_history = np.array(state_history)
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(state_history)
plt.xlabel('Chunk')
plt.ylabel('Latent State')
plt.title('State Evolution Over Time')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(uncertainty_history)
plt.xlabel('Chunk')
plt.ylabel('Uncertainty (trace of cov)')
plt.title('State Uncertainty')
plt.grid(True)

plt.tight_layout()
plt.show()

When to Use StreamingSessionSTS

Use StreamingSessionSTS when:
  • Sessions last >30 minutes (fatigue/drift effects)
  • You observe accuracy degradation over time
  • Electrode impedance changes during session
  • Need online adaptation with delayed feedback
  • Cross-day experiments with state transfer
Use StreamingSession (static models) when:
  • Short sessions (<15 minutes)
  • Stable recording conditions
  • Need absolute minimum latency (<15ms)
  • Data is stationary

Complete Adaptive BCI Example

from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
from nimbus_bci.data import BCIMetadata
import numpy as np

# Initial calibration
clf = NimbusSTS(
    transition_cov=0.05,  # Moderate drift
    learning_rate=0.1,
    num_steps=100,
    verbose=False
)
clf.fit(X_calibration, y_calibration)

# Setup streaming
metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,
    temporal_aggregation="logvar"
)

session = StreamingSessionSTS(clf, metadata)

# Long session (e.g., 60 minutes)
accuracy_window = []
window_size = 20

for trial_idx, trial in enumerate(long_session):
    # Process trial chunks
    for chunk in trial.chunks:
        result = session.process_chunk(chunk)
    
    # Get final prediction
    final = session.finalize_trial()
    prediction = final.prediction
    
    # Execute command
    bci_command = execute_action(prediction)
    
    # Get feedback after action completes
    true_label = get_feedback(bci_command)
    
    # Track accuracy
    correct = (prediction == true_label)
    accuracy_window.append(correct)
    if len(accuracy_window) > window_size:
        accuracy_window.pop(0)
    
    recent_acc = np.mean(accuracy_window)
    
    # Update model (online learning)
    trial_features = aggregate_trial(trial.chunks)
    clf.partial_fit(trial_features.reshape(1, -1), [true_label])
    
    # Monitor performance
    if (trial_idx + 1) % 10 == 0:
        print(f"Trial {trial_idx + 1}: Recent accuracy = {recent_acc:.1%}")
        
        # Get state info
        z_mean, z_cov = clf.get_latent_state()
        uncertainty = np.trace(z_cov)
        print(f"  State uncertainty: {uncertainty:.3f}")
    
    # Reset for next trial
    session.reset()

print("\nSession complete!")
print(f"Final accuracy (last {window_size} trials): {recent_acc:.1%}")

# Save adapted model for next session
z_final, P_final = clf.get_latent_state()
import pickle
with open("adapted_model_day1.pkl", "wb") as f:
    pickle.dump({
        'model': clf,
        'state': (z_final, P_final),
        'metadata': metadata
    }, f)
Pro Tip: For long BCI sessions, monitor the state uncertainty. If it grows too large, consider a brief recalibration (5-10 trials) to re-anchor the model.

Best Practices

1. Always Reset Between Trials

# Good
session.process_chunk(chunk1)
final = session.finalize_trial()
session.reset()  # Clear state

# Bad - state leaks between trials
session.process_chunk(chunk1)
final = session.finalize_trial()
session.process_chunk(chunk2)  # Wrong!

2. Validate Chunk Shapes

expected_shape = (n_features, chunk_size)
if chunk.shape != expected_shape:
    raise ValueError(f"Expected {expected_shape}, got {chunk.shape}")

3. Monitor Confidence

if result.confidence < 0.5:
    print("Warning: Low confidence prediction")

4. Handle Missing Data

if chunk is None or np.any(np.isnan(chunk)):
    print("Missing data - skipping chunk")
    continue

Next Steps