Skip to main content

Streaming Inference

Process EEG data in real-time with chunk-by-chunk inference for online BCI applications.

Overview

Streaming inference allows you to process EEG data as it arrives, making predictions on partial trials before the full trial is complete. This is essential for real-time BCI applications where low latency is critical. Key Features:
  • Chunk-by-chunk processing: Process data as it streams in
  • Weighted aggregation: Combine chunk predictions intelligently
  • Quality assessment: Monitor confidence and reject low-quality trials
  • Low latency: Get predictions within milliseconds

Basic Streaming Setup

1. Train a Model

First, train a classifier on your calibration data:
from nimbus_bci import NimbusLDA
import numpy as np

# Train on calibration data
clf = NimbusLDA()
clf.fit(X_calibration, y_calibration)

2. Create Streaming Session

Set up a streaming session with metadata:
from nimbus_bci import StreamingSession
from nimbus_bci.data import BCIMetadata

metadata = BCIMetadata(
    sampling_rate=250.0,          # Hz
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,               # 500ms chunks at 250 Hz
    temporal_aggregation="logvar" # For CSP features
)

session = StreamingSession(clf.model_, metadata)

3. Process Chunks

Process data chunks as they arrive:
# Simulate streaming (in practice, from real-time EEG)
for i, chunk in enumerate(eeg_stream):
    # chunk shape: (n_features, chunk_size)
    result = session.process_chunk(chunk)
    
    print(f"Chunk {i+1}:")
    print(f"  Prediction: class {result.prediction}")
    print(f"  Confidence: {result.confidence:.2%}")
    print(f"  Entropy: {result.entropy:.2f} bits")

4. Finalize Trial

After all chunks, get the final prediction:
final = session.finalize_trial(method="weighted_vote")

print(f"\nFinal Prediction: class {final.prediction}")
print(f"Final Confidence: {final.confidence:.2%}")
print(f"Chunk predictions: {final.chunk_predictions}")

5. Reset for Next Trial

Reset the session for the next trial:
session.reset()

Complete Streaming Example

from nimbus_bci import NimbusLDA, StreamingSession
from nimbus_bci.data import BCIMetadata
import numpy as np

# 1. Train model
clf = NimbusLDA()
clf.fit(X_train, y_train)

# 2. Setup streaming
metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,  # 500ms chunks
    temporal_aggregation="logvar"
)

session = StreamingSession(clf.model_, metadata)

# 3. Process multiple trials
n_trials = 10
n_chunks_per_trial = 4

for trial_idx in range(n_trials):
    print(f"\n=== Trial {trial_idx + 1} ===")
    
    # Process chunks for this trial
    for chunk_idx in range(n_chunks_per_trial):
        # Get chunk from EEG stream (shape: n_features x chunk_size)
        chunk = get_next_chunk()  # Your streaming function
        
        result = session.process_chunk(chunk)
        print(f"Chunk {chunk_idx + 1}: class {result.prediction} "
              f"({result.confidence:.2%})")
    
    # Finalize trial
    final = session.finalize_trial(method="weighted_vote")
    print(f"Final: class {final.prediction} ({final.confidence:.2%})")
    
    # Check if we should reject this trial
    if final.confidence < 0.7:
        print("⚠️ Low confidence - trial rejected")
    
    # Reset for next trial
    session.reset()

Aggregation Methods

Different methods for combining chunk predictions:

Weighted Vote

Weight chunks by confidence (default):
final = session.finalize_trial(method="weighted_vote")
Chunks with higher confidence have more influence on the final prediction.

Majority Vote

Simple majority voting:
final = session.finalize_trial(method="majority_vote")
Each chunk gets equal weight regardless of confidence.

Last Chunk

Use only the most recent chunk:
final = session.finalize_trial(method="last_chunk")
Useful when later chunks are more informative.

Average Probabilities

Average probability distributions:
final = session.finalize_trial(method="average_probs")
Smooth predictions across all chunks.

