CudaGraphManager¶
- class qlip.inference.cuda_graph.CudaGraphManager(method, method_name=None, warmup_steps=1, skip_filters=[], reset_filters=[])¶
Bases:
objectManages the cuda graphs for a model’s method.
Examples
>>> model = torch.nn.Linear(10, 10) >>> CudaGraphManager.compile(model) >>> model(torch.randn(1, 10)) >>> model(torch.randn(3, 10)) >>> # Temporarily disable cuda graph >>> with no_cuda_graph(): ... model(torch.randn(1, 10)) >>> # Restore model by removing the cuda graph manager >>> CudaGraphManager.restore(model) >>> model(torch.randn(1, 10))
- classmethod compile(model, method_name='forward', skip_filters=[], reset_filters=[])¶
Compile model’s method with cuda graph.
Parameters
model (
torch.nn.Module) – Model to compile.method_name (
str) – Method to compile. Default is “forward”.skip_filters (
List[Callable[...,bool]]) – Filters to skip the cuda graph. Each filter should take a dictionary of all inputs and return a boolean. Skip if True.reset_filters (
List[Callable[...,bool]]) – Filters to reset the cuda graph. Each filter should take a dictionary of all inputs and return a boolean. Reset if True. CudaGraphManager automatically resets if non-tensor inputs change. CudaGraphManager automatically resets if the new input size is larger than the allocated memory from previous calls.
- classmethod restore(model_or_method, method_name='forward')¶
Restore the original method and remove CUDA graph.
Parameters
model_or_method (
torch.nn.Module | Callable) – Model or method to restore.method_name (
str) – Method to restore if model_or_method is a model. Default is “forward”.
- reset_warmup()¶
Re-implement this method to reset after warmup in graph capturing process.
- reset()¶
Reset the cuda graph.
- classmethod enable_cuda_graph(enable)¶
Enable or disable cuda graph globally.