๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๋จธ์‹ ๋Ÿฌ๋‹, ๋”ฅ๋Ÿฌ๋‹/OCR

Multi-GPU Model์—์„œ h5(hdf5)๋ฅผ ๋กœ๋“œํ•˜์ง€ ๋ชปํ•˜๋Š” ๋ฌธ์ œ

by ํ–‰๋ฑ 2019. 8. 13.

๋ฌธ์ œ ๊ฐœ์š”

Multi-GPU Model๋กœ ํ•™์Šต์‹œํ‚ค๊ณ  h5(hdf5)๋กœ ์ €์žฅํ•ด๋‘” ๋‹ค์Œ, ๋‚˜์ค‘์— ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ layers ์ˆ˜๊ฐ€ ๋งž์ง€ ์•Š๋‹ค๊ณ  ์—๋Ÿฌ๊ฐ€ ๋‚จ

์ด๋ฏธ ๋งŽ์ด ์•Œ๋ ค์ง„ ๋ฒ„๊ทธ

 

ํ•ด๊ฒฐ์ฑ… ๊ฐœ์š”

Multi-GPU Model๋กœ ํ•™์Šตํ•œ weight๋ฅผ (Multi-GPU๋กœ compileํ•˜์ง€ ์•Š์€) Base Model์— setํ•œ ๋‹ค์Œ h5(hdf5)๋กœ ์ €์žฅ

 

Epoch๊ฐ€ ๋ชจ๋‘ ๋๋‚œ ๋‹ค์Œ model.save()๋Š” ๋Œ€๋žต ์•„๋ž˜์™€ ๊ฐ™์ด ํ•˜๋ฉด ๋˜์ง€๋งŒ...

...
parallel_model, base_model = get_compiled_model(...)

parallel_model.fit(...)

base_model.set_weights(parallel_model.get_weights())
base_model.save(filepath)
...

Keras์—์„œ ์ œ๊ณตํ•˜๋Š” Callback ์ค‘ ModelCheckpoint๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ๊ฐ„๋‹จํžˆ ๋ฐ”๊ฟ€ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด ์—†๋‹ค...

๊ทธ๋ž˜์„œ ModelCheckpoint๋ฅผ ์ƒ์†๋ฐ›์•„์„œ ์œ„์™€ ๊ฐ™์ด ์•ฝ๊ฐ„ ์ˆ˜์ •ํ•ด์ค˜์•ผ ํ•œ๋‹ค. ์•„๋ž˜ ๋งํฌ๋“ค์„ ์ฐธ๊ณ ..

 

https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L275

https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L633

์ˆœ์„œ๋Œ€๋กœ Callback, ModelCheckpoint ์ฝ”๋“œ

 

https://github.com/keras-team/keras/issues/8123#issuecomment-348976624

ModelCheckpoint Custom rf

 

๊ทธ๋ž˜์„œ ๋ฐ”๊พผ Multi-GPU Model ์šฉ ModelCheckpoint๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

class MultiGPUModelCheckpoint(Callback):
    """
    Multi GPU ์‚ฌ์šฉ ์‹œ ๋ชจ๋ธ ์ €์žฅ ํ›„ ๋กœ๋“œ๊ฐ€ ์•ˆ ๋˜์–ด base_model์„ ํ†ตํ•ด ์ €์žฅํ•จ
    """
    def __init__(self, filepath, base_model, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        super(MultiGPUModelCheckpoint, self).__init__()
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0
        self.base_model = base_model

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('ModelCheckpoint mode %s is unknown, '
                          'fallback to auto mode.' % (mode),
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf
    
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = self.filepath.format(epoch=epoch + 1, **logs)
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    warnings.warn('Can save best model only with %s available, '
                                  'skipping.' % (self.monitor), RuntimeWarning)
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                                  ' saving model to %s'
                                  % (epoch + 1, self.monitor, self.best,
                                     current, filepath))
                        self.best = current
                        if self.save_weights_only:
                            self.base_model.set_weights(self.model.get_weights())
                            self.base_model.save_weights(filepath, overwrite=True)
                        else:
                            self.base_model.set_weights(self.model.get_weights())
                            self.base_model.save(filepath, overwrite=True)
                    else:
                        if self.verbose > 0:
                            print('\nEpoch %05d: %s did not improve from %0.5f' %
                                  (epoch + 1, self.monitor, self.best))
            else:
                if self.verbose > 0:
                    print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
                if self.save_weights_only:
                    self.base_model.set_weights(self.model.get_weights())
                    self.base_model.save_weights(filepath, overwrite=True)
                else:
                    self.base_model.set_weights(self.model.get_weights())
                    self.base_model.save(filepath, overwrite=True)

hdf5 ํŒŒ์ผ๋กœ ์ž˜ ์ €์žฅ๋˜๋Š” ๊ฒƒ๊ณผ, ๋‹ค์‹œ ๋กœ๋“œํ•˜๋Š” ๊ฒƒ์ด ๋˜๋Š” ๊ฒƒ์„ ํ™•์ธํ–ˆ๋‹ค.

'๋จธ์‹ ๋Ÿฌ๋‹, ๋”ฅ๋Ÿฌ๋‹ > OCR' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

tf.Example, tfrecord ๊ณต๋ถ€  (0) 2019.08.23
TF Guide - Tensors ๊ณต๋ถ€  (0) 2019.08.19
Variable Scope ๊ณต๋ถ€  (0) 2019.08.13
Attention ๊ณต๋ถ€  (0) 2019.08.12
Attention Is All You Need ๊ณต๋ถ€  (0) 2019.08.09

๋Œ“๊ธ€