> ## Documentation Index
> Fetch the complete documentation index at: https://docs.nimbusbci.com/llms.txt
> Use this file to discover all available pages before exploring further.

# Python SDK API Reference

> Full nimbus-bci Python API reference for NimbusLDA, NimbusQDA, NimbusSoftmax, NimbusSTS, and core inference utilities.

# Python SDK API Reference

Complete reference for the nimbus-bci Python library.

## Start Here

<CardGroup cols={3}>
  <Card title="Python SDK Quickstart" icon="rocket" href="/python-sdk/quickstart">
    Build and run your first classifier before diving into full API details.
  </Card>

  <Card title="Model Selection" icon="brain" href="/model-specification">
    Compare NimbusLDA, NimbusQDA, NimbusSoftmax, and NimbusSTS use cases.
  </Card>

  <Card title="Streaming Inference" icon="activity" href="/python-sdk/streaming-inference">
    Move from batch workflows to real-time BCI inference patterns.
  </Card>
</CardGroup>

## Classifiers

### NimbusLDA

Bayesian Linear Discriminant Analysis with shared covariance.

```python theme={null}
from nimbus_bci import NimbusLDA

clf = NimbusLDA(
    mu_loc=0.0,
    mu_scale=3.0,
    wishart_df=None,
    class_prior_alpha=1.0
)
```

**Parameters:**

* `mu_loc` (float, default=0.0): Prior mean location for class means
* `mu_scale` (float, default=3.0): Prior scale for class means (> 0)
* `wishart_df` (float or None, default=None): Wishart degrees of freedom. If None, set to `n_features + 2`
* `class_prior_alpha` (float, default=1.0): Dirichlet smoothing for class priors (≥ 0)

**Methods:**

* `fit(X, y)`: Fit the model
* `predict(X)`: Predict class labels
* `predict_proba(X)`: Predict class probabilities
* `partial_fit(X, y, classes=None)`: Incremental learning
* `score(X, y)`: Return accuracy score

**Attributes:**

* `classes_`: Unique class labels
* `n_classes_`: Number of classes
* `n_features_in_`: Number of features
* `model_`: Underlying Nimbus model

**Example:**

```python theme={null}
from nimbus_bci import NimbusLDA
import numpy as np

# Create and fit
clf = NimbusLDA(mu_scale=5.0)
clf.fit(X_train, y_train)

# Predict
predictions = clf.predict(X_test)
probabilities = clf.predict_proba(X_test)

# Online learning
clf.partial_fit(X_new, y_new)
```

### NimbusQDA

Bayesian QDA with class-specific covariances.

```python theme={null}
from nimbus_bci import NimbusQDA

clf = NimbusQDA(
    mu_loc=0.0,
    mu_scale=3.0,
    wishart_df=None,
    class_prior_alpha=1.0
)
```

**Parameters:**

* Same as NimbusLDA

**Methods:**

* Same as NimbusLDA

**Example:**

```python theme={null}
from nimbus_bci import NimbusQDA

# Better for overlapping distributions (e.g., P300)
clf = NimbusQDA()
clf.fit(X_train, y_train)
probs = clf.predict_proba(X_test)
```

### NimbusSoftmax

Bayesian Multinomial Logistic Regression (Polya-Gamma VI).

Install the optional extra before using this model:

```bash theme={null}
pip install nimbus-bci[softmax]
```

```python theme={null}
from nimbus_bci import NimbusSoftmax

clf = NimbusSoftmax(
    w_loc=0.0,
    w_scale=1.0,
    b_loc=0.0,
    b_scale=1.0,
    learning_rate=0.2,
    num_steps=50,
    num_posterior_samples=50,
    rng_seed=0
)
```

**Parameters:**

* `w_loc` (float, default=0.0): Prior mean for weights
* `w_scale` (float, default=1.0): Prior scale for weights
* `b_loc` (float, default=0.0): Prior mean for biases
* `b_scale` (float, default=1.0): Prior scale for biases
* `learning_rate` (float, default=0.2): Damping factor for variational updates
* `num_steps` (int, default=50): Number of variational update sweeps
* `num_posterior_samples` (int, default=50): Number of posterior samples for prediction
* `rng_seed` (int, default=0): Random seed for reproducibility

**Methods:**

* Same as NimbusLDA

**Example:**

```python theme={null}
from nimbus_bci import NimbusSoftmax

# For non-Gaussian decision boundaries
clf = NimbusSoftmax()
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)
```

### NimbusSTS

Bayesian Structural Time Series classifier with Extended Kalman Filter for non-stationary data.

```python theme={null}
from nimbus_bci import NimbusSTS

clf = NimbusSTS(
    state_dim=None,
    w_loc=0.0,
    w_scale=1.0,
    transition_cov=None,
    observation_cov=1.0,
    transition_matrix=None,
    learning_rate=0.1,
    num_steps=50,
    rng_seed=0,
    verbose=False
)
```

