Streaming Inference
Process EEG data in real-time with chunk-by-chunk inference for online BCI applications.
Overview
Streaming inference allows you to process EEG data as it arrives, making predictions on partial trials before the full trial is complete. This is essential for real-time BCI applications where low latency is critical.
Supported models: NimbusLDA, NimbusQDA, NimbusSoftmax, and NimbusSTS.
Key Features:
- Chunk-by-chunk processing: Process data as it streams in
- Weighted aggregation: Combine chunk predictions intelligently
- Quality assessment: Monitor confidence and reject low-quality trials
- Low latency: Get predictions within milliseconds
Basic Streaming Setup
1. Train a Model
First, train a classifier on your calibration data:
from nimbus_bci import NimbusLDA
import numpy as np
# Train on calibration data
clf = NimbusLDA()
clf.fit(X_calibration, y_calibration)
2. Create Streaming Session
Set up a streaming session with metadata:
from nimbus_bci import StreamingSession
from nimbus_bci.data import BCIMetadata
metadata = BCIMetadata(
sampling_rate=250.0, # Hz
paradigm="motor_imagery",
feature_type="csp",
n_features=16,
n_classes=4,
chunk_size=125, # 500ms chunks at 250 Hz
temporal_aggregation="logvar" # For CSP features
)
session = StreamingSession(clf.model_, metadata)
3. Process Chunks
Process data chunks as they arrive:
# Simulate streaming (in practice, from real-time EEG)
for i, chunk in enumerate(eeg_stream):
# chunk shape: (n_features, chunk_size)
result = session.process_chunk(chunk)
print(f"Chunk {i+1}:")
print(f" Prediction: class {result.prediction}")
print(f" Confidence: {result.confidence:.2%}")
print(f" Entropy: {result.entropy:.2f} bits")
4. Finalize Trial
After all chunks, get the final prediction:
final = session.finalize_trial(method="weighted_vote")
print(f"\nFinal Prediction: class {final.prediction}")
print(f"Final Confidence: {final.confidence:.2%}")
print(f"Chunk predictions: {final.chunk_predictions}")
5. Reset for Next Trial
Reset the session for the next trial:
Complete Streaming Example
from nimbus_bci import NimbusLDA, StreamingSession
from nimbus_bci.data import BCIMetadata
import numpy as np
# 1. Train model
clf = NimbusLDA()
clf.fit(X_train, y_train)
# 2. Setup streaming
metadata = BCIMetadata(
sampling_rate=250.0,
paradigm="motor_imagery",
feature_type="csp",
n_features=16,
n_classes=4,
chunk_size=125, # 500ms chunks
temporal_aggregation="logvar"
)
session = StreamingSession(clf.model_, metadata)
# 3. Process multiple trials
n_trials = 10
n_chunks_per_trial = 4
for trial_idx in range(n_trials):
print(f"\n=== Trial {trial_idx + 1} ===")
# Process chunks for this trial
for chunk_idx in range(n_chunks_per_trial):
# Get chunk from EEG stream (shape: n_features x chunk_size)
chunk = get_next_chunk() # Your streaming function
result = session.process_chunk(chunk)
print(f"Chunk {chunk_idx + 1}: class {result.prediction} "
f"({result.confidence:.2%})")
# Finalize trial
final = session.finalize_trial(method="weighted_vote")
print(f"Final: class {final.prediction} ({final.confidence:.2%})")
# Check if we should reject this trial
if final.confidence < 0.7:
print("⚠️ Low confidence - trial rejected")
# Reset for next trial
session.reset()
Aggregation Methods
Different methods for combining chunk predictions:
Weighted Vote
Weight chunks by confidence (default):
final = session.finalize_trial(method="weighted_vote")
Chunks with higher confidence have more influence on the final prediction.
Majority Vote
Simple majority voting:
final = session.finalize_trial(method="majority_vote")
Each chunk gets equal weight regardless of confidence.
Last Chunk
Use only the most recent chunk:
final = session.finalize_trial(method="last_chunk")
Useful when later chunks are more informative.
Average Probabilities
Average probability distributions:
final = session.finalize_trial(method="average_probs")
Smooth predictions across all chunks.
Temporal Aggregation
For time-series features, aggregate temporally before classification:
Log Variance (CSP)
For CSP features in motor imagery:
metadata = BCIMetadata(
# ... other params ...
temporal_aggregation="logvar"
)
Computes log variance across time dimension.
Mean
Simple temporal average:
metadata = BCIMetadata(
# ... other params ...
temporal_aggregation="mean"
)
RMS (Root Mean Square)
For power-based features:
metadata = BCIMetadata(
# ... other params ...
temporal_aggregation="rms"
)
Quality Assessment
Monitor and reject low-quality trials:
from nimbus_bci import should_reject_trial
# Process trial
for chunk in trial_chunks:
result = session.process_chunk(chunk)
final = session.finalize_trial()
# Check quality
if should_reject_trial(final.confidence, threshold=0.7):
print("Trial rejected - low confidence")
# Ask user to repeat or skip
else:
# Use prediction
execute_command(final.prediction)
Adaptive Thresholds
Adjust rejection threshold based on application:
# High-stakes application (wheelchair control)
safety_threshold = 0.85
# Low-stakes application (game control)
game_threshold = 0.6
if final.confidence < safety_threshold:
print("High-confidence prediction required")
Real-Time Motor Imagery BCI
Complete example for motor imagery:
from nimbus_bci import NimbusLDA, StreamingSession
from nimbus_bci.data import BCIMetadata
from nimbus_bci.compat import extract_csp_features
import mne
# 1. Calibration phase
print("=== Calibration Phase ===")
# Load calibration data
raw = mne.io.read_raw_gdf("calibration.gdf", preload=True)
events = mne.find_events(raw)
epochs = mne.Epochs(raw, events, tmin=0, tmax=4, baseline=None, preload=True)
epochs.filter(8, 30) # Mu + Beta bands
# Extract CSP features
csp_features, csp = extract_csp_features(epochs, n_components=8)
# Train classifier
clf = NimbusLDA()
clf.fit(csp_features, epochs.events[:, 2])
print(f"Calibration complete. Accuracy: {clf.score(csp_features, epochs.events[:, 2]):.2%}")
# 2. Online phase
print("\n=== Online Phase ===")
metadata = BCIMetadata(
sampling_rate=250.0,
paradigm="motor_imagery",
feature_type="csp",
n_features=8,
n_classes=4,
chunk_size=125, # 500ms chunks
temporal_aggregation="logvar"
)
session = StreamingSession(clf.model_, metadata)
# Process real-time EEG
while True:
# Wait for trial start cue
wait_for_cue()
# Process 4 seconds of data in chunks
for t in range(8): # 8 chunks of 500ms
# Get raw EEG chunk
raw_chunk = acquire_eeg_chunk(duration=0.5)
# Apply CSP transform
csp_chunk = csp.transform(raw_chunk)
# Process chunk
result = session.process_chunk(csp_chunk)
print(f"Chunk {t+1}: {result.prediction} ({result.confidence:.2%})")
# Finalize and execute
final = session.finalize_trial(method="weighted_vote")
if final.confidence > 0.75:
print(f"✓ Command: {final.prediction}")
execute_command(final.prediction)
else:
print("✗ Low confidence - no action")
session.reset()
P300 Speller Streaming
For P300-based BCI:
from nimbus_bci import NimbusQDA, StreamingSession
from nimbus_bci.data import BCIMetadata
# Train on P300 data
clf = NimbusQDA() # QDA often works better for P300
clf.fit(erp_features, labels) # 0=non-target, 1=target
# Streaming setup
metadata = BCIMetadata(
sampling_rate=250.0,
paradigm="p300",
feature_type="erp",
n_features=32,
n_classes=2,
chunk_size=50, # 200ms chunks
temporal_aggregation="mean"
)
session = StreamingSession(clf.model_, metadata)
# Process P300 epochs
for flash in flashing_sequence:
# Get ERP response
erp_chunk = extract_erp(flash)
result = session.process_chunk(erp_chunk)
if result.prediction == 1 and result.confidence > 0.8:
print(f"Target detected: {flash.character}")
Batch Processing
Process multiple chunks at once for efficiency:
# Instead of processing one chunk at a time
for chunk in chunks:
result = session.process_chunk(chunk)
# Process in batches (if your application allows slight delay)
batch_size = 4
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i+batch_size]
results = [session.process_chunk(c) for c in batch]
Pre-compute Features
Extract features offline when possible:
# Pre-compute CSP transform
csp_transformer = fit_csp(calibration_data)
# During streaming, just apply transform
while streaming:
raw_chunk = acquire_eeg()
csp_chunk = csp_transformer.transform(raw_chunk) # Fast
result = session.process_chunk(csp_chunk)
Optimize Chunk Size
Balance latency vs. accuracy:
# Smaller chunks = lower latency, less accurate
metadata_fast = BCIMetadata(chunk_size=62) # 250ms
# Larger chunks = higher latency, more accurate
metadata_accurate = BCIMetadata(chunk_size=250) # 1000ms
# Find sweet spot for your application
Error Handling
Handle streaming errors gracefully:
from nimbus_bci import StreamingSession
session = StreamingSession(clf.model_, metadata)
try:
while streaming:
chunk = acquire_eeg_chunk()
# Validate chunk shape
if chunk.shape != (n_features, chunk_size):
print(f"Invalid chunk shape: {chunk.shape}")
continue
# Check for artifacts
if has_artifacts(chunk):
print("Artifact detected - skipping chunk")
continue
result = session.process_chunk(chunk)
# Check for NaN
if np.isnan(result.confidence):
print("Invalid result - resetting session")
session.reset()
continue
# Process result
handle_result(result)
except KeyboardInterrupt:
print("Streaming stopped by user")
finally:
session.reset()
cleanup()
Monitoring and Logging
Track streaming performance:
import time
from collections import deque
class StreamingMonitor:
def __init__(self, window_size=100):
self.latencies = deque(maxlen=window_size)
self.confidences = deque(maxlen=window_size)
self.predictions = deque(maxlen=window_size)
def log_chunk(self, result, latency):
self.latencies.append(latency)
self.confidences.append(result.confidence)
self.predictions.append(result.prediction)
def report(self):
print(f"Mean latency: {np.mean(self.latencies):.1f}ms")
print(f"Mean confidence: {np.mean(self.confidences):.2%}")
print(f"Prediction distribution: {np.bincount(self.predictions)}")
# Use monitor
monitor = StreamingMonitor()
for chunk in stream:
start = time.time()
result = session.process_chunk(chunk)
latency = (time.time() - start) * 1000 # ms
monitor.log_chunk(result, latency)
if len(monitor.latencies) % 100 == 0:
monitor.report()
NimbusSTS Streaming (Stateful, Adaptive)
NimbusSTS provides unique streaming capabilities for non-stationary data with automatic state propagation.
StreamingSessionSTS
Unlike static models, NimbusSTS maintains a latent state that evolves over time. The StreamingSessionSTS class automatically handles state propagation between chunks.
from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
from nimbus_bci.data import BCIMetadata
# Train adaptive model
clf = NimbusSTS(
transition_cov=0.05, # Moderate drift tracking
num_steps=100
)
clf.fit(X_train, y_train)
# Setup streaming session
metadata = BCIMetadata(
sampling_rate=250.0,
paradigm="motor_imagery",
feature_type="csp",
n_features=16,
n_classes=4,
chunk_size=125,
temporal_aggregation="logvar"
)
# StreamingSessionSTS automatically manages state
session = StreamingSessionSTS(clf, metadata)
# Process chunks - state propagates automatically
for chunk in eeg_stream:
result = session.process_chunk(chunk)
print(f"Prediction: {result.prediction}, Confidence: {result.confidence:.2%}")
# Finalize and reset
final = session.finalize_trial()
session.reset()
Key Differences: STS vs Static Models
| Feature | StreamingSessionSTS (NimbusSTS) | StreamingSession (LDA/GMM/Softmax) |
|---|
| State Management | ✅ Automatic state propagation | ❌ Stateless |
| Drift Adaptation | ✅ Adapts to temporal changes | ❌ Fixed parameters |
| Chunk Correlation | ✅ Considers temporal dependencies | ❌ Treats chunks independently |
| Use Case | Long sessions, adaptive BCI | Short, stationary sessions |
| Latency | ~20-30ms per chunk | ~10-25ms per chunk |
Online Learning with Streaming STS
Combine streaming inference with online adaptation:
from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
from nimbus_bci.data import BCIMetadata
# Train initial model
clf = NimbusSTS(transition_cov=0.05, learning_rate=0.1)
clf.fit(X_calibration, y_calibration)
# Setup streaming
metadata = BCIMetadata(
sampling_rate=250.0,
paradigm="motor_imagery",
feature_type="csp",
n_features=16,
n_classes=4,
chunk_size=125,
temporal_aggregation="logvar"
)
session = StreamingSessionSTS(clf, metadata)
# Online session with delayed feedback
for trial_data in online_trials:
# Process chunks for this trial
for chunk in trial_data.chunks:
result = session.process_chunk(chunk)
print(f"Chunk prediction: {result.prediction}")
# Get final prediction
final = session.finalize_trial()
# Execute BCI command
execute_command(final.prediction)
# Wait for user feedback
true_label = wait_for_feedback()
# Update model with feedback (adapts to user)
# Aggregate chunks back to trial level
trial_features = aggregate_chunks(trial_data.chunks)
clf.partial_fit(trial_features.reshape(1, -1), [true_label])
# Reset session for next trial
session.reset()
State Monitoring During Streaming
Monitor how the latent state evolves:
from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
import matplotlib.pyplot as plt
clf = NimbusSTS(transition_cov=0.05)
clf.fit(X_train, y_train)
session = StreamingSessionSTS(clf, metadata)
# Track state evolution
state_history = []
uncertainty_history = []
for trial_idx in range(n_trials):
for chunk in trial_chunks[trial_idx]:
result = session.process_chunk(chunk)
# Get current state
z_mean, z_cov = clf.get_latent_state()
state_history.append(z_mean.copy())
uncertainty_history.append(np.trace(z_cov))
session.finalize_trial()
session.reset()
# Visualize state evolution
state_history = np.array(state_history)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(state_history)
plt.xlabel('Chunk')
plt.ylabel('Latent State')
plt.title('State Evolution Over Time')
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(uncertainty_history)
plt.xlabel('Chunk')
plt.ylabel('Uncertainty (trace of cov)')
plt.title('State Uncertainty')
plt.grid(True)
plt.tight_layout()
plt.show()
When to Use StreamingSessionSTS
Use StreamingSessionSTS when:
- Sessions last >30 minutes (fatigue/drift effects)
- You observe accuracy degradation over time
- Electrode impedance changes during session
- Need online adaptation with delayed feedback
- Cross-day experiments with state transfer
Use StreamingSession (static models) when:
- Short sessions (<15 minutes)
- Stable recording conditions
- Need absolute minimum latency (<15ms)
- Data is stationary
Complete Adaptive BCI Example
from nimbus_bci import NimbusSTS
from nimbus_bci.inference import StreamingSessionSTS
from nimbus_bci.data import BCIMetadata
import numpy as np
# Initial calibration
clf = NimbusSTS(
transition_cov=0.05, # Moderate drift
learning_rate=0.1,
num_steps=100,
verbose=False
)
clf.fit(X_calibration, y_calibration)
# Setup streaming
metadata = BCIMetadata(
sampling_rate=250.0,
paradigm="motor_imagery",
feature_type="csp",
n_features=16,
n_classes=4,
chunk_size=125,
temporal_aggregation="logvar"
)
session = StreamingSessionSTS(clf, metadata)
# Long session (e.g., 60 minutes)
accuracy_window = []
window_size = 20
for trial_idx, trial in enumerate(long_session):
# Process trial chunks
for chunk in trial.chunks:
result = session.process_chunk(chunk)
# Get final prediction
final = session.finalize_trial()
prediction = final.prediction
# Execute command
bci_command = execute_action(prediction)
# Get feedback after action completes
true_label = get_feedback(bci_command)
# Track accuracy
correct = (prediction == true_label)
accuracy_window.append(correct)
if len(accuracy_window) > window_size:
accuracy_window.pop(0)
recent_acc = np.mean(accuracy_window)
# Update model (online learning)
trial_features = aggregate_trial(trial.chunks)
clf.partial_fit(trial_features.reshape(1, -1), [true_label])
# Monitor performance
if (trial_idx + 1) % 10 == 0:
print(f"Trial {trial_idx + 1}: Recent accuracy = {recent_acc:.1%}")
# Get state info
z_mean, z_cov = clf.get_latent_state()
uncertainty = np.trace(z_cov)
print(f" State uncertainty: {uncertainty:.3f}")
# Reset for next trial
session.reset()
print("\nSession complete!")
print(f"Final accuracy (last {window_size} trials): {recent_acc:.1%}")
# Save adapted model for next session
z_final, P_final = clf.get_latent_state()
import pickle
with open("adapted_model_day1.pkl", "wb") as f:
pickle.dump({
'model': clf,
'state': (z_final, P_final),
'metadata': metadata
}, f)
Pro Tip: For long BCI sessions, monitor the state uncertainty. If it grows too large, consider a brief recalibration (5-10 trials) to re-anchor the model.
Best Practices
1. Always Reset Between Trials
# Good
session.process_chunk(chunk1)
final = session.finalize_trial()
session.reset() # Clear state
# Bad - state leaks between trials
session.process_chunk(chunk1)
final = session.finalize_trial()
session.process_chunk(chunk2) # Wrong!
2. Validate Chunk Shapes
expected_shape = (n_features, chunk_size)
if chunk.shape != expected_shape:
raise ValueError(f"Expected {expected_shape}, got {chunk.shape}")
3. Monitor Confidence
if result.confidence < 0.5:
print("Warning: Low confidence prediction")
4. Handle Missing Data
if chunk is None or np.any(np.isnan(chunk)):
print("Missing data - skipping chunk")
continue
Next Steps