neurai.util package#
Submodules#
- class neurai.util.ir_collector.NpEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)#
Bases:
JSONEncoder
- default(obj)#
Implement this method in a subclass such that it returns a serializable object for
o
, or calls the base implementation (to raise aTypeError
).For example, to support arbitrary iterators, you could implement default like this:
def default(self, o): try: iterable = iter(o) except TypeError: pass else: return list(iterable) # Let the base class default method raise the TypeError return JSONEncoder.default(self, o)
- neurai.util.ir_collector.build_network_from_json(path)#
Builds a neural network from a JSON file.
- neurai.util.serialization.restore(path='/path/to/modelparams', cast_array=True, replace_keys=True, key_dict=None)#
Restore the model parameters from a saved file.
- Parameters:
path (str, optional) – The file path to the saved file.
cast_array (bool, optional) – Whether to cast the parameter arrays to jnp.Array. Default is True.
replace_keys (bool, optional) – Whether to replace the keys in the model parameters according to the provided key mapping. Default is True.
key_dict (dict, optional) – A dictionary containing key mappings for replacing keys in the parameters. Default is None.
- Returns:
dict – The restored model parameters.
Note
The saved file should be created with the corresponding ‘save’ function.
Make sure the model structure and keys are consistent during saving and restoring the parameters.
- neurai.util.serialization.save(path, param=None, overwrite=True)#
Save the model parameters to a file.
- Parameters:
- Returns:
str – The file path where the parameters are saved.
Note
Make sure ‘param’ contains the valid model parameters to be saved.
The file format and serialization method should be consistent with the corresponding ‘restore’ function.
- neurai.util.trans.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=(), keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)#
Compile a function using XLA.
This function transforms a Python function into a function that runs on XLA devices (e.g. CPUs and GPUs) by just-in-time (JIT) compiling it.
- Parameters:
fun (callable) – A Python function to be compiled.
static_argnums (int or Iterable of int, optional) – A static argument of fun is an argument that is passed in every time fun is called with the same value. static_argnums specifies which positional arguments of fun are static arguments. If multiple arguments are static, pass them as a tuple or list. Default is ().
device (str or jaxlib.xla.Device, optional) – The XLA device to compile fun for. Default is None, which means JAX will choose the default device (usually the fastest available one).
backend (str, optional) – The XLA backend to use. Default is None, which means JAX will choose the default backend.
donate_argnums (int or Iterable of int, optional) – If an output of fun is an input to another JIT-compiled function, it may be useful to donate the output array to the caller of fun, to avoid a copy. donate_argnums specifies which positional arguments of fun are outputs that should be donated. If multiple arguments should be donated, pass them as a tuple or list. Default is ().
**compile_options – Other XLA compilation options. See the JAX documentation for details: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JIT-compilation
keep_unused (
bool
) –inline (
bool
) –
- Return type:
- Returns:
A JIT-compiled version of fun.
- neurai.util.visualization.line_plot(ts, val_matrix, plot_ids=None, ax=None, xlim=None, ylim=None, xlabel=None, ylabel=None, legend=None, title=None, save=True, show=False, **kwargs)#
Show the specified value in the given object (Neurons or Synapses.)
- Parameters:
ts (float) – The simulate time.
val_matrix (jnp.ndarray) –
The value matrix which record the history trajectory. It can be easily accessed by specifying the monitors
neurai.monitor
of NeuGroup/SynConn by:neu/syn = NeuGroup/SynConn(..., monitors=[k1, k2])
plot_ids (None, int, tuple, a_list) – The index of the value to plot.
ax (None, Axes) – The figure to plot.
xlabel (str) – The xlabel.
ylabel (str) – The ylabel.
legend (str) – The prefix of legend for plot.
save (bool) – Whether save the figure.
show (bool) – Whether show the figure.
Examples
Visualize.line_plot( run_manager.monitor.mon["ts"], neuron0, xlabel="Time(ms)", ylabel="LIF0.V(mv)", show=True, save=True, title="LIF0.v")
- neurai.util.visualization.raster_plot(ts, sp_matrix, ax=None, marker='.', markersize=2, color='k', xlabel='Time (ms)', ylabel='Neuron index', xlim=None, ylim=None, title=None, save=True, show=False, **kwargs)#
Show the rater plot of the spikes.
- Parameters:
ts (jnp.ndarray) – The simulate time.
sp_matrix (jnp.ndarray) –
The spike matrix which records the spike information. It can be easily accessed by specifying the monitors
neurai.monitor
of NeuGroup by:neu = NeuGroup(..., monitors=['spike'])
ax (Axes) – The figure.
markersize (int) – The size of the marker.
color (str) – The color of the marker.
xlabel (str) – The xlabel.
ylabel (str) – The ylabel.
save (bool) – Whether save the figure.
show (bool) – Show the figure.
Examples
Visualize.raster_plot( run_manager.monitor.mon["ts"], run_manager.monitor.mon['LIF0.spike'], show=True, save=True, title="LIF0.spike")
- class neurai.util.write.DebugWriteFormatChecker#
Bases:
Formatter
Custom format checker class to ensure correct usage of formatting strings in write_file.
- neurai.util.write.convert2nest_spike(file_path, dt)#
Convert spike data to NEST format and write to a file.
- neurai.util.write.convert2nest_voltage(file_path, dt)#
Convert voltage data to NEST format and write to a file.
- neurai.util.write.load_data(file_path)#
Load data from a file and extract timestamps and data.
- Parameters:
file_path (str) – Path to the file containing the data.
- Return type:
array
- Returns:
np.array – Timestamps array.
np.array – Data array.
- neurai.util.write.write_file(fmt, path_file, *args, ordered=False, **kwargs)#
Write operations in real time under JIT. Note: This function does not work with f-strings because the formatting is done lazily.
- Parameters:
fmt (str) – A format string, e.g.
"hello {x}"
, that will be used to format input arguments.*args – A list of positional arguments to be formatted.
ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this
write_file
w.r.t. other orderedwrite_file
calls.**kwargs – Additional keyword arguments to be formatted.
path_file (
str
) –
- Return type:
- Returns:
None
- neurai.util.write.write_spike(file_path, data, ts, dt)#
Write spike data to a file.