**Parameters:**

* `state_dim` (int or None, default=None): Dimension of latent state. If None, set to `n_classes - 1`
* `w_loc` (float, default=0.0): Prior mean for feature weights
* `w_scale` (float, default=1.0): Prior scale for feature weights
* `transition_cov` (float or None, default=None): Process noise covariance Q (controls drift speed). If None, auto-estimated. Typical values:
  * 0.001: Very slow drift (multi-day stability)
  * 0.01: Moderate drift (within-session adaptation)
  * 0.1: Fast drift (rapid environmental changes)
* `observation_cov` (float, default=1.0): Observation noise covariance R
* `transition_matrix` (ndarray or None, default=None): State transition matrix A. If None, uses identity (random walk)
* `learning_rate` (float, default=0.1): Step size for parameter updates
* `num_steps` (int, default=50): Number of learning iterations
* `rng_seed` (int, default=0): Random seed for reproducibility
* `verbose` (bool, default=False): Print convergence diagnostics during training

**Methods:**

* `fit(X, y)`: Fit the model
* `predict(X)`: Predict class labels (stateless)
* `predict_proba(X)`: Predict class probabilities (stateless)
* `partial_fit(X, y, classes=None)`: Incremental learning with EKF update
* `score(X, y)`: Return accuracy score
* `propagate_state(n_steps=1)`: Advance latent state using prior dynamics only
* `reset_state()`: Reset latent state to initial values from training
* `get_latent_state()`: Get current latent state (z\_mean, z\_cov)
* `set_latent_state(z_mean, z_cov=None)`: Set latent state manually

**Attributes:**

* `classes_`: Unique class labels
* `n_classes_`: Number of classes
* `n_features_in_`: Number of features
* `model_`: Underlying Nimbus model with state parameters

**Example - Basic Usage:**

```python theme={null}
from nimbus_bci import NimbusSTS
import numpy as np

# Create and fit
clf = NimbusSTS(transition_cov=0.05, num_steps=100)
clf.fit(X_train, y_train)

# Standard prediction (stateless)
predictions = clf.predict(X_test)
probabilities = clf.predict_proba(X_test)
```

**Example - Stateful Prediction:**

```python theme={null}
from nimbus_bci import NimbusSTS

# Train model
clf = NimbusSTS(transition_cov=0.05)
clf.fit(X_train, y_train)

# Time-ordered prediction with state propagation
for x_t in X_stream:
    clf.propagate_state()  # Advance time
    pred = clf.predict(x_t.reshape(1, -1))
    print(f"Prediction: {pred}")
```

**Example - Online Learning with Delayed Feedback:**

```python theme={null}
from nimbus_bci import NimbusSTS

# Initial training
clf = NimbusSTS(transition_cov=0.05, learning_rate=0.1)
clf.fit(X_calibration, y_calibration)

# Online session
for trial in online_session:
    # 1. Advance time (no measurement)
    clf.propagate_state()
    
    # 2. Predict using current state
    prediction = clf.predict(trial.features.reshape(1, -1))[0]
    
    # 3. Execute action
    execute_action(prediction)
    
    # 4. Get feedback after action completes
    true_label = wait_for_feedback()
    
    # 5. Update state with measurement
    clf.partial_fit(trial.features.reshape(1, -1), [true_label])
```

**Example - State Inspection and Transfer:**

```python theme={null}
from nimbus_bci import NimbusSTS

# Day 1: Train and save state
clf_day1 = NimbusSTS()
clf_day1.fit(X_day1, y_day1)
z_final, P_final = clf_day1.get_latent_state()

# Day 2: Transfer state with increased uncertainty
clf_day2 = NimbusSTS()
clf_day2.fit(X_day2_calib, y_day2_calib)  # Minimal calibration
clf_day2.set_latent_state(z_final * 0.5, P_final * 2.0)

# Use with transferred state
predictions = clf_day2.predict(X_day2_test)
```

<Note>
  **Key Differences from Other Classifiers:**

  * **Stateful**: Maintains and evolves latent state over time
  * **Non-stationary**: Designed for data with temporal drift
  * **State Management**: Explicit API for time propagation and state control
  * **Use case**: Long sessions, cross-day transfer, adaptive BCI

  See [Bayesian STS Documentation](/models/rxsts) for complete usage guide.
</Note>

## Optional Riemannian Pipelines

### `make_riemann_nimbus_pipeline()`

Factory for composing pyRiemann covariance/tangent feature extraction with an existing Nimbus classifier head.

Install the optional extra before using this API:

```bash theme={null}
pip install nimbus-bci[riemann]
```

```python theme={null}
from nimbus_bci import NimbusLDA
from nimbus_bci.riemann import make_riemann_nimbus_pipeline

pipe = make_riemann_nimbus_pipeline(
    head=NimbusLDA(),
    covariance_estimator="oas",
    tangent_metric="riemann",
)
```

**Parameters:**

