from typing import Callable, Dict, List, TYPE_CHECKING
from labml.internal.util.strings import is_pattern_match
if TYPE_CHECKING:
from torch import nn
def _transform_key_part(part):
try:
if str(int(part)) == part:
return f'{int(part):09d}'
except ValueError as e:
pass
return part
def sort_keys(keys: List[str]):
transformed = [
'/'.join(
[
'.'.join([_transform_key_part(k2) for k2 in k.split('.')])
for k in key.split('/')
]
)
for key in keys]
combined = [(t, k) for t, k in zip(transformed, keys)]
combined = sorted(combined, key=lambda x: x[0])
return [k for t, k in combined]
class ForwardHook:
def __init__(self, name: str, save_callback: Callable):
self.save_callback = save_callback
self.name = name
def __call__(self, module, i, o):
self.save_callback(self.name, i, o)
class BackwardHook:
def __init__(self, name: str, save_callback: Callable):
self.save_callback = save_callback
self.name = name
def __call__(self, module, i, o):
self.save_callback(self.name, i, o)
class ModelProbe:
def __init__(self, model: 'nn.Module', name: str = 'model', *,
add_forward_hooks=True,
add_backward_hooks=False):
self.model = model
for n, module in model.named_modules():
module: 'nn.Module'
if n == '':
n = name
if add_forward_hooks:
forward_hook = ForwardHook(n, self._add_forward_tensor)
module.register_forward_hook(forward_hook)
if add_backward_hooks:
backward_hook = ForwardHook(n, self._add_backward_tensor)
module.register_full_backward_hook(backward_hook)
self._forward_output = {}
self._forward_input = {}
self._backward_output = {}
self._backward_input = {}
self._parameters = {}
for k, v in model.named_parameters():
self._parameters[k] = v
def _add_forward_tensor(self, name: str, inp: any, outp: any):
self._forward_input[name] = inp
self._forward_output[name] = outp
def _add_backward_tensor(self, name: str, inp: any, outp: any):
self._backward_input[name] = inp
self._backward_output[name] = outp
@property
def parameters(self):
"""
All the model parameters as a :class:`ValueCollection`
"""
return ValueCollection(self._parameters, sort_keys(list(self._parameters.keys())))
@property
def forward_input(self):
"""
Inputs to layers in the forward pass as a :class:`ValueCollection`
"""
return ValueCollection(self._forward_input, sort_keys(list(self._forward_input.keys())))
@property
def forward_output(self):
"""
Outputs of layers in the forward pass as a :class:`ValueCollection`
"""
return ValueCollection(self._forward_output, sort_keys(list(self._forward_output.keys())))
@property
def backward_input(self):
"""
Inputs (gradients) to layers in the backward pass as a :class:`ValueCollection`
"""
return ValueCollection(self._backward_input, sort_keys(list(self._backward_input.keys())))
@property
def backward_output(self):
"""
Output (gradients) of layers in the backward pass as a :class:`ValueCollection`
"""
return ValueCollection(self._backward_output, sort_keys(list(self._backward_output.keys())))
class ValueCollection:
def __init__(self, values: Dict[str, any], keys: List[str]):
self._keys = keys
self._values = values
@staticmethod
def _expand_value(prefix: str, value: any, separator: str = '.'):
if isinstance(value, tuple) or isinstance(value, list):
return sum([
ValueCollection._expand_value(f'{prefix}{separator}{i}', v) for i, v in enumerate(value)],
[])
if isinstance(value, dict):
return sum([ValueCollection._expand_value(f'{prefix}{separator}{k}', v) for k, v in value.items()], [])
return [prefix]
[docs] def get_value(self, key: str):
"""
Get a value by key
Arguments:
key (str): Key of the value
"""
return self._values[key]
[docs] def get_list(self):
"""
Get a list of values
"""
return [self.get_value(f) for f in self._keys]
[docs] def get_dict(self):
"""
Get a dictionary of values
"""
return {f: self.get_value(f) for f in self._keys}
[docs] def deep(self):
"""
Get a :class:`DeepValueCollection` by expanding the tree of values
"""
if isinstance(self, DeepValueCollection):
return self
keys = sum([
ValueCollection._expand_value(f'{k}', self._values[k], '/') for k in self._keys],
[])
return DeepValueCollection(self._values, sort_keys(keys))
def __str__(self):
return str(self._keys)
def __repr__(self):
return repr(self._keys)
def __getitem__(self, item: str):
keys = [k for k in self._keys if is_pattern_match(k, item)]
return self.__class__(self._values, keys)
def __len__(self):
return len(self._keys)
def keys(self):
return self._keys
class DeepValueCollection(ValueCollection):
[docs] def get_value(self, key: str):
"""
Get a value by key. Use ``/`` to go navigate into child elements.
Arguments:
key (str): Key of the value
"""
parts = key.split('/')
if len(parts) == 1:
return self._values[parts[0]]
assert len(parts) == 2
value = self._values[parts[0]]
key = parts[1]
parts = key.split('.')
for p in parts:
if isinstance(value, tuple) or isinstance(value, list):
value = value[int(p)]
else:
value = value[p]
return value