neurai.util package#

Submodules#

class neurai.util.ir_collector.NpEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)#

Bases: JSONEncoder

default(obj)#

Implement this method in a subclass such that it returns a serializable object for o, or calls the base implementation (to raise a TypeError).

For example, to support arbitrary iterators, you could implement default like this:

def default(self, o):
    try:
        iterable = iter(o)
    except TypeError:
        pass
    else:
        return list(iterable)
    # Let the base class default method raise the TypeError
    return JSONEncoder.default(self, o)
neurai.util.ir_collector.build_network_from_json(path)#

Builds a neural network from a JSON file.

Parameters:

path (str) – The local JSON file path or HTTP link address.

Return type:

Tuple[dict, str, List]

Returns:

Tuple[str, List] – A tuple containing the generated code (str) and a list of monitors.

neurai.util.serialization.restore(path='/path/to/modelparams', cast_array=True, replace_keys=True, key_dict=None)#

Restore the model parameters from a saved file.

Parameters:
  • path (str, optional) – The file path to the saved file.

  • cast_array (bool, optional) – Whether to cast the parameter arrays to jnp.Array. Default is True.

  • replace_keys (bool, optional) – Whether to replace the keys in the model parameters according to the provided key mapping. Default is True.

  • key_dict (dict, optional) – A dictionary containing key mappings for replacing keys in the parameters. Default is None.

Returns:

dict – The restored model parameters.

Note

  • The saved file should be created with the corresponding ‘save’ function.

  • Make sure the model structure and keys are consistent during saving and restoring the parameters.

neurai.util.serialization.save(path, param=None, overwrite=True)#

Save the model parameters to a file.

Parameters:
  • path (str) – The file path to save the parameters.

  • param (Any, optional) – The model parameters to be saved. Default is None.

  • overwrite (bool, optional) – Whether to overwrite the file if it already exists. Default is True.

Returns:

str – The file path where the parameters are saved.

Note

  • Make sure ‘param’ contains the valid model parameters to be saved.

  • The file format and serialization method should be consistent with the corresponding ‘restore’ function.

neurai.util.trans.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=(), keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)#

Compile a function using XLA.

This function transforms a Python function into a function that runs on XLA devices (e.g. CPUs and GPUs) by just-in-time (JIT) compiling it.

Parameters:
  • fun (callable) – A Python function to be compiled.

  • static_argnums (int or Iterable of int, optional) – A static argument of fun is an argument that is passed in every time fun is called with the same value. static_argnums specifies which positional arguments of fun are static arguments. If multiple arguments are static, pass them as a tuple or list. Default is ().

  • device (str or jaxlib.xla.Device, optional) – The XLA device to compile fun for. Default is None, which means JAX will choose the default device (usually the fastest available one).

  • backend (str, optional) – The XLA backend to use. Default is None, which means JAX will choose the default backend.

  • donate_argnums (int or Iterable of int, optional) – If an output of fun is an input to another JIT-compiled function, it may be useful to donate the output array to the caller of fun, to avoid a copy. donate_argnums specifies which positional arguments of fun are outputs that should be donated. If multiple arguments should be donated, pass them as a tuple or list. Default is ().

  • **compile_options – Other XLA compilation options. See the JAX documentation for details: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JIT-compilation

  • static_argnames (Union[str, Iterable[str], None]) –

  • keep_unused (bool) –

  • inline (bool) –

  • abstracted_axes (Optional[Any]) –

Return type:

Wrapped

Returns:

A JIT-compiled version of fun.

neurai.util.visualization.line_plot(ts, val_matrix, plot_ids=None, ax=None, xlim=None, ylim=None, xlabel=None, ylabel=None, legend=None, title=None, save=True, show=False, **kwargs)#

Show the specified value in the given object (Neurons or Synapses.)