* `head` (object or None, default=None): Final sklearn-compatible classifier exposing `fit`, `predict`, and `predict_proba`. Defaults to `NimbusLDA()` when omitted.
* `covariance_estimator` (str, default=`"oas"`): Estimator passed to `pyriemann.estimation.Covariances`.
* `tangent_metric` (str, default=`"riemann"`): Metric passed to `pyriemann.tangentspace.TangentSpace`.

**Returns:**

* `sklearn.pipeline.Pipeline`: Pipeline with `covariances`, `tangent`, and `head` steps.

**Example:**

```python theme={null}
from nimbus_bci import NimbusLDA
from nimbus_bci.riemann import make_riemann_nimbus_pipeline

# X_epochs shape: (n_trials, n_channels, n_times)
pipe = make_riemann_nimbus_pipeline(head=NimbusLDA(mu_scale=5.0))
pipe.fit(X_epochs_train, y_train)

probs = pipe.predict_proba(X_epochs_test)
predictions = pipe.predict(X_epochs_test)
```

<Note>
  Nimbus supports Riemannian workflows by consuming tangent-space features from pyRiemann; this is not a separate Riemannian Bayesian model. The returned object is a standard sklearn `Pipeline`, so do not rely on pipeline-level `partial_fit()` for this path.
</Note>

## Data Structures

### BCIData

Container for BCI features, metadata, and labels.

```python theme={null}
from nimbus_bci.data import BCIData

data = BCIData(
    features,      # (n_features, n_samples, n_trials) or (n_features, n_samples)
    metadata,      # BCIMetadata instance
    labels=None    # Optional labels
)
```

**Parameters:**

* `features` (np.ndarray): Feature array of shape `(n_features, n_samples, n_trials)` for multiple trials or `(n_features, n_samples)` for one trial
* `metadata` (BCIMetadata): Metadata describing the data
* `labels` (np.ndarray, optional): Trial labels

**Attributes:**

* `features`: Feature array
* `metadata`: Metadata object
* `labels`: Labels (if provided)
* `n_trials`: Number of trials
* `n_samples`: Number of samples per trial

### BCIMetadata

Metadata for BCI experiments.

```python theme={null}
from nimbus_bci.data import BCIMetadata

metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=None,
    temporal_aggregation="mean"
)
```

**Parameters:**

* `sampling_rate` (float): Sampling rate in Hz
* `paradigm` (str): BCI paradigm (`"motor_imagery"`, `"p300"`, `"ssvep"`, `"erp"`, or `"custom"`)
* `feature_type` (str): Feature type (`"raw"`, `"csp"`, `"bandpower"`, `"erp_amplitude"`, or `"custom"`)
* `n_features` (int): Number of features
* `n_classes` (int): Number of classes
* `chunk_size` (int, optional): Chunk size for streaming
* `temporal_aggregation` (str, default="mean"): Aggregation method (`"mean"`, `"logvar"`, `"last"`, `"max"`, `"median"`, `"var"`, or `"std"`)

## Inference

### predict\_batch()

Batch inference with comprehensive diagnostics.

```python theme={null}
from nimbus_bci import predict_batch

result = predict_batch(
    model,           # Trained model
    data,            # BCIData instance
    num_posterior_samples=50,
    rng_seed=0
)
```

**Parameters:**

* `model` (NimbusModel): Trained Nimbus model
* `data` (BCIData): Data to predict on
* `num_posterior_samples` (int, default=50): Posterior samples for softmax models
* `rng_seed` (int, default=0): Random seed for softmax prediction

**Returns:**

* `BatchResult`: Result object with predictions, posteriors, entropy, and diagnostics

**Example:**

```python theme={null}
from nimbus_bci import predict_batch, NimbusLDA
from nimbus_bci.data import BCIData, BCIMetadata
import numpy as np

clf = NimbusLDA()
clf.fit(X_train, y_train)

metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4
)

# BCIData expects (n_features, n_samples, n_trials).
# Duplicate tabular test features across a short temporal axis for this example.
features = np.repeat(X_test.T[:, np.newaxis, :], 2, axis=1)
data = BCIData(features, metadata)

result = predict_batch(clf.model_, data)
print(f"Accuracy: {(result.predictions == y_test).mean():.2%}")
print(f"Mean entropy: {result.mean_entropy:.2f} bits")
```

### StreamingSession

Real-time chunk-by-chunk processing.

```python theme={null}
from nimbus_bci import StreamingSession

session = StreamingSession(
    model,      # Trained model
    metadata    # BCIMetadata with chunk_size
)
```

**Methods:**

* `process_chunk(chunk)`: Process one chunk, returns ChunkResult
* `finalize_trial(method="weighted_vote")`: Finalize trial, returns StreamingResult
* `reset()`: Reset session for new trial

**Example:**

