ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • cat and dog1
    카테고리 없음 2022. 4. 9. 14:55
    from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D
    from keras.layers import Dense, Activation, Dropout, Flatten
    from keras import optimizers
    from keras.models import Sequential
    from keras.preprocessing.image import ImageDataGenerator
    import numpy as np 
    
    # step 1: load data
    
    img_width = 150
    img_height = 150
    train_data_dir = 'data/train'
    valid_data_dir = 'data/validation'
    
    datagen = ImageDataGenerator(rescale = 1./255)
    
    train_generator = datagen.flow_from_directory(directory=train_data_dir,
    											   target_size=(img_width,img_height),
    											   classes=['dogs','cats'],
    											   class_mode='binary',
    											   batch_size=16)
    
    validation_generator = datagen.flow_from_directory(directory=valid_data_dir,
    											   target_size=(img_width,img_height),
    											   classes=['dogs','cats'],
    											   class_mode='binary',
    											   batch_size=32)
    
    
    # step-2 : build model
    
    model =Sequential()
    
    model.add(Conv2D(32,(3,3), input_shape=(img_width, img_height, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    
    model.add(Conv2D(32,(3,3), input_shape=(img_width, img_height, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    
    model.add(Conv2D(64,(3,3), input_shape=(img_width, img_height, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    
    model.add(Flatten())
    model.add(Dense(64))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    
    model.compile(loss='binary_crossentropy',optimizer='rmsprop',metrics=['accuracy'])
    
    print('model complied!!')
    
    print('starting training....')
    training = model.fit_generator(generator=train_generator, steps_per_epoch=2048 // 16,epochs=20,validation_data=validation_generator,validation_steps=832//16)
    
    print('training finished!!')
    
    print('saving weights to simple_CNN.h5')
    
    model.save_weights('models/simple_CNN.h5')
    
    print('all weights saved successfully !!')
    #models.load_weights('models/simple_CNN.h5')
Designed by Tistory.