CudaGraphManager

class qlip.inference.cuda_graph.CudaGraphManager(method, method_name=None, warmup_steps=1, skip_filters=[], reset_filters=[])

Bases: object

Manages the cuda graphs for a model’s method.

Examples

>>> model = torch.nn.Linear(10, 10)
>>> CudaGraphManager.compile(model)
>>> model(torch.randn(1, 10))
>>> model(torch.randn(3, 10))
>>> # Temporarily disable cuda graph
>>> with no_cuda_graph():
...     model(torch.randn(1, 10))
>>> # Restore model by removing the cuda graph manager
>>> CudaGraphManager.restore(model)
>>> model(torch.randn(1, 10))
classmethod compile(model, method_name='forward', skip_filters=[], reset_filters=[])

Compile model’s method with cuda graph.

Parameters

  • model (torch.nn.Module) – Model to compile.

  • method_name (str) – Method to compile. Default is “forward”.

  • skip_filters (List[Callable[..., bool]]) – Filters to skip the cuda graph. Each filter should take a dictionary of all inputs and return a boolean. Skip if True.

  • reset_filters (List[Callable[..., bool]]) – Filters to reset the cuda graph. Each filter should take a dictionary of all inputs and return a boolean. Reset if True. CudaGraphManager automatically resets if non-tensor inputs change. CudaGraphManager automatically resets if the new input size is larger than the allocated memory from previous calls.

classmethod restore(model_or_method, method_name='forward')

Restore the original method and remove CUDA graph.

Parameters

  • model_or_method (torch.nn.Module | Callable) – Model or method to restore.

  • method_name (str) – Method to restore if model_or_method is a model. Default is “forward”.

reset_warmup()

Re-implement this method to reset after warmup in graph capturing process.

reset()

Reset the cuda graph.

classmethod enable_cuda_graph(enable)

Enable or disable cuda graph globally.