```python theme={null}
from nimbus_bci import NimbusLDA, StreamingSession
from nimbus_bci.data import BCIMetadata

clf = NimbusLDA()
clf.fit(X_train, y_train)

metadata = BCIMetadata(
    sampling_rate=250.0,
    paradigm="motor_imagery",
    feature_type="csp",
    n_features=16,
    n_classes=4,
    chunk_size=125,  # 500ms chunks
    temporal_aggregation="logvar"
)

session = StreamingSession(clf.model_, metadata)

# Process chunks
for chunk in stream:
    result = session.process_chunk(chunk)
    print(f"Chunk: class {result.prediction} ({result.confidence:.2%})")

# Finalize
final = session.finalize_trial(method="weighted_vote")
print(f"Final: class {final.prediction}")
```

### ChunkResult

Result from processing a single chunk.

**Attributes:**

* `prediction` (int): Predicted class
* `confidence` (float): Confidence (max probability)
* `posterior` (np.ndarray): Class posterior probabilities
* `latency_ms` (float): Processing latency in milliseconds

### StreamingResult

Result from finalizing a trial.

**Attributes:**

* `prediction` (int): Final predicted class
* `confidence` (float): Final confidence
* `posterior` (np.ndarray): Aggregated posterior probabilities
* `chunk_posteriors` (list): Posterior from each chunk
* `entropy` (float): Final entropy
* `aggregation_method` (str): Method used for aggregation
* `n_chunks` (int): Number of chunks processed
* `latency_ms` (float): Total trial inference latency
* `chunk_latencies_ms` (list): Latency for each chunk
* `balance` (float): Class balance across chunks
* `calibration` (CalibrationMetrics or None): Calibration metrics if a label was provided

### BatchResult

Result from batch inference.

**Attributes:**

* `predictions` (np.ndarray): Predicted classes
* `confidences` (np.ndarray): Maximum posterior probability per trial
* `posteriors` (np.ndarray): Class posterior probabilities
* `entropy` (np.ndarray): Entropy per trial
* `mean_entropy` (float): Mean entropy
* `mahalanobis_distances` (np.ndarray): Distance to each class center
* `outlier_scores` (np.ndarray): Outlier score per trial
* `balance` (float): Class balance
* `latency_ms` (float): Inference latency
* `per_trial_latency_ms` (np.ndarray): Estimated latency per trial
* `calibration` (CalibrationMetrics or None): Calibration metrics if labels were provided

## Metrics

### compute\_entropy()

Compute Shannon entropy from probabilities.

```python theme={null}
from nimbus_bci import compute_entropy

entropy = compute_entropy(probabilities)  # bits
```

**Parameters:**

* `probabilities` (np.ndarray): Probability distributions

**Returns:**

* `float`: Entropy in bits. For a 2D probability matrix, this is the mean entropy across rows.

### compute\_calibration\_metrics()

Compute Expected Calibration Error (ECE) and Maximum Calibration Error (MCE).

```python theme={null}
from nimbus_bci import compute_calibration_metrics

metrics = compute_calibration_metrics(
    predictions,
    confidences,
    labels,
    n_bins=10
)
```

**Parameters:**

* `predictions` (np.ndarray): Predicted classes
* `confidences` (np.ndarray): Confidence scores
* `labels` (np.ndarray): True labels
* `n_bins` (int, default=10): Number of bins

**Returns:**

* `CalibrationMetrics`: Object with `ece` and `mce` attributes

### calculate\_itr()

Calculate Information Transfer Rate.

```python theme={null}
from nimbus_bci import calculate_itr

itr = calculate_itr(
    accuracy=0.85,
    n_classes=4,
    trial_duration=4.0
)
```

**Parameters:**

* `accuracy` (float): Classification accuracy (0-1)
* `n_classes` (int): Number of classes
* `trial_duration` (float): Trial duration in seconds

**Returns:**

* `float`: ITR in bits/minute

### assess\_trial\_quality()

Assess quality of predictions.

```python theme={null}
from nimbus_bci import assess_trial_quality

quality = assess_trial_quality(
    features,
    confidence,
    entropy=entropy,
    confidence_threshold=0.7
)
```

**Parameters:**

* `features` (np.ndarray): Trial features to check for NaN/Inf artifacts
* `confidence` (float): Prediction confidence in `[0, 1]`
* `confidence_threshold` (float, default=0.6): Minimum confidence for accepting prediction
* `outlier_threshold` (float, default=5.0): Maximum outlier score for accepting prediction
* `entropy` (float, optional): Prediction entropy in bits
* `outlier_score` (float, optional): Mahalanobis-based outlier score
* `entropy_threshold` (float, default=1.5): Maximum entropy for accepting prediction

**Returns:**

* `TrialQuality`: Object with quality metrics

### should\_reject\_trial()

Determine if trial should be rejected based on confidence.

```python theme={null}
from nimbus_bci import should_reject_trial

reject = should_reject_trial(confidence, threshold=0.7)
```

**Parameters:**

* `confidence` (float): Confidence score
* `threshold` (float, default=0.7): Rejection threshold