Temporal Aggregation

For time-series features, aggregate temporally before classification:

Log Variance (CSP)

For CSP features in motor imagery:
metadata = BCIMetadata(
    # ... other params ...
    temporal_aggregation="logvar"
)
Computes log variance across time dimension.

Mean

Simple temporal average:
metadata = BCIMetadata(
    # ... other params ...
    temporal_aggregation="mean"
)

RMS (Root Mean Square)

For power-based features:
metadata = BCIMetadata(
    # ... other params ...
    temporal_aggregation="rms"
)

Quality Assessment

Monitor and reject low-quality trials:
from nimbus_bci import should_reject_trial

# Process trial
for chunk in trial_chunks:
    result = session.process_chunk(chunk)

final = session.finalize_trial()

# Check quality
if should_reject_trial(final.confidence, threshold=0.7):
    print("Trial rejected - low confidence")
    # Ask user to repeat or skip
else:
    # Use prediction
    execute_command(final.prediction)

Adaptive Thresholds

Adjust rejection threshold based on application:
# High-stakes application (wheelchair control)
safety_threshold = 0.85

# Low-stakes application (game control)
game_threshold = 0.6

if final.confidence < safety_threshold:
    print("High-confidence prediction required")

Real-Time Motor Imagery BCI

Complete example for motor imagery:
from nimbus_bci import NimbusLDA, StreamingSession
from nimbus_bci.data import BCIMetadata
from nimbus_bci.compat import extract_csp_features
import mne

# 1. Calibration phase
print("=== Calibration Phase ===")

# Load calibration data
raw = mne.io.read_raw_gdf("calibration.gdf", preload=True)
events = mne.find_events(raw)
epochs = mne.Epochs(raw, events, tmin=0, tmax=4, baseline=None, preload=True)
epochs.filter(8, 30)  # Mu + Beta bands

# Extract CSP features
csp_features, csp = extract_csp_features(epochs, n_components=8)

# Train classifier
clf = NimbusLDA()
clf.fit(csp_features, epochs.events[:, 2])
print(f"Calibration complete. Accuracy: {clf.score(csp_features, epochs.events[:, 2]):.2%}")

# 2. Online phase
print("\n=== Online Phase ===")

metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=8,
    n_classes=4,
    chunk_size=125,  # 500ms chunks
    temporal_aggregation="logvar"
)

session = StreamingSession(clf.model_, metadata)

# Process real-time EEG
while True:
    # Wait for trial start cue
    wait_for_cue()
    
    # Process 4 seconds of data in chunks
    for t in range(8):  # 8 chunks of 500ms
        # Get raw EEG chunk
        raw_chunk = acquire_eeg_chunk(duration=0.5)
        
        # Apply CSP transform
        csp_chunk = csp.transform(raw_chunk)
        
        # Process chunk
        result = session.process_chunk(csp_chunk)
        print(f"Chunk {t+1}: {result.prediction} ({result.confidence:.2%})")
    
    # Finalize and execute
    final = session.finalize_trial(method="weighted_vote")
    
    if final.confidence > 0.75:
        print(f"✓ Command: {final.prediction}")
        execute_command(final.prediction)
    else:
        print("✗ Low confidence - no action")
    
    session.reset()

P300 Speller Streaming

For P300-based BCI:
from nimbus_bci import NimbusGMM, StreamingSession
from nimbus_bci.data import BCIMetadata

# Train on P300 data
clf = NimbusGMM()  # GMM works better for P300
clf.fit(erp_features, labels)  # 0=non-target, 1=target

# Streaming setup
metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="p300",
    feature_type="erp",
    n_features=32,
    n_classes=2,
    chunk_size=50,  # 200ms chunks
    temporal_aggregation="mean"
)

session = StreamingSession(clf.model_, metadata)

