Source code for speechbrain.utils.profiling

"""Polymorphic decorators to handle PyTorch profiling and benchmarking.

Author:
    * Andreas Nautsch 2022
"""
import numpy as np
from copy import deepcopy
from torch import profiler
from itertools import chain
from functools import wraps
from torch.autograd.profiler_util import (  # pytorch v1.10.1
    EventList,
    FunctionEvent,
    _format_time,
    _format_memory,
)
from typing import Any, Callable, Iterable, Optional, List


[docs]def set_profiler_attr(func: object, set_attr: str, handler: Callable): """Sets handler for profiler: scheduler or trace export. """ assert set_attr in [ "on_trace_ready", "schedule", ], "Needs to be a callable profiler attribute." if ( func is None ): # Polymorph: not used as decorator; func is used as e.g.: trace_export() return handler elif callable( func ): # Polymorph: decorates a decorator of function/class constructor @wraps(func) def wrapper(*args, **kwargs): """Wrapper implementation.""" if "__call__" not in dir( func ): # Decorator for class constructor (directly) result = func(*args, **kwargs) setattr(result.profiler, set_attr, handler) return result # not tested else: # Return as additional argument. kwargs[set_attr] = handler return func(*args, **kwargs) return wrapper else: # Polymorph: func is assumed to be an instance of speechbrain.core.Brain # No return: in-place edit if hasattr(func, "profiler"): if func.profiler is profiler.profile: setattr(func.profiler, set_attr, handler)
[docs]def schedule( func: Optional[object] = None, wait: int = 2, warmup: int = 2, active: int = 2, repeat: int = 1, skip_first: int = 0, ): """Wrapper to create a ```torch.profiler.schedule``` (sets default parameters for warm-up). """ torch_scheduler = profiler.schedule( wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first, ) """ Curious which action a default scheduler suggests at which profiler.step() ? [torch_scheduler(x) for x in range(10)] 00 = {ProfilerAction} ProfilerAction.NONE 01 = {ProfilerAction} ProfilerAction.NONE 02 = {ProfilerAction} ProfilerAction.WARMUP 03 = {ProfilerAction} ProfilerAction.WARMUP 04 = {ProfilerAction} ProfilerAction.RECORD 05 = {ProfilerAction} ProfilerAction.RECORD_AND_SAVE 06 = {ProfilerAction} ProfilerAction.NONE 07 = {ProfilerAction} ProfilerAction.NONE 08 = {ProfilerAction} ProfilerAction.NONE 09 = {ProfilerAction} ProfilerAction.NONE """ return set_profiler_attr( func=func, set_attr="schedule", handler=torch_scheduler )
[docs]def export( func: Optional[object] = None, dir_name: str = "./log/", worker_name: Optional[str] = None, use_gzip: bool = False, ): """Exports current and aggregated traces for: - Chrome tensorboard - FlameGraph (and sets default parameters for log file folder/filenames). """ import os import socket import time # Chrome export (default handler); inspired the log_file() function below. tensorboard_handler = profiler.tensorboard_trace_handler( dir_name=dir_name, worker_name=worker_name, use_gzip=use_gzip ) def trace_handler(prof: profiler.profile): """trace_handler implementation.""" def log_file(export_chrome: bool = False, info: str = ""): """Implementation of logging file.""" nonlocal worker_name if not worker_name: worker_name = "{}_{}".format( socket.gethostname(), str(os.getpid()) ) if export_chrome: ext = "pt.trace.json" else: ext = "txt" if info: pattern = "{{}}.{{}}_{}.{{}}".format(info) else: pattern = "{}.{}.{}" file_name = pattern.format( worker_name, int(time.time() * 1000), ext ) if use_gzip: file_name = file_name + ".gz" return os.path.join(dir_name, file_name) def export_stacks(log_path: str, metric: str): """Implementation of export_stacks.""" prof.export_stacks(log_file(), metric) def export_traces(aggregated_traces: bool = False): """Implementation of export_traces.""" if not aggregated_traces: # Chrome export (also checks for dir_name existing). tensorboard_handler(prof) # FlameGraph exports. if prof.with_stack or aggregated_traces: log_path = ( log_file(info="aggregated") if aggregated_traces else log_file() ) export_stacks(log_path=log_path, metric="self_cpu_time_total") if prof.profiler is not None: if prof.profiler.use_cuda: export_stacks( log_path=log_path, metric="self_cuda_time_total" ) # export last logged trace - skip if events are empty (e.g., profiler created w/o any torch.nn call) if prof.events(): export_traces() return set_profiler_attr( func=func, set_attr="on_trace_ready", handler=trace_handler )
[docs]def prepare_profiler_for_brain(prof: profiler.profile): """Sets up a ``torch.profiler.profile`` to also (a) aggregate traces issued from various interactions with ``speechbrain.core.Brain``:s and (b) hooks a method to ``merge_traces``. """ # Brain functions will be called independently -> traces will be segregated, so we aggregate them. prof.speechbrain_event_traces = list() # Preparing the profiler to be re-used during Brain:s' lifecycles. def hook_profiler_stop(stop: Callable): """Implementation of hook_profiler_stop.""" @wraps(stop) def stop_wrapper(): """Implementation of stop_wrapper.""" kineto_profiler = prof.profiler if kineto_profiler is not None: stop_result = stop() if ( prof.events() ): # kineto events are not aggregatable (sticking with parsed kineto events) # see: torch.autograd.profiler.__exit__ kineto_events = kineto_profiler._parse_kineto_results( kineto_profiler.kineto_results ) # add to trace record prof.speechbrain_event_traces.append( deepcopy(kineto_events) ) # set flag to disable the profiler kineto_profiler.enabled = False return stop_result else: return stop() # will be: None return stop_wrapper # Preparing the profiler to be re-started during Brain:s' lifecycles. def hook_profiler_start(start: Callable): """Implementation of hook_profiler_start.""" @wraps(start) def start_wrapper(): """Implementation of start_wrapper.""" prof.step_num = 0 prof.current_action = prof.schedule(prof.step_num) kineto_profiler = prof.profiler if kineto_profiler is not None: # check flag if profiler is disabled (i.e. as of stop_wrapper); prevents entering its __init__ twice if not kineto_profiler.enabled: # reset kineto profiler (otherwise, one obtains the same traces over & over again) kineto_profiler.enabled = True return start() return start_wrapper # It's currently designed as hiding an Easter Egg. def merge_traces(): """Implementation of merge_traces.""" # Alternative re-design quirks: make trace aggregator a GLOBAL -or- create another profiler class. trace_aggregator = "speechbrain_event_traces" if prof.profiler is not None: if trace_aggregator in dir(prof) and prof.events(): # clear all assigned parents/children (from previous mergers & trees) for trace in getattr(prof, trace_aggregator): for event in trace: event.cpu_parent = None event.cpu_children: List[FunctionEvent] = [] # assemble new list merged_events = EventList( list(chain.from_iterable(getattr(prof, trace_aggregator))), use_cuda=prof.profiler.use_cuda, profile_memory=prof.profiler.profile_memory, with_flops=prof.profiler.with_flops, ) merged_events._build_tree() return merged_events else: # not tested return prof.events() else: return [] # Augment torch's profiler. setattr(prof, "start", hook_profiler_start(getattr(prof, "start"))) setattr(prof, "stop", hook_profiler_stop(getattr(prof, "stop"))) setattr(prof, "merge_traces", merge_traces) # Return so it can be readily assigned elsewhere :) return prof
[docs]def hook_brain_methods( func: object, prof: profiler.profile, class_hooks: Optional[Iterable[str]] = None, ): """For instances of ``speechbrain.core.Brain``, critical functions are hooked to profiler start/stop methods. """ # Prepare additional hook decorators for methods of Brain:s. def hook_brain(f: Callable): """Implementation of hook_brain.""" @wraps(f) def hook(*f_args, **f_kwargs): """Implementation of hook.""" # The profiler stopped after __init__ so we need to get it up again and stop it manually also. prof.start() r = f(*f_args, **f_kwargs) prof.stop() return r return hook # Hook the crucial Brain methods. if class_hooks is None: class_hooks = ["fit", "evaluate"] for method in class_hooks: if method in dir(func): # func is an instance of Brain setattr(func, method, hook_brain(getattr(func, method)))
[docs]def profile( func: Optional[object] = None, class_hooks: Optional[Iterable[str]] = None, activities: Optional[Iterable[profiler.ProfilerActivity]] = None, schedule: Optional[Callable[[int], profiler.ProfilerAction]] = None, on_trace_ready: Optional[Callable[..., Any]] = None, record_shapes: bool = False, profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, with_modules: bool = False, ) -> object: """Wrapper to create a PyTorch profiler to benchmark training/inference of speechbrain.core.Brain instances. See ``torch.profiler.profile`` documentation for details (brief summary below). Arguments --------- func : object ``speechbrain.core.Brain``:s or a (train/eval) function to be profiled. class_hooks : iterable List of method/function names of ``speechbrain.core.Brain``:s that should be profiled also. Otherwise, only the __init__ constructor will be profiled when decorating a Brain class. Default: ``['fit', 'evaluate']`` for classes, and ``None`` for functions. activities : iterable List of activity groups. Default: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA. (Default value should be ok for most cases.) schedule : callable Waits a specified amount of steps for PyTorch to warm-up; see the above ``schedule`` decorator. Default: ``ProfilerAction.RECORD`` (immediately starts recording). on_trace_ready : callable Specifies what benchmark record should be saved (after each scheduled step); see above ``trace_handler`` decorator. Default: ``None`` (pick up collected reporting once profiling ended, but not details per step). record_shapes : bool Save input shapes of operations (enables to group benchmark data by after profiling). Default: ``False``. profile_memory : bool Track tensor memory allocation/deallocation. Default: ``False``. with_stack : bool Record source information (file and line number). Default: ``False``. with_flops: bool Estimate the number of FLOPs. Default: ``False``. with_modules: bool Record module hierarchy (including function names) Default: ``False`` Example ------- >>> import torch >>> @profile ... def run(x : torch.Tensor): ... y = x ** 2 ... z = y ** 3 ... return y.backward() # y.backward() returns None --> return value is substituted with profiler >>> data = torch.randn((1, 1), requires_grad=True) >>> prof = run(data) >>> out = [len(prof.events()), len(prof.key_averages()), prof.profiler.total_average().count] """ if func is None: # return a profiler; not tested return prepare_profiler_for_brain( profiler.profile( activities=activities, schedule=schedule, on_trace_ready=on_trace_ready, record_shapes=record_shapes, profile_memory=profile_memory, with_stack=with_stack, with_flops=with_flops, with_modules=with_modules, ) ) # Polymorph: func is pretrained or an instance of Brain (assumed case) if hasattr(func, "HPARAMS_NEEDED") or not callable(func): with profiler.profile( activities=activities, schedule=schedule, # scheduler needs to be set directly (fetching is here not possible as for wrappers) on_trace_ready=on_trace_ready, record_shapes=record_shapes, profile_memory=profile_memory, with_stack=with_stack, with_flops=with_flops, with_modules=with_modules, ) as prof: func.profiler = prepare_profiler_for_brain(prof) hook_brain_methods(func=func, class_hooks=class_hooks, prof=prof) return func # no need to return anything; all done in-place; but if needs to be readily assigned elsewhere else: # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped @wraps(func) def wrapper(*args, **kwargs): """Implementation of the wrapper.""" # Binding variables. nonlocal class_hooks nonlocal schedule nonlocal on_trace_ready # Check if there's a nested decorators. if schedule is None: if "schedule" in kwargs: schedule = kwargs.pop("schedule") if on_trace_ready is None: if "on_trace_ready" in kwargs: on_trace_ready = kwargs.pop("on_trace_ready") with profiler.profile( activities=activities, schedule=schedule, on_trace_ready=on_trace_ready, record_shapes=record_shapes, profile_memory=profile_memory, with_stack=with_stack, with_flops=with_flops, with_modules=with_modules, ) as prof: # Preserves profiler as class attribute if func is not a function (implies: speechbrain.core.Brain). if "__call__" not in dir(func): # Passing the profiler to Bain:s' __init__ constructor as an additional argument. kwargs["profiler"] = prepare_profiler_for_brain(prof) hook_brain_methods( func=func, class_hooks=class_hooks, prof=prof ) # Run & trace to benchmark. result = func(*args, **kwargs) # Prof is about to be lost at return. if "__call__" in dir(func): if result is None: return prof # for void function, simply return profiling data else: # not tested - returns both return result, prof return result return wrapper
[docs]def profile_analyst( func: Optional[object] = None, class_hooks: Optional[Iterable[str]] = None, ): # to diverge, define parameters from scratch: @schedule; @export & @profile """Pre-configured profiling for a fully detailed benchmark - analyst perspective. Creating this analyst view will create overheads (disabling some PyTorch optimisations); use @profile_optimiser to take benefits of optimisations and further optimise your modules, accordingly. """ profiler_kwargs = { "schedule": schedule(), "on_trace_ready": None, "record_shapes": True, "profile_memory": True, "with_stack": True, "with_flops": True, # only for: matrix multiplication & 2D conv; see: torch.autograd.profiler.profile "with_modules": True, "class_hooks": class_hooks, } wrapped_func = profile(func, **profiler_kwargs) # Polymorph: func is pretrained or an instance of Brain (assumed case) if hasattr(func, "HPARAMS_NEEDED") or not callable(func): return wrapped_func else: # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped @wraps(func) def wrapper(*args, **kwargs): """Implementation of the wrapper.""" return wrapped_func(*args, **kwargs) return wrapper
[docs]def profile_optimiser( func: Optional[object] = None, class_hooks: Optional[Iterable[str]] = None, ): # to diverge, define parameters from scratch: @schedule; @export & @profile """Pre-configured profiling for a detailed benchmark (better suitable for speed-optimisation than @profile_analyst). """ profiler_kwargs = { "schedule": schedule(), "on_trace_ready": None, "record_shapes": False, # avoid: overheads "profile_memory": True, "with_stack": False, # avoid: overheads "with_flops": False, # only for: matrix multiplication & 2D conv; see: torch.autograd.profiler.profile "with_modules": True, "class_hooks": class_hooks, } wrapped_func = profile(func, **profiler_kwargs) # Polymorph: func is pretrained or an instance of Brain (assumed case) if hasattr(func, "HPARAMS_NEEDED") or not callable(func): return wrapped_func else: # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped @wraps(func) def wrapper(*args, **kwargs): """Implementation of the wrapper.""" return wrapped_func(*args, **kwargs) return wrapper
[docs]def profile_report( # not part of unittests func: Optional[object] = None, class_hooks: Optional[Iterable[str]] = None, ): """Pre-configured profiling for a reporting benchmark (changed scheduler to @profile_optimiser). """ profiler_kwargs = { "schedule": schedule( wait=1, warmup=2, active=7, repeat=1, skip_first=0, ), # gives #active, avg:ed of #repeat "on_trace_ready": None, "record_shapes": False, # avoid: overheads "profile_memory": True, "with_stack": False, # avoid: overheads "with_flops": False, # only for: matrix multiplication & 2D conv; see: torch.autograd.profiler.profile "with_modules": True, "class_hooks": class_hooks, } wrapped_func = profile(func, **profiler_kwargs) # Polymorph: func is pretrained or an instance of Brain (assumed case) if hasattr(func, "HPARAMS_NEEDED") or not callable(func): return wrapped_func else: # callable(func) - polymorph: __init__ Brain constructor -or- function to be wrapped @wraps(func) def wrapper(*args, **kwargs): """Implementation of the wrapper.""" return wrapped_func(*args, **kwargs) return wrapper
[docs]def events_diff( a: EventList, b: EventList, filter_by: str = "count", ): """Takes two ``EventList``:s in, filters events of equal value (default: by the count of events). The purpose of the results of this diff are for visualisation only (to see the difference between implementations). """ # Making copies from the originals instead of simply adding the diff directly might be slower (preserves structure). aa = deepcopy(a) bb = deepcopy(b) # Maps: function name -> (call count, position) // the position helps to remove alike call numbers later on. a_filter = dict( [(i.key, (getattr(i, filter_by), p)) for p, i in enumerate(aa)] ) b_filter = dict( [(i.key, (getattr(i, filter_by), p)) for p, i in enumerate(bb)] ) # Figuring our which ones to delete. a_to_remove = list([]) b_to_remove = list([]) for key in a_filter.keys(): if key in b_filter.keys(): # Equal values are filtered. if a_filter[key][0] == b_filter[key][0]: # Enlist position to be removed. a_to_remove.append(a_filter[key][1]) b_to_remove.append(b_filter[key][1]) # Since EventLists are lists: removing items from the back. if a_to_remove: a_to_remove.sort(reverse=True) for k in a_to_remove: aa.remove(aa[k]) if b_to_remove: b_to_remove.sort(reverse=True) for k in b_to_remove: bb.remove(bb[k]) return aa, bb
[docs]def report_time(events: object, verbose=False, upper_control_limit=False): """Summary reporting of total time - see: torch.autograd.profiler_util """ # Aggregate CPU & CUDA time. if isinstance(events, FunctionEvent): function_events = events elif isinstance(events, profiler.profile): function_events = events.events() elif hasattr(events, "profiler"): # assumes speechbrain.core.Brain function_events = events.profiler.events() else: raise TypeError( "Expected a FunctionEvent; profiler.profile, or a SpeechBrain." ) if upper_control_limit: # discerns top-level event (among others) aten:zeros which is in the avg range of 10-20ms on laptop CPU cpu_data = np.array( [e.cpu_time for e in function_events if e.key == "ProfilerStep*"] ) cuda_data = np.array( [e.cuda_time for e in function_events if e.key == "ProfilerStep*"] ) cpu_time = cpu_data.mean() + 3 * cpu_data.std() cuda_time = cuda_data.mean() + 3 * cuda_data.std() else: total = function_events.total_average() cpu_time = total.self_cpu_time_total cuda_time = total.self_cuda_time_total if verbose: print("CPU time: {}".format(_format_time(cpu_time))) if cuda_time > 0: print("CUDA time: {}".format(_format_time(cuda_time))) return cpu_time, cuda_time
[docs]def report_memory(handler: object, verbose=False): """Summary reporting of total time - see: torch.autograd.profiler_util """ # Aggregate CPU & CUDA time. if isinstance(handler, FunctionEvent): events = handler elif isinstance(handler, profiler.profile): events = handler.events() elif hasattr(handler, "profiler"): # assumes speechbrain.core.Brain events = handler.profiler.events() else: raise TypeError( "Expected a FunctionEvent; profiler.profile, or a SpeechBrain." ) """memory allocation during each time step is of relevance, e.g. for visualisation - time intensive for lots events mem_times = np.unique( [[x.time_range.start, x.time_range.end] for x in events] ) cpu_memory = np.zeros_like(mem_times) cuda_memory = np.zeros_like(mem_times) for x in events: idx = (x.time_range.start <= mem_times) & ( x.time_range.end >= mem_times ) cpu_memory[idx] += x.cpu_memory_usage cuda_memory[idx] += x.cuda_memory_usage # variable names instead of labeling pandas' columns cpu_mem = np.max(cpu_memory) cuda_mem = np.max(cuda_memory) """ cpu_mem = cuda_mem = 0 for e in events: if len(e.cpu_children) == 0: leaf_cpu_mem = e.cpu_memory_usage leaf_cuda_mem = e.cuda_memory_usage parent = e.cpu_parent while parent is not None: leaf_cpu_mem += parent.cpu_memory_usage leaf_cuda_mem += parent.cuda_memory_usage parent = parent.cpu_parent if leaf_cpu_mem > cpu_mem: cpu_mem = leaf_cpu_mem if leaf_cuda_mem > cuda_mem: cuda_mem = leaf_cuda_mem if verbose: print("Peak CPU Mem: {}".format(_format_memory(cpu_mem))) if cuda_mem > 0: print("Peak CUDA Mem: {}".format(_format_memory(cuda_mem))) return cpu_mem, cuda_mem