**Returns:**

* `bool`: True if trial should be rejected

## Utilities

### estimate\_normalization\_params()

Estimate normalization parameters from data.

```python theme={null}
from nimbus_bci import estimate_normalization_params

params = estimate_normalization_params(
    X,
    method="zscore"  # or "minmax", "robust"
)
```

**Parameters:**

* `X` (np.ndarray): Data array
* `method` (str): Normalization method

**Returns:**

* `NormalizationParams`: Parameters for normalization

### apply\_normalization()

Apply normalization to data.

```python theme={null}
from nimbus_bci import apply_normalization

X_normalized = apply_normalization(X, params)
```

**Parameters:**

* `X` (np.ndarray): Data to normalize
* `params` (NormalizationParams): Normalization parameters

**Returns:**

* `np.ndarray`: Normalized data

### diagnose\_preprocessing()

Diagnose preprocessing quality.

```python theme={null}
from nimbus_bci import diagnose_preprocessing

report = diagnose_preprocessing(
    bci_data
)
```

**Parameters:**

* `data` (BCIData): Data to diagnose

**Returns:**

* `PreprocessingReport`: Diagnostic report

### compute\_fisher\_score()

Compute Fisher score for feature discriminability.

```python theme={null}
from nimbus_bci import compute_fisher_score

scores = compute_fisher_score(X, y)
```

**Parameters:**

* `X` (np.ndarray): Features
* `y` (np.ndarray): Labels

**Returns:**

* `np.ndarray`: Fisher scores per feature

### rank\_features\_by\_discriminability()

Rank features by discriminability.

```python theme={null}
from nimbus_bci import rank_features_by_discriminability

ranking = rank_features_by_discriminability(X, y)
```

**Parameters:**

* `X` (np.ndarray): Features
* `y` (np.ndarray): Labels

**Returns:**

* `np.ndarray`: Feature indices sorted by discriminability

## MNE Integration

### from\_mne\_epochs()

Convert MNE Epochs to BCIData.

```python theme={null}
from nimbus_bci.compat import from_mne_epochs

data = from_mne_epochs(
    epochs,
    paradigm="motor_imagery",
    feature_type="raw"
)
```

**Parameters:**

* `epochs` (mne.Epochs): MNE Epochs object
* `paradigm` (str): BCI paradigm
* `feature_type` (str): Feature type

**Returns:**

* `BCIData`: Converted data

### extract\_csp\_features()

Extract CSP features from MNE Epochs.

```python theme={null}
from nimbus_bci.compat import extract_csp_features

features, csp = extract_csp_features(
    epochs,
    n_components=8
)
```

**Parameters:**

* `epochs` (mne.Epochs): MNE Epochs object
* `n_components` (int): Number of CSP components

**Returns:**

* `features` (np.ndarray): CSP features
* `csp` (mne.decoding.CSP): Fitted CSP object

### extract\_bandpower\_features()

Extract bandpower features from MNE Epochs.

```python theme={null}
from nimbus_bci.compat import extract_bandpower_features

features, band_names = extract_bandpower_features(
    epochs,
    bands={"mu": (8, 12), "beta": (13, 30)},
    log_transform=True
)
```

**Parameters:**

* `epochs` (mne.Epochs): MNE Epochs object
* `bands` (dict): Frequency bands
* `log_transform` (bool, default=True): Apply log transform to band powers

**Returns:**

* `tuple[np.ndarray, list[str]]`: Bandpower features and band names

### create\_bci\_pipeline()

Create complete BCI pipeline with MNE and nimbus-bci.

```python theme={null}
from nimbus_bci.compat import create_bci_pipeline
from nimbus_bci import NimbusLDA

pipeline = create_bci_pipeline(
    NimbusLDA,
    feature_extraction="csp",
    n_csp_components=8,
)
```

**Parameters:**

* `model_class` (class): Classifier class (`NimbusLDA`, `NimbusQDA`, or `NimbusSoftmax`)
* `preprocessor` (str, default="standard"): Preprocessing method (`"standard"`, `"robust"`, or `None`)
* `feature_extraction` (str, optional): Feature extraction method (`"csp"` or `None`)
* `n_csp_components` (int, default=8): Number of CSP components
* `**model_kwargs`: Additional arguments passed to the classifier

**Returns:**

* `sklearn.pipeline.Pipeline`: Complete pipeline

## Functional API (Backward Compatible)

### LDA Functions

```python theme={null}
from nimbus_bci import (
    nimbus_lda_fit,
    nimbus_lda_predict,
    nimbus_lda_predict_proba,
    nimbus_lda_update
)
import numpy as np

# Fit
model = nimbus_lda_fit(
    X,
    y,
    n_classes=4,
    label_base=0,
    mu_loc=0.0,
    mu_scale=3.0,
    wishart_df=X.shape[1] + 2,
    wishart_scale=np.eye(X.shape[1]),
    class_prior_alpha=1.0,
)

# Predict
probs = nimbus_lda_predict_proba(model, X_test)
preds = nimbus_lda_predict(model, X_test)

# Update
model = nimbus_lda_update(model, X_new, y_new)
```