# Process P300 epochs
for flash in flashing_sequence:
    # Get ERP response
    erp_chunk = extract_erp(flash)
    
    result = session.process_chunk(erp_chunk)
    
    if result.prediction == 1 and result.confidence > 0.8:
        print(f"Target detected: {flash.character}")

Performance Optimization

Batch Processing

Process multiple chunks at once for efficiency:
# Instead of processing one chunk at a time
for chunk in chunks:
    result = session.process_chunk(chunk)

# Process in batches (if your application allows slight delay)
batch_size = 4
for i in range(0, len(chunks), batch_size):
    batch = chunks[i:i+batch_size]
    results = [session.process_chunk(c) for c in batch]

Pre-compute Features

Extract features offline when possible:
# Pre-compute CSP transform
csp_transformer = fit_csp(calibration_data)

# During streaming, just apply transform
while streaming:
    raw_chunk = acquire_eeg()
    csp_chunk = csp_transformer.transform(raw_chunk)  # Fast
    result = session.process_chunk(csp_chunk)

Optimize Chunk Size

Balance latency vs. accuracy:
# Smaller chunks = lower latency, less accurate
metadata_fast = BCIMetadata(chunk_size=62)  # 250ms

# Larger chunks = higher latency, more accurate
metadata_accurate = BCIMetadata(chunk_size=250)  # 1000ms

# Find sweet spot for your application

Error Handling

Handle streaming errors gracefully:
from nimbus_bci import StreamingSession

session = StreamingSession(clf.model_, metadata)

try:
    while streaming:
        chunk = acquire_eeg_chunk()
        
        # Validate chunk shape
        if chunk.shape != (n_features, chunk_size):
            print(f"Invalid chunk shape: {chunk.shape}")
            continue
        
        # Check for artifacts
        if has_artifacts(chunk):
            print("Artifact detected - skipping chunk")
            continue
        
        result = session.process_chunk(chunk)
        
        # Check for NaN
        if np.isnan(result.confidence):
            print("Invalid result - resetting session")
            session.reset()
            continue
        
        # Process result
        handle_result(result)
        
except KeyboardInterrupt:
    print("Streaming stopped by user")
finally:
    session.reset()
    cleanup()

Monitoring and Logging

Track streaming performance:
import time
from collections import deque

class StreamingMonitor:
    def __init__(self, window_size=100):
        self.latencies = deque(maxlen=window_size)
        self.confidences = deque(maxlen=window_size)
        self.predictions = deque(maxlen=window_size)
    
    def log_chunk(self, result, latency):
        self.latencies.append(latency)
        self.confidences.append(result.confidence)
        self.predictions.append(result.prediction)
    
    def report(self):
        print(f"Mean latency: {np.mean(self.latencies):.1f}ms")
        print(f"Mean confidence: {np.mean(self.confidences):.2%}")
        print(f"Prediction distribution: {np.bincount(self.predictions)}")

# Use monitor
monitor = StreamingMonitor()

for chunk in stream:
    start = time.time()
    result = session.process_chunk(chunk)
    latency = (time.time() - start) * 1000  # ms
    
    monitor.log_chunk(result, latency)
    
    if len(monitor.latencies) % 100 == 0:
        monitor.report()

Best Practices

1. Always Reset Between Trials

# Good
session.process_chunk(chunk1)
final = session.finalize_trial()
session.reset()  # Clear state

# Bad - state leaks between trials
session.process_chunk(chunk1)
final = session.finalize_trial()
session.process_chunk(chunk2)  # Wrong!

2. Validate Chunk Shapes

expected_shape = (n_features, chunk_size)
if chunk.shape != expected_shape:
    raise ValueError(f"Expected {expected_shape}, got {chunk.shape}")

3. Monitor Confidence

if result.confidence < 0.5:
    print("Warning: Low confidence prediction")

4. Handle Missing Data

if chunk is None or np.any(np.isnan(chunk)):
    print("Missing data - skipping chunk")
    continue

Next Steps