liblaf.jarp.tree.prelude
¶
PyTree-aware wrappers for callables and transparent object proxies.
This subpackage contains helper wrappers such as
Partial and
PyTreeProxy. Importing
liblaf.jarp.tree also imports this package's private
prelude module, which registers bound methods and warp.array with JAX before
the public tree helpers are used.
Classes:
-
Partial–Store a partially applied callable as a PyTree-aware proxy.
-
PyTreeProxy–Wrap an arbitrary object and flatten the wrapped value as a PyTree.
Functions:
-
partial–Partially apply a callable and keep bound values visible to JAX trees.
Partial
¶
Bases: PartialCallableObjectProxy
flowchart TD
liblaf.jarp.tree.prelude.Partial[Partial]
click liblaf.jarp.tree.prelude.Partial href "" "liblaf.jarp.tree.prelude.Partial"
Store a partially applied callable as a PyTree-aware proxy.
Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.
Examples:
>>> import jax
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> def add(left, right):
... return left + right
>>> part = jarp.partial(add, jnp.array([1, 2]))
>>> leaves, _treedef = jax.tree.flatten(part)
>>> [leaf.tolist() for leaf in leaves]
[[1, 2]]
>>> part(jnp.array([3, 4])).tolist()
[4, 6]
Methods:
-
__call__–
Attributes:
-
__wrapped__(Callable[..., T]) –
Source code in src/liblaf/jarp/tree/prelude/_partial.py
PyTreeProxy
¶
Bases: BaseObjectProxy
flowchart TD
liblaf.jarp.tree.prelude.PyTreeProxy[PyTreeProxy]
click liblaf.jarp.tree.prelude.PyTreeProxy href "" "liblaf.jarp.tree.prelude.PyTreeProxy"
Wrap an arbitrary object and flatten the wrapped value as a PyTree.
The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.
Attributes:
-
__wrapped__(T) –
partial
¶
Partially apply a callable and keep bound values visible to JAX trees.