### QDA Functions

```python theme={null}
from nimbus_bci import (
    nimbus_qda_fit,
    nimbus_qda_predict,
    nimbus_qda_predict_proba,
    nimbus_qda_update
)
import numpy as np

# Fit
model = nimbus_qda_fit(
    X,
    y,
    n_classes=4,
    label_base=0,
    mu_loc=0.0,
    mu_scale=3.0,
    wishart_df=X.shape[1] + 2,
    wishart_scale=np.eye(X.shape[1]),
    class_prior_alpha=1.0,
)
probs = nimbus_qda_predict_proba(model, X_test)
```

### Softmax Functions

These functions require the optional `softmax` extra.

```python theme={null}
from nimbus_bci import (
    nimbus_softmax_fit,
    nimbus_softmax_predict,
    nimbus_softmax_predict_proba,
    nimbus_softmax_update
)

# Fit
model = nimbus_softmax_fit(
    X,
    y,
    n_classes=4,
    label_base=0,
    w_loc=0.0,
    w_scale=1.0,
    b_loc=0.0,
    b_scale=1.0,
    rng_seed=0,
    learning_rate=0.2,
    num_steps=50,
)

# Predict
probs = nimbus_softmax_predict_proba(
    model,
    X_test,
    num_posterior_samples=50,
    rng_seed=1,
)
```

### STS Functions

```python theme={null}
from nimbus_bci import (
    nimbus_sts_fit,
    nimbus_sts_predict,
    nimbus_sts_predict_proba,
    nimbus_sts_update
)

# Fit
model = nimbus_sts_fit(
    X, y,
    n_classes=4,
    label_base=0,
    state_dim=None,
    transition_cov=0.05,
    observation_cov=1.0,
    learning_rate=0.1,
    num_steps=100,
    rng_seed=0,
    verbose=False
)

# Predict (with optional state evolution)
probs = nimbus_sts_predict_proba(model, X_test, evolve_state=False)
preds = nimbus_sts_predict(model, X_test)

# Update (online learning)
model = nimbus_sts_update(model, X_new, y_new, learning_rate=0.1)
```

**Note**: The functional API for STS provides lower-level control. For most use cases, prefer the `NimbusSTS` class with its state management methods.

### Model I/O

```python theme={null}
from nimbus_bci import nimbus_save, nimbus_load

# Save model
nimbus_save(model, "model.npz")

# Load model
model = nimbus_load("model.npz")
```

## Active Learning

Active learning helpers reduce calibration cost by ranking unlabeled feature rows, deciding whether streaming trials are worth labeling, and stopping calibration when the model posterior stabilizes.

```python theme={null}
from nimbus_bci.active_learning import (
    CalibrationRound,
    CalibrationSession,
    CalibrationStatus,
    QueryResult,
    StreamingQueryDecision,
    calibration_sufficient,
    should_query,
    suggest_next_trial,
)
```

All helpers accept either a fitted Nimbus classifier (`NimbusLDA`, `NimbusQDA`, `NimbusSoftmax`, `NimbusSTS`) or a raw `NimbusModel` snapshot.

<Info>
  Active learning expects preprocessed features. Use `X_pool` shaped `(n_pool, n_features)` for pool-based ranking and stopping, and `x_new` shaped `(n_features,)` or `(1, n_features)` for streaming query decisions.
</Info>

### `CalibrationSession`

Stateful workflow object for active calibration loops. Use it when you want the SDK to manage active-pool bookkeeping, selected index history, `partial_fit()` updates, and previous model snapshots for `posterior_stability`.

```python theme={null}
CalibrationSession(
    model,
    X_pool,
    *,
    pool_strategy="bald",
    streaming_strategy="entropy",
    batch_size=1,
    stopping_criterion="posterior_stability",
    stopping_threshold=None,
    streaming_threshold=None,
    num_posterior_samples=256,
    rng_seed=0,
)
```

**Parameters:**

* `model` (`NimbusModel` or fitted Nimbus classifier): Model used for scoring. `update(...)` requires a fitted Nimbus classifier with `partial_fit()`.
* `X_pool` (`np.ndarray`): Original unlabeled feature pool with shape `(n_pool, n_features)`.
* `pool_strategy` (`"entropy"`, `"margin"`, `"least_confidence"`, or `"bald"`, default=`"bald"`): Strategy used by `suggest_next_trial()`.
* `streaming_strategy` (`"entropy"`, `"margin"`, or `"least_confidence"`, default=`"entropy"`): Strategy used by `should_query()`.
* `batch_size` (int, default=`1`): Default number of pool candidates to request per round.
* `stopping_criterion` (`"posterior_stability"` or `"expected_info_gain"`, default=`"posterior_stability"`): Criterion used by `calibration_sufficient()`.
* `stopping_threshold` (float, optional): Default stopping threshold.
* `streaming_threshold` (float, optional): Default streaming query threshold.
* `num_posterior_samples` (int, default=`256`): Default sample count forwarded to active-learning helpers.
* `rng_seed` (int, default=`0`): Default deterministic seed.

**Methods:**

* `suggest_next_trial(...) -> QueryResult`: Rank the current active pool. Returned indices are local to `remaining_pool`.
* `update(chosen_indices, y_new) -> CalibrationSession`: Map pool-local indices to original pool rows, capture the pre-update snapshot, call `partial_fit()`, and remove selected rows.
* `calibration_sufficient(...) -> CalibrationStatus`: Evaluate whether calibration can stop.
* `should_query(x_new, ...) -> StreamingQueryDecision`: Delegate streaming query decisions with session defaults.
* `get_model() -> NimbusModel`: Return the current Nimbus model snapshot.

**Properties:**

* `remaining_indices`: Original pool indices still available.
* `remaining_pool`: Active feature rows still available.
* `n_remaining`: Number of active candidates left.
* `is_exhausted`: Whether no candidates remain.
* `round_index`: Number of completed update rounds.
* `n_labeled`: Number of labels applied through `update(...)`.

```python theme={null}
session = CalibrationSession(
    clf,
    X_pool,
    pool_strategy="bald",
    batch_size=4,
    stopping_threshold=0.02,
)

ranked = session.suggest_next_trial()
global_indices = session.remaining_indices[ranked.indices]
y_new = collect_labels_for(global_indices)
session.update(ranked.indices, y_new)

status = session.calibration_sufficient()
```

<Note>
  For `posterior_stability`, `calibration_sufficient()` requires at least one prior `update(...)` unless you pass an explicit `previous` snapshot.
</Note>

### `suggest_next_trial()`

Rank an unlabeled feature pool by informativeness and return the top `n` candidates.

```python theme={null}
suggest_next_trial(
    model,
    X_pool,
    *,
    strategy="bald",
    n=1,
    num_posterior_samples=256,
    rng_seed=0,
) -> QueryResult
```

**Parameters:**

* `model` (`NimbusModel` or fitted Nimbus classifier): Model used to score candidates
* `X_pool` (`np.ndarray`): Unlabeled feature rows with shape `(n_pool, n_features)`
* `strategy` (`"entropy"`, `"margin"`, `"least_confidence"`, or `"bald"`, default=`"bald"`): Informativeness criterion
* `n` (int, default=`1`): Number of candidates to return
* `num_posterior_samples` (int, default=`256`): Posterior samples for `bald`; also forwarded to `NimbusSoftmax` probability estimates
* `rng_seed` (int, default=`0`): Deterministic seed for posterior sampling

**Returns:** `QueryResult`

* `indices`: Top-`n` indices into `X_pool`, sorted from most to least informative
* `scores`: Raw informativeness score for each row in `X_pool`
* `strategy`: Strategy used
* `n_posterior_samples`: Posterior samples used (`1` for cheap strategies)

```python theme={null}
previous = clf.get_model()
ranked = suggest_next_trial(
    clf,
    X_pool,
    strategy="bald",
    n=4,
    num_posterior_samples=256,
)

X_new = X_pool[ranked.indices]
y_new = collect_labels_for(ranked.indices)
clf.partial_fit(X_new, y_new)

status = calibration_sufficient(
    clf,
    X_pool,
    previous=previous,
    threshold=0.02,
)
```

<Note>
  `strategy="bald"` is supported for `NimbusLDA`, `NimbusQDA`, and `NimbusSoftmax`. `NimbusSTS` supports only the cheap strategies in this release.
</Note>

### `should_query()`

Decide whether a single arriving trial is informative enough to label.

```python theme={null}
should_query(
    model,
    x_new,
    *,
    strategy="entropy",
    threshold=...,
    num_posterior_samples=50,
    rng_seed=0,
) -> StreamingQueryDecision
```

**Parameters:**

* `model` (`NimbusModel` or fitted Nimbus classifier): Model used to score the trial
* `x_new` (`np.ndarray`): Single feature row with shape `(n_features,)` or `(1, n_features)`
* `strategy` (`"entropy"`, `"margin"`, or `"least_confidence"`, default=`"entropy"`): Cheap informativeness strategy
* `threshold` (float): Query cutoff
* `num_posterior_samples` (int, default=`50`): Forwarded to `NimbusSoftmax` probability estimates
* `rng_seed` (int, default=`0`): Deterministic seed for `NimbusSoftmax`

**Returns:** `StreamingQueryDecision`

* `should_query`: Whether the trial should be labeled
* `score`: Raw informativeness score
* `threshold`: Threshold used
* `strategy`: Strategy used

```python theme={null}
decision = should_query(
    clf,
    x_new,
    strategy="entropy",
    threshold=1.0,
)

if decision.should_query:
    request_label(x_new)
```

<Warning>
  `should_query()` does not support `strategy="bald"`. Single-point BALD is too noisy; batch trials and call `suggest_next_trial(strategy="bald")` instead.
</Warning>

### `calibration_sufficient()`

Decide whether calibration can stop based on a label-free criterion evaluated over the same unlabeled pool.

```python theme={null}
calibration_sufficient(
    model,
    X_pool,
    *,
    criterion="posterior_stability",
    previous=None,
    threshold=...,
    num_posterior_samples=64,
    rng_seed=0,
) -> CalibrationStatus
```

**Parameters:**

* `model` (`NimbusModel` or fitted Nimbus classifier): Current model snapshot
* `X_pool` (`np.ndarray`): Unlabeled feature rows with shape `(n_pool, n_features)`
* `criterion` (`"posterior_stability"` or `"expected_info_gain"`, default=`"posterior_stability"`): Stopping signal
* `previous` (`NimbusModel` or fitted Nimbus classifier, optional): Previous model snapshot, required for `posterior_stability`
* `threshold` (float): Stop when the signal is below this value
* `num_posterior_samples` (int, default=`64`): Posterior samples for `expected_info_gain`; forwarded to `NimbusSoftmax`
* `rng_seed` (int, default=`0`): Deterministic seed

**Returns:** `CalibrationStatus`

* `is_sufficient`: `True` when the stopping signal is below `threshold`
* `signal`: Mean total variation for `posterior_stability`, or mean BALD in bits for `expected_info_gain`
* `threshold`: Threshold used
* `criterion`: Criterion used
* `details`: Criterion-specific diagnostics such as `max_tv`, `min_tv`, `max_bald`, or `min_bald`

```python theme={null}
previous = clf.get_model()

# ... collect labels and call clf.partial_fit(...) on the current round ...

status = calibration_sufficient(
    clf,
    X_pool,
    criterion="posterior_stability",
    previous=previous,
    threshold=0.02,
)

if status.is_sufficient:
    stop_calibration()
```

`posterior_stability` works for every Nimbus head, including `NimbusSTS`. `expected_info_gain` uses BALD and is supported for `NimbusLDA`, `NimbusQDA`, and `NimbusSoftmax`.

### Strategy Units

| Quantity              | Range                         | Used by                                                                          |
| --------------------- | ----------------------------- | -------------------------------------------------------------------------------- |
| `entropy`             | `[0, log2(n_classes)]` bits   | `suggest_next_trial()`, `should_query()`                                         |
| `bald`                | `[0, log2(n_classes)]` bits   | `suggest_next_trial()`, `calibration_sufficient(criterion="expected_info_gain")` |
| `margin`              | `[0, 1]` probability gap      | `suggest_next_trial()`, `should_query()`                                         |
| `least_confidence`    | `[0, 1 - 1/n_classes]`        | `suggest_next_trial()`, `should_query()`                                         |
| `posterior_stability` | `[0, 1]` mean total variation | `calibration_sufficient()`                                                       |

## Type Hints

All functions and classes include type hints for better IDE support:

```python theme={null}
from nimbus_bci import NimbusLDA
import numpy as np
from numpy.typing import NDArray

def train_classifier(
    X: NDArray[np.float64],
    y: NDArray[np.int64]
) -> NimbusLDA:
    clf = NimbusLDA()
    clf.fit(X, y)
    return clf
```

## API FAQ

<AccordionGroup>
  <Accordion title="Which classifier should I start with?" icon="help-circle">
    Start with `NimbusLDA` for fast baselines, especially motor imagery. Use `NimbusQDA` for overlapping distributions and `NimbusSTS` for non-stationary sessions.
  </Accordion>

  <Accordion title="When should I use StreamingSession instead of predict_batch?" icon="arrow-right">
    Use `predict_batch` for offline trials and evaluation. Use `StreamingSession` for chunk-by-chunk real-time inference where latency and incremental decisions matter.
  </Accordion>

  <Accordion title="Do I need MNE-Python to use nimbus-bci?" icon="brain">
    No. MNE integration is optional. You can use `nimbus-bci` with any preprocessing pipeline as long as you provide correctly shaped feature arrays.
  </Accordion>
</AccordionGroup>

## Next Read

<CardGroup cols={2}>
  <Card title="sklearn Integration" icon="code" href="/python-sdk/sklearn-integration">
    Advanced sklearn patterns and best practices
  </Card>

  <Card title="Streaming Inference" icon="activity" href="/python-sdk/streaming-inference">
    Real-time BCI with chunk processing
  </Card>

  <Card title="MNE Integration" icon="brain" href="/python-sdk/mne-integration">
    Complete EEG preprocessing pipeline
  </Card>

  <Card title="Examples" icon="braces" href="/examples/basic-examples">
    Working code examples
  </Card>
</CardGroup>
