๋ฌธ์ ๊ฐ์
Multi-GPU Model๋ก ํ์ต์ํค๊ณ h5(hdf5)๋ก ์ ์ฅํด๋ ๋ค์, ๋์ค์ ๋ถ๋ฌ์ฌ ๋ layers ์๊ฐ ๋ง์ง ์๋ค๊ณ ์๋ฌ๊ฐ ๋จ
์ด๋ฏธ ๋ง์ด ์๋ ค์ง ๋ฒ๊ทธ
ํด๊ฒฐ์ฑ ๊ฐ์
Multi-GPU Model๋ก ํ์ตํ weight๋ฅผ (Multi-GPU๋ก compileํ์ง ์์) Base Model์ setํ ๋ค์ h5(hdf5)๋ก ์ ์ฅ
Epoch๊ฐ ๋ชจ๋ ๋๋ ๋ค์ model.save()๋ ๋๋ต ์๋์ ๊ฐ์ด ํ๋ฉด ๋์ง๋ง...
...
parallel_model, base_model = get_compiled_model(...)
parallel_model.fit(...)
base_model.set_weights(parallel_model.get_weights())
base_model.save(filepath)
...
Keras์์ ์ ๊ณตํ๋ Callback ์ค ModelCheckpoint๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด ๊ฐ๋จํ ๋ฐ๊ฟ ์ ์๋ ๋ฐฉ๋ฒ์ด ์๋ค...
๊ทธ๋์ ModelCheckpoint๋ฅผ ์์๋ฐ์์ ์์ ๊ฐ์ด ์ฝ๊ฐ ์์ ํด์ค์ผ ํ๋ค. ์๋ ๋งํฌ๋ค์ ์ฐธ๊ณ ..
https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L275
https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L633
์์๋๋ก Callback, ModelCheckpoint ์ฝ๋
https://github.com/keras-team/keras/issues/8123#issuecomment-348976624
ModelCheckpoint Custom rf
๊ทธ๋์ ๋ฐ๊พผ Multi-GPU Model ์ฉ ModelCheckpoint๋ ์๋์ ๊ฐ๋ค.
class MultiGPUModelCheckpoint(Callback):
"""
Multi GPU ์ฌ์ฉ ์ ๋ชจ๋ธ ์ ์ฅ ํ ๋ก๋๊ฐ ์ ๋์ด base_model์ ํตํด ์ ์ฅํจ
"""
def __init__(self, filepath, base_model, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
super(MultiGPUModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
self.base_model = base_model
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
'fallback to auto mode.' % (mode),
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.base_model.set_weights(self.model.get_weights())
self.base_model.save_weights(filepath, overwrite=True)
else:
self.base_model.set_weights(self.model.get_weights())
self.base_model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.base_model.set_weights(self.model.get_weights())
self.base_model.save_weights(filepath, overwrite=True)
else:
self.base_model.set_weights(self.model.get_weights())
self.base_model.save(filepath, overwrite=True)
hdf5 ํ์ผ๋ก ์ ์ ์ฅ๋๋ ๊ฒ๊ณผ, ๋ค์ ๋ก๋ํ๋ ๊ฒ์ด ๋๋ ๊ฒ์ ํ์ธํ๋ค.
'๋จธ์ ๋ฌ๋, ๋ฅ๋ฌ๋ > OCR' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
tf.Example, tfrecord ๊ณต๋ถ (0) | 2019.08.23 |
---|---|
TF Guide - Tensors ๊ณต๋ถ (0) | 2019.08.19 |
Variable Scope ๊ณต๋ถ (0) | 2019.08.13 |
Attention ๊ณต๋ถ (0) | 2019.08.12 |
Attention Is All You Need ๊ณต๋ถ (0) | 2019.08.09 |
๋๊ธ