Skip to content

liblaf.jarp.tree.prelude

PyTree-aware wrappers for callables and transparent object proxies.

This subpackage contains helper wrappers such as Partial and PyTreeProxy. Importing liblaf.jarp.tree also imports this package's private prelude module, which registers bound methods and warp.array with JAX before the public tree helpers are used.

Classes:

  • Partial

    Store a partially applied callable as a PyTree-aware proxy.

  • PyTreeProxy

    Wrap an arbitrary object and flatten the wrapped value as a PyTree.

Functions:

  • partial

    Partially apply a callable and keep bound values visible to JAX trees.

Partial

Partial(
    func: Callable[..., T], /, *args: Any, **kwargs: Any
)

Bases: PartialCallableObjectProxy


              flowchart TD
              liblaf.jarp.tree.prelude.Partial[Partial]

              

              click liblaf.jarp.tree.prelude.Partial href "" "liblaf.jarp.tree.prelude.Partial"
            

Store a partially applied callable as a PyTree-aware proxy.

Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.

Examples:

>>> import jax
>>> import jax.numpy as jnp
>>> from liblaf import jarp
>>> def add(left, right):
...     return left + right
>>> part = jarp.partial(add, jnp.array([1, 2]))
>>> leaves, _treedef = jax.tree.flatten(part)
>>> [leaf.tolist() for leaf in leaves]
[[1, 2]]
>>> part(jnp.array([3, 4])).tolist()
[4, 6]

Methods:

Attributes:

Source code in src/liblaf/jarp/tree/prelude/_partial.py
def __init__(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> None:
    """Create a proxy that records bound arguments for PyTree flattening."""
    super().__init__(func, *args, **kwargs)
    self._self_args = args
    self._self_kwargs = kwargs

__wrapped__ instance-attribute

__wrapped__: Callable[..., T]

__call__

__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/liblaf/jarp/tree/prelude/_partial.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

PyTreeProxy

Bases: BaseObjectProxy


              flowchart TD
              liblaf.jarp.tree.prelude.PyTreeProxy[PyTreeProxy]

              

              click liblaf.jarp.tree.prelude.PyTreeProxy href "" "liblaf.jarp.tree.prelude.PyTreeProxy"
            

Wrap an arbitrary object and flatten the wrapped value as a PyTree.

The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.

Attributes:

__wrapped__ instance-attribute

__wrapped__: T

partial

partial[T](
    func: Callable[..., T], /, *args: Any, **kwargs: Any
) -> Partial[..., T]

Partially apply a callable and keep bound values visible to JAX trees.

Source code in src/liblaf/jarp/tree/prelude/_partial.py
def partial[T](func: Callable[..., T], /, *args: Any, **kwargs: Any) -> Partial[..., T]:
    """Partially apply a callable and keep bound values visible to JAX trees."""
    return Partial(func, *args, **kwargs)