Skip to content

fchollet/deep-learning-models

Repository files navigation

Trained image classification models for Keras

THIS REPOSITORY IS DEPRECATED. USE THE MODULEkeras.applicationsINSTEAD.

Pull requests will not be reviewed nor merged. Direct any PRs tokeras.applications.Issues are not monitored either.


This repository contains code for the following Keras models:

  • VGG16
  • VGG19
  • ResNet50
  • Inception v3
  • CRNN for music tagging

All architectures are compatible with both TensorFlow and Theano, and upon instantiation the models will be built according to the image dimension ordering set in your Keras configuration file at~/.keras/keras.json.For instance, if you have setimage_dim_ordering=tf,then any model loaded from this repository will get built according to the TensorFlow dimension ordering convention, "Width-Height-Depth".

Pre-trained weights can be automatically loaded upon instantiation (weights='imagenet'argument in model constructor for all image models,weights='msd'for the music tagging model). Weights are automatically downloaded if necessary, and cached locally in~/.keras/models/.

Examples

Classify images

fromresnet50importResNet50
fromkeras.preprocessingimportimage
fromimagenet_utilsimportpreprocess_input,decode_predictions

model=ResNet50(weights='imagenet')

img_path='elephant.jpg'
img=image.load_img(img_path,target_size=(224,224))
x=image.img_to_array(img)
x=np.expand_dims(x,axis=0)
x=preprocess_input(x)

preds=model.predict(x)
print('Predicted:',decode_predictions(preds))
# print: [[u'n02504458', u'African_elephant']]

Extract features from images

fromvgg16importVGG16
fromkeras.preprocessingimportimage
fromimagenet_utilsimportpreprocess_input

model=VGG16(weights='imagenet',include_top=False)

img_path='elephant.jpg'
img=image.load_img(img_path,target_size=(224,224))
x=image.img_to_array(img)
x=np.expand_dims(x,axis=0)
x=preprocess_input(x)

features=model.predict(x)

Extract features from an arbitrary intermediate layer

fromvgg19importVGG19
fromkeras.preprocessingimportimage
fromimagenet_utilsimportpreprocess_input
fromkeras.modelsimportModel

base_model=VGG19(weights='imagenet')
model=Model(input=base_model.input,output=base_model.get_layer('block4_pool').output)

img_path='elephant.jpg'
img=image.load_img(img_path,target_size=(224,224))
x=image.img_to_array(img)
x=np.expand_dims(x,axis=0)
x=preprocess_input(x)

block4_pool_features=model.predict(x)

References

Additionally, don't forget tocite Kerasif you use these models.

License