NeuroDyn
A Python toolbox for fitting interpretable latent dynamics to large-scale neural recordings.
PythonJAXstate-space
NeuroDyn fits state-space and switching linear dynamical systems to recordings from thousands of simultaneously imaged or recorded neurons. It is built on JAX, so models scale to GPU/TPU and differentiate end-to-end.
Highlights
- Pluggable observation models (Gaussian, Poisson, negative-binomial).
- Variational and Laplace inference with a unified API.
- Tools for cross-validation, model comparison, and trajectory visualization.
- Tutorials reproducing results from several published datasets.
Why it exists
Most labs re-implement the same inference machinery for every project. NeuroDyn packages it once, tests it thoroughly, and gets out of your way so you can focus on the science.
import neurodyn as nd
model = nd.SLDS(n_states=4, n_latents=8, observations="poisson")
posterior = model.fit(spikes, n_iters=500)
nd.plot.trajectories(posterior)