Parameters:
  • ts (float) – The simulate time.

  • val_matrix (jnp.ndarray) –

    The value matrix which record the history trajectory. It can be easily accessed by specifying the monitors neurai.monitor of NeuGroup/SynConn by:

    neu/syn = NeuGroup/SynConn(..., monitors=[k1, k2])
    

  • plot_ids (None, int, tuple, a_list) – The index of the value to plot.

  • ax (None, Axes) – The figure to plot.

  • xlim (list, tuple) – The xlim.

  • ylim (list, tuple) – The ylim.

  • xlabel (str) – The xlabel.

  • ylabel (str) – The ylabel.

  • legend (str) – The prefix of legend for plot.

  • save (bool) – Whether save the figure.

  • show (bool) – Whether show the figure.

Examples

Visualize.line_plot(
  run_manager.monitor.mon["ts"], neuron0,
  xlabel="Time(ms)", ylabel="LIF0.V(mv)",
  show=True, save=True, title="LIF0.v")
neurai.util.visualization.raster_plot(ts, sp_matrix, ax=None, marker='.', markersize=2, color='k', xlabel='Time (ms)', ylabel='Neuron index', xlim=None, ylim=None, title=None, save=True, show=False, **kwargs)#

Show the rater plot of the spikes.

Parameters:
  • ts (jnp.ndarray) – The simulate time.

  • sp_matrix (jnp.ndarray) –

    The spike matrix which records the spike information. It can be easily accessed by specifying the monitors neurai.monitor of NeuGroup by:

    neu = NeuGroup(..., monitors=['spike'])
    

  • ax (Axes) – The figure.

  • markersize (int) – The size of the marker.

  • color (str) – The color of the marker.

  • xlim (list, tuple) – The xlim.

  • ylim (list, tuple) – The ylim.

  • xlabel (str) – The xlabel.

  • ylabel (str) – The ylabel.

  • save (bool) – Whether save the figure.

  • show (bool) – Show the figure.

Examples

Visualize.raster_plot(
    run_manager.monitor.mon["ts"],
    run_manager.monitor.mon['LIF0.spike'],
    show=True, save=True, title="LIF0.spike")
class neurai.util.write.DebugWriteFormatChecker#

Bases: Formatter

Custom format checker class to ensure correct usage of formatting strings in write_file.

neurai.util.write.convert2nest_spike(file_path, dt)#

Convert spike data to NEST format and write to a file.

Parameters:
  • file_path (str) – Path to the input spike data file.

  • dt (float) – The simulation time steps.

Return type:

None

Returns:

None

neurai.util.write.convert2nest_voltage(file_path, dt)#

Convert voltage data to NEST format and write to a file.

Parameters:
  • file_path (str) – Path to the input voltage data file.

  • dt (float) – The simulation time steps.

Return type:

None

Returns:

None

neurai.util.write.load_data(file_path)#

Load data from a file and extract timestamps and data.

Parameters:

file_path (str) – Path to the file containing the data.

Return type:

array

Returns:

  • np.array – Timestamps array.

  • np.array – Data array.

neurai.util.write.write_file(fmt, path_file, *args, ordered=False, **kwargs)#

Write operations in real time under JIT. Note: This function does not work with f-strings because the formatting is done lazily.

Parameters:
  • fmt (str) – A format string, e.g. "hello {x}", that will be used to format input arguments.

  • *args – A list of positional arguments to be formatted.

  • ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this write_file w.r.t. other ordered write_file calls.

  • **kwargs – Additional keyword arguments to be formatted.

  • path_file (str) –

Return type:

None

Returns:

None

neurai.util.write.write_spike(file_path, data, ts, dt)#

Write spike data to a file.

Parameters:
  • file_path (str) – Path to the file to write the data.

  • data (np.array) – Spike data array.

  • ts (np.array) – Timestamps array.

  • dt (float) – The simulation time steps.

Return type:

None

Returns:

None

neurai.util.write.write_voltage(file_path, data, ts, dt)#

Write voltage data to a file.

Parameters:
  • file_path (str) – Path to the file to write the data.

  • data (np.array) – Voltage data array.

  • ts (np.array) – Timestamps array.

  • dt (float) – The simulation time steps.

Return type:

None

Returns:

None