Release 0.8.0¶
New features
JAX-compatible functions that run on classical accelerators, such as GPUs, via
catalyst.acceleratenow support autodifferentiation. (#920)For example,
from catalyst import qjit, grad @qjit @grad def f(x): expm = catalyst.accelerate(jax.scipy.linalg.expm) return jnp.sum(expm(jnp.sin(x)) ** 2)
>>> x = jnp.array([[0.1, 0.2], [0.3, 0.4]]) >>> f(x) Array([[2.80120452, 1.67518663], [1.61605839, 4.42856163]], dtype=float64)
Assertions can now be raised at runtime via the
catalyst.debug_assertfunction. (#925)Python-based exceptions (via
raise) and assertions (viaassert) will always be evaluated at program capture time, before certain runtime information may be available.Use
debug_assertto instead raise assertions at runtime, including assertions that depend on values of dynamic variables.For example,
from catalyst import debug_assert @qjit def f(x): debug_assert(x < 5, "x was greater than 5") return x * 8
>>> f(4) Array(32, dtype=int64) >>> f(6) RuntimeError: x was greater than 5
Assertions can be disabled globally for a qjit-compiled function via the
disable_assertionskeyword argument:@qjit(disable_assertions=True) def g(x): debug_assert(x < 5, "x was greater than 5") return x * 8
>>> g(6) Array(48, dtype=int64)
Mid-circuit measurement results when using
lightning.qubitandlightning.kokkoscan now be seeded via the newseedargument of theqjitdecorator. (#936)The seed argument accepts an unsigned 32-bit integer, which is used to initialize the pseudo-random state at the beginning of each execution of the compiled function. Therefor, different
qjitobjects with the same seed (including repeated calls to the sameqjit) will always return the same sequence of mid-circuit measurement results.dev = qml.device("lightning.qubit", wires=1) @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) m = measure(0) if m: qml.Hadamard(0) return qml.probs() @qjit(seed=37, autograph=True) def workflow(x): return jnp.stack([circuit(x) for i in range(4)])
Repeatedly calling the
workflowfunction above will always result in the same values:>>> workflow(1.8) Array([[1. , 0. ], [1. , 0. ], [1. , 0. ], [0.5, 0.5]], dtype=float64) >>> workflow(1.8) Array([[1. , 0. ], [1. , 0. ], [1. , 0. ], [0.5, 0.5]], dtype=float64)
Note that setting the seed will not avoid shot-noise stochasticity in terminal measurement statistics such as
sampleorexpval:dev = qml.device("lightning.qubit", wires=1, shots=10) @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) m = measure(0) if m: qml.Hadamard(0) return qml.expval(qml.PauliZ(0)) @qjit(seed=37, autograph=True) def workflow(x): return jnp.stack([circuit(x) for i in range(4)])
>>> workflow(1.8) Array([1. , 1. , 1. , 0.4], dtype=float64) >>> workflow(1.8) Array([ 1. , 1. , 1. , -0.2], dtype=float64)
Exponential fitting is now a supported method of zero-noise extrapolation when performing error mitigation in Catalyst using
mitigate_with_zne. (#953)This new functionality fits the data from noise-scaled circuits with an exponential function, and returns the zero-noise value:
from pennylane.transforms import exponential_extrapolate from catalyst import mitigate_with_zne dev = qml.device("lightning.qubit", wires=2, shots=100000) @qml.qnode(dev) def circuit(weights): qml.StronglyEntanglingLayers(weights, wires=[0, 1]) return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) @qjit def workflow(weights, s): zne_circuit = mitigate_with_zne(circuit, scale_factors=s, extrapolate=exponential_extrapolate) return zne_circuit(weights)
>>> weights = jnp.ones([3, 2, 3]) >>> scale_factors = jnp.array([1, 2, 3]) >>> workflow(weights, scale_factors) Array(-0.19946598, dtype=float64)
A new module is available,
catalyst.passes, which provides Python decorators for enabling and configuring Catalyst MLIR compiler passes. (#911) (#1037)The first pass available is
catalyst.passes.cancel_inverses, which enables the-removed-chained-self-inverseMLIR pass that cancels two neighbouring Hadamard gates.from catalyst.debug import get_compilation_stage from catalyst.passes import cancel_inverses dev = qml.device("lightning.qubit", wires=1) @qml.qnode(dev) def circuit(x: float): qml.RX(x, wires=0) qml.Hadamard(wires=0) qml.Hadamard(wires=0) return qml.expval(qml.PauliZ(0)) @qjit(keep_intermediate=True) def workflow(x): optimized_circuit = cancel_inverses(circuit) return circuit(x), optimized_circuit(x)
Catalyst now has debug functions
get_compilation_stageandreplace_irto acquire and recompile the IR from a given pipeline pass for functions compiled withkeep_intermediate=True. (#981)For example, consider the following function:
@qjit(keep_intermediate=True) def f(x): return x**2
>>> f(2.0) 4.0
Here we use
get_compilation_stageto acquire the IR, and then modify%2 = arith.mulf %in, %in_0 : f64to turn the square function into a cubic one viareplace_ir:from catalyst.debug import get_compilation_stage, replace_ir old_ir = get_compilation_stage(f, "HLOLoweringPass") new_ir = old_ir.replace( "%2 = arith.mulf %in, %in_0 : f64\n", "%t = arith.mulf %in, %in_0 : f64\n %2 = arith.mulf %t, %in_0 : f64\n" ) replace_ir(f, "HLOLoweringPass", new_ir)
The recompilation starts after the given checkpoint stage:
>>> f(2.0) 8.0
Either function can also be used independently of each other. Note that
get_compilation_stagereplaces theprint_compilation_stagefunction; please see the Breaking Changes section for more details.Catalyst now supports generating executables from compiled functions for the native host architecture using
catalyst.debug.compile_executable. (#1003)>>> @qjit ... def f(x): ... y = x * x ... catalyst.debug.print_memref(y) ... return y >>> f(5) MemRef: base@ = 0x31ac22580 rank = 0 offset = 0 sizes = [] strides = [] data = 25 Array(25, dtype=int64)
We can use
compile_executableto compile this function to a binary:>>> from catalyst.debug import compile_executable >>> binary = compile_executable(f, 5) >>> print(binary) /path/to/executable
Executing this function from a shell environment:
$ /path/to/executable MemRef: base@ = 0x64fc9dd5ffc0 rank = 0 offset = 0 sizes = [] strides = [] data = 25
Improvements
Catalyst has been updated to work with JAX v0.4.28 (exact version match required). (#931) (#995)
Catalyst now supports keyword arguments for qjit-compiled functions. (#1004)
>>> @qjit ... @grad ... def f(x, y): ... return x * y >>> f(3., y=2.) Array(2., dtype=float64)
Note that the
static_argnumsargument to theqjitdecorator is not supported when passing argument values as keyword arguments.Support has been added for the
jax.numpy.argsortfunction within qjit-compiled functions. (#901)Autograph now supports in-place array assignments with static slices. (#843)
For example,
@qjit(autograph=True) def f(x, y): y[1:10:2] = x return y
>>> f(jnp.ones(5), jnp.zeros(10)) Array([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.], dtype=float64)
Autograph now works when
qjitis applied to a function decorated withvmap,cond,for_looporwhile_loop. Previously, stacking the autograph-enabled qjit decorator directly on top of other Catalyst decorators would lead to errors. (#835) (#938) (#942)from catalyst import vmap, qjit dev = qml.device("lightning.qubit", wires=2) @qml.qnode(dev) def circuit(x): qml.RX(x, wires=0) return qml.expval(qml.PauliZ(0))
>>> x = jnp.array([0.1, 0.2, 0.3]) >>> qjit(vmap(circuit), autograph=True)(x) Array([0.99500417, 0.98006658, 0.95533649], dtype=float64)
Runtime memory usage, and compilation complexity, has been reduced by eliminating some scalar tensors from the IR. This has been done by adding a
linalg-detensorizepass at the end of the HLO lowering pipeline. (#1010)Program verification is extended to confirm that the measurements included in QNodes are compatible with the specified device and settings. (#945) (#962)
>>> dev = qml.device("lightning.qubit", wires=2, shots=None) >>> @qjit ... @qml.qnode(dev) ... def circuit(params): ... qml.RX(params[0], wires=0) ... qml.RX(params[1], wires=1) ... return { ... "sample": qml.sample(wires=[0, 1]), ... "expval": qml.expval(qml.PauliZ(0)) ... } >>> circuit([0.1, 0.2]) CompileError: Sample-based measurements like sample(wires=[0, 1]) cannot work with shots=None. Please specify a finite number of shots.
On devices that support it, initial state preparation routines
qml.StatePrepandqml.BasisStateare no longer decomposed when using Catalyst, improving compilation and runtime performance. (#955) (#1047) (#1062) (#1073)Improved type validation and error messaging has been added to both the
catalyst.jvpandcatalyst.vjpfunctions to ensure that the (co)tangent and parameter types are compatible. (#1020) (#1030) (#1031)For example, providing an integer tangent for a function with float64 parameters will result in an error:
>>> f = lambda x: (2 * x, x * x) >>> f_jvp = lambda x: catalyst.jvp(f, params=(x,), tangents=(1,)) >>> qjit(f_jvp)(0.5) TypeError: function params and tangents arguments to catalyst.jvp do not match; dtypes must be equal. Got function params dtype float64 and so expected tangent dtype float64, but got tangent dtype int64 instead.
Ensuring that the types match will resolve the error:
>>> f_jvp = lambda x: catalyst.jvp(f, params=(x,), tangents=(1.0,)) >>> qjit(f_jvp)(0.5) ((Array(1., dtype=float64), Array(0.25, dtype=float64)), (Array(2., dtype=float64), Array(1., dtype=float64)))
Add a script for setting up a Frontend-Only Development Environment that does not require compilation, as it uses the TestPyPI wheel shared libraries. (#1022)
Breaking changes
The
argnumkeyword argument in thegrad,jacobian,value_and_grad,vjp, andjvpfunctions has been renamed toargnumsto better match JAX. (#1036)Return values of qjit-compiled functions that were previously
numpy.ndarrayare now of typejax.Arrayinstead. This should have minimal impact, but code that depends on the output of qjit-compiled function being NumPy arrays will need to be updated. (#895)The
print_compilation_stagefunction has been renamedget_compilation_stage. It no longer prints the IR to the standard output, instead it simply returns the IR as a string. (#981)>>> @qjit(keep_intermediate=True) ... def func(x: float): ... return x >>> print(get_compilation_stage(func, "HLOLoweringPass")) module @func { func.func public @jit_func(%arg0: tensor<f64>) -> tensor<f64> attributes {llvm.emit_c_interface} { return %arg0 : tensor<f64> } func.func @setup() { quantum.init return } func.func @teardown() { quantum.finalize return } }
Support for TOML files in Schema 1 has been disabled. (#960)
The
mitigate_with_znefunction no longer accepts adegreeparameter for polynomial fitting and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation function is valid. Keyword arguments can be passed to this function using theextrapolate_kwargskeyword argument inmitigate_with_zne. (#806)The QuantumDevice API has now added the functions
SetStateandSetBasisStatefor simulators that may benefit from instructions that directly set the state. Implementing these methods is optional, and device support can be indicated via theinitial_state_prepflag in the TOML configuration file. (#955)
Bug fixes
Catalyst no longer silently converts complex parameters to floats where floats are expected, instead an error is raised. (#1008)
Fixes a bug where dynamic one-shot did not work when no mid-circuit measurements are present and when the return type is an iterable. (#1060)
Fixes a bug finding the quantum function jaxpr when using quantum primitives with dynamic one-shot (#1041)
Fix a bug where LegacyDevice number of shots is not correctly extracted when using the legacyDeviceFacade. (#1035)
Catalyst no longer generates a
QubitUnitaryoperation during decomposition if a device doesn’t support it. Instead, the operation that would lead to aQubitUnitaryis either decomposed or raises an error. (#1002)Correctly errors out when user uses
qml.density_matrix(#1118)Catalyst now preserves output PyTrees in QNodes executed with
mcm_method="one-shot". (#957)For example:
dev = qml.device("lightning.qubit", wires=1, shots=20) @qml.qjit @qml.qnode(dev, mcm_method="one-shot") def func(x): qml.RX(x, wires=0) m_0 = catalyst.measure(0, postselect=1) return {"hi": qml.expval(qml.Z(0))}
>>> func(0.9) {'hi': Array(-1., dtype=float64)}
Fixes a bug where scatter did not work correctly with list indices. (#982)
A = jnp.ones([3, 3]) * 2 def update(A): A = A.at[[0, 1], :].set(jnp.ones([2, 3]), indices_are_sorted=True, unique_indices=True) return A
>>> update [[1. 1. 1.] [1. 1. 1.] [2. 2. 2.]]
Static arguments can now be passed through a QNode when specified with the
static_argnumskeyword argument. (#932)dev = qml.device("lightning.qubit", wires=1) @qjit(static_argnums=(1,)) @qml.qnode(dev) def circuit(x, c): print("Inside QNode:", c) qml.RY(c, 0) qml.RX(x, 0) return qml.expval(qml.PauliZ(0))
When executing the qjit-compiled function above,
cwill be a static variable with value known at compile time:>>> circuit(0.5, 0.5) "Inside QNode: 0.5" Array(0.77015115, dtype=float64)
Changing the value of
cwill result in re-compilation:>>> circuit(0.5, 0.8) "Inside QNode: 0.8" Array(0.61141766, dtype=float64)
Fixes a bug where Catalyst would fail to apply quantum transforms and preserve QNode configuration settings when Autograph was enabled. (#900)
pure_callbackwill no longer cause a crash in the compiler if the return type signature is declared incorrectly and the callback function is differentiated. (#916)Instead, this is caught early and a useful error message returned:
@catalyst.pure_callback def callback_fn(x) -> jax.ShapeDtypeStruct((2,), jnp.float32): return np.array([np.sin(x), np.cos(x)]) callback_fn.fwd(lambda x: (callback_fn(x), x)) callback_fn.bwd(lambda x, dy: (jnp.array([jnp.cos(x), -jnp.sin(x)]) @ dy,)) @qjit @catalyst.grad def f(x): return jnp.sum(callback_fn(jnp.sin(x)))
>>> f(0.54) TypeError: Callback callback_fn expected type ShapedArray(float32[2]) but observed ShapedArray(float64[2]) in its return value
AutoGraph will now correctly convert conditional statements where the condition is a non-boolean static value. (#944)
Internally, statically known non-boolean predicates (such as
1) will be converted tobool:@qml.qjit(autograph=True) def workflow(x): n = 1 if n: y = x ** 2 else: y = x return y
value_and_gradwill now correctly differentiate functions with multiple arguments. Previously, attempting to differentiate functions with multiple arguments, or pass theargnumsargument, would result in an error. (#1034)@qjit def g(x, y, z): def f(x, y, z): return x * y ** 2 * jnp.sin(z) return catalyst.value_and_grad(f, argnums=[1, 2])(x, y, z)
>>> g(0.4, 0.2, 0.6) (Array(0.00903428, dtype=float64), (Array(0.0903428, dtype=float64), Array(0.01320537, dtype=float64)))
A bug is fixed in
catalyst.debug.get_cmainto support multi-dimensional arrays as function inputs. (#1003)Bug fixed when parameter annotations return strings. (#1078)
In certain cases,
jax.scipy.linalg.expmmay return incorrect numerical results when used within a qjit-compiled function. A warning will now be raised whenjax.scipy.linalg.expmis used to inform of this issue.In the meantime, we strongly recommend the catalyst.accelerate function within qjit-compiled function to call
jax.scipy.linalg.expmdirectly.@qjit def f(A): B = catalyst.accelerate(jax.scipy.linalg.expm)(A) return B
Note that this PR doesn’t actually fix the aforementioned numerical errors, and just raises a warning. (#1082)
Documentation
A page has been added to the documentation, listing devices that are Catalyst compatible. (#966)
Internal changes
Adds
catalyst.from_plxpr.from_plxprfor converting a PennyLane variant jaxpr into a Catalyst variant jaxpr. (#837)Catalyst now uses Enzyme
v0.0.130(#898)When memrefs have no identity layout, memrefs copy operations are replaced by the linalg copy operation. It does not use a runtime function but instead lowers to scf and standard dialects. It also ensures a better compatibility with Enzyme. (#917)
LLVM’s O2 optimization pipeline and Enzyme’s AD transformations are now only run in the presence of gradients, significantly improving compilation times for programs without derivatives. Similarly, LLVM’s coroutine lowering passes only run when
async_qnodesis enabled in the QJIT decorator. (#968)The function
inactive_callbackwas renamed__catalyst_inactive_callback. (#899)The function
__catalyst_inactive_callbackhas the nofree attribute. (#898)catalyst.dynamic_one_shotusespostselect_mode="pad-invalid-samples"in favour ofinterface="jax"when processing results. (#956)Callbacks now have nicer identifiers in their MLIR representation. The identifiers include the name of the Python function being called back into. (#919)
Fix tracing of
SProdoperations to bring Catalyst in line with PennyLane v0.38. (#935)After some changes in PennyLane,
Sprod.terms()returns the terms as leaves instead of a tree. This means that we need to manually trace each term and finally multiply it with the coefficients to create a Hamiltonian.The function
mitigate_with_zneaccomodates afoldinginput argument for specifying the type of circuit folding technique to be used by the error-mitigation routine (onlyglobalvalue is supported to date.) (#946)Catalyst’s implementation of Lightning Kokkos plugin has been removed in favor of Lightning’s one. (#974)
The
validate_device_capabilitiesfunction is considered obsolete. Hence, it has been removed. (#1045)
Contributors
This release contains contributions from (in alphabetical order):
Joey Carter, Alessandro Cosentino, Lillian M. A. Frederiksen, David Ittah, Josh Izaac, Christina Lee, Kunwar Maheep Singh, Mehrdad Malekmohammadi, Romain Moyard, Erick Ochoa Lopez, Mudit Pandey, Nate Stemen, Raul Torres, Tzung-Han Juang, Paul Haochen Wang,