import collections
import logging
import pickle
import re
import gtar
from tensorflow import keras
logger = logging.getLogger(__name__)
def all_layers(model):
"""Recursively finds all the layers within a keras model"""
for layer in model.layers:
if isinstance(layer, keras.models.Model):
yield from all_layers(layer)
else:
yield layer
[docs]class Trajectory:
"""Interface to save and load models from a GTAR trajectory
.. note:
Consistent with the GTAR schema, model weights are saved with a
dynamic index, which is a string and could indicate a "timestep"
or other time. When accessed via the `load` function, however,
the frame is a simple integer index, beginning at 0.
:param filename: File to save or load from
:param mode: File open mode: 'r' (read-only), 'w' (overwrite), or 'a' (append)
:param group: GTAR group prefix to use to organize multiple sub-trajectories within the same GTAR file, if given
"""
def __init__(self, filename, mode='r', group=None):
self.filename = filename
self.mode = mode
self.handle = gtar.GTAR(filename, mode)
self.group = group
def __enter__(self):
return self
def __exit__(self, typ, val, trace):
self.close()
def __len__(self):
return len(self.frames)
def close(self):
self.handle.close()
@property
def frames(self):
(_, frames) = self.handle.framesWithRecordsNamed('weight', group_prefix=self.group)
return frames
def _get_path(self, name):
if self.group is None:
return name
else:
return '{}/{}'.format(self.group, name)
[docs] def get_weights(self, frame=-1):
"""Returns a list of weight arrays for a model stored at the given frame index
:param frame: integer index of the step to load. Can be negative to count from the end.
"""
(_, frames) = self.handle.framesWithRecordsNamed('weight', group_prefix=self.group)
frame_index = frames[frame]
weight_records = collections.defaultdict(dict)
shape_records = collections.defaultdict(dict)
weight_pattern = re.compile(r'keras/layer/(?P<layer>\d+)/weight/(?P<weight>\d+)')
for rec in self.handle.getRecordTypes():
group = rec.getGroup()
invalid_group = (self.group is not None and
not rec.getGroup().startswith(self.group))
if invalid_group:
continue
match = weight_pattern.search(rec.getGroup())
if not match:
continue
layer = int(match.group('layer'))
weight = int(match.group('weight'))
if rec.getName() == 'weight':
weight_records[layer][weight] = rec
elif rec.getName() == 'shape':
shape_records[layer][weight] = rec
all_weights = []
for (i, records) in sorted(weight_records.items()):
for weight_index in range(len(records)):
weight_rec = records[weight_index]
shape_rec = shape_records[i][weight_index]
shape = self.handle.getRecord(shape_rec, frame_index)
weight = self.handle.getRecord(weight_rec, frame_index)
if weight.size:
weight = weight.reshape(shape)
all_weights.append(weight)
return all_weights
[docs] def load(self, frame=-1, extra_classes={}):
"""Loads a model stored at the given frame index
:param frame: integer index of the step to load. Can be negative to count from the end.
:param extra_classes: Dictionary of additional (name: Class) values to use when initializing model.
"""
given_extra_classes = extra_classes
model_description = self.handle.readStr(self._get_path('keras/model.json'))
assert model_description
extra_classes = self.handle.readBytes(self._get_path('keras/layer_classes.pkl'))
try:
extra_classes = pickle.loads(extra_classes) if extra_classes else {}
except (AttributeError, ModuleNotFoundError):
logger.warning('Failed to load saved layer classes. '
'Custom layers may not load.', exc_info=True)
extra_classes = {}
(extra_class_rec, extra_class_names) = self.handle.framesWithRecordsNamed(
'layer_class.pkl', group_prefix=self.group)
for name in extra_class_names:
if name not in extra_classes:
try:
content = self.handle.getRecord(extra_class_rec, name)
extra_classes[name] = pickle.loads(content)
except (AttributeError, ModuleNotFoundError):
logger.warning(
'Failed to load saved layer class for {}'.format(name),
exc_info=True)
extra_classes.update(given_extra_classes)
model = keras.models.model_from_json(model_description, extra_classes)
all_weights = self.get_weights(frame)
model.set_weights(all_weights)
return model
[docs] def save(self, model, frame=None, only_weights=False):
"""Save a model description and/or current state
:param frame: Frame index (string) to save as. If not given, do not save weights.
:param only_weights: If True, only save the current model weights, not the model architecture.
"""
if not only_weights:
model_json = model.to_json()
layer_classes = {type(layer).__name__: type(layer) for layer in all_layers(model)}
layer_classes_dump = pickle.dumps(layer_classes)
self.handle.writeStr(self._get_path('keras/model.json'), model_json)
self.handle.writeBytes(self._get_path('keras/layer_classes.pkl'), layer_classes_dump)
for (name, cls) in layer_classes.items():
path = self._get_path('keras/vars/layer_class.pkl/{}'.format(name))
self.handle.writeBytes(path, pickle.dumps(cls))
else:
assert frame, 'Trying to save only the weights of a model without a frame given'
dtypes = {'float32': 'f32',
'float64': 'f64'}
if frame:
for (i, layer) in enumerate(model.layers):
for (j, weight) in enumerate(layer.get_weights()):
dtype_string = dtypes[weight.dtype.name]
group = self._get_path('keras/layer/{}/weight/{}'.format(i, j))
self.handle.writePath('{}/frames/{}/weight.{}.uni'.format(group, frame, dtype_string), weight)
self.handle.writePath('{}/shape.u32.uni'.format(group), weight.shape)
[docs] def save_weights(self, model, frame):
"""Save (only) the current model weights.
:param model: Keras Model object containing weights to save
:param frame: Frame index (string) to save as
"""
return self.save(model, frame, only_weights=True)
[docs] def save_model(self, model):
"""Save (only) the current model architecture.
:param model: Keras Model object to save
"""
return self.save(model)