Upload 3 files
Browse files- LICENSE +21 -0
- main.py +163 -0
- requirements.txt +10 -0
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2024 Mikhail Filippov
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        main.py
    ADDED
    
    | @@ -0,0 +1,163 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import tensorflow as tf
         | 
| 3 | 
            +
            from tensorflow import keras
         | 
| 4 | 
            +
            from keras import layers
         | 
| 5 | 
            +
            from tensorflow.keras.preprocessing.image import ImageDataGenerator
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def create_model(input_shape=(32, 32, 3)):
         | 
| 10 | 
            +
                """Create and return a CNN model for binary image classification."""
         | 
| 11 | 
            +
                model = keras.Sequential([
         | 
| 12 | 
            +
                    layers.Input(shape=input_shape),  # Proper input layer specification
         | 
| 13 | 
            +
                    layers.Conv2D(32, (3, 3), activation='relu'),
         | 
| 14 | 
            +
                    layers.MaxPooling2D((2, 2)),
         | 
| 15 | 
            +
                    layers.Conv2D(64, (3, 3), activation='relu'),
         | 
| 16 | 
            +
                    layers.MaxPooling2D((2, 2)),
         | 
| 17 | 
            +
                    layers.Conv2D(128, (3, 3), activation='relu'),
         | 
| 18 | 
            +
                    layers.MaxPooling2D((2, 2)),
         | 
| 19 | 
            +
                    layers.Flatten(),
         | 
| 20 | 
            +
                    layers.Dense(128, activation='relu'),
         | 
| 21 | 
            +
                    layers.Dense(1, activation='sigmoid')
         | 
| 22 | 
            +
                ])
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                # Compile the model
         | 
| 25 | 
            +
                model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
         | 
| 26 | 
            +
                return model
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def train_model(batch_size=32, epochs=8):
         | 
| 29 | 
            +
                """Train the model and save it."""
         | 
| 30 | 
            +
                # Generate data for training and validation
         | 
| 31 | 
            +
                datagen = ImageDataGenerator(
         | 
| 32 | 
            +
                    rescale=1.0 / 255,
         | 
| 33 | 
            +
                    validation_split=0.2,
         | 
| 34 | 
            +
                    rotation_range=20,      # Add data augmentation
         | 
| 35 | 
            +
                    width_shift_range=0.2,
         | 
| 36 | 
            +
                    height_shift_range=0.2,
         | 
| 37 | 
            +
                    shear_range=0.2,
         | 
| 38 | 
            +
                    zoom_range=0.2,
         | 
| 39 | 
            +
                    horizontal_flip=True
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                train_generator = datagen.flow_from_directory(
         | 
| 43 | 
            +
                    directory='archive/train',
         | 
| 44 | 
            +
                    target_size=(32, 32),
         | 
| 45 | 
            +
                    batch_size=batch_size,
         | 
| 46 | 
            +
                    class_mode='binary',
         | 
| 47 | 
            +
                    subset='training'
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                validation_generator = datagen.flow_from_directory(
         | 
| 51 | 
            +
                    directory='archive/train',
         | 
| 52 | 
            +
                    target_size=(32, 32),
         | 
| 53 | 
            +
                    batch_size=batch_size,
         | 
| 54 | 
            +
                    class_mode='binary',
         | 
| 55 | 
            +
                    subset='validation'
         | 
| 56 | 
            +
                )
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # Create model
         | 
| 59 | 
            +
                model = create_model()
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # Add early stopping to prevent overfitting
         | 
| 62 | 
            +
                early_stopping = keras.callbacks.EarlyStopping(
         | 
| 63 | 
            +
                    monitor='val_loss',
         | 
| 64 | 
            +
                    patience=3,
         | 
| 65 | 
            +
                    restore_best_weights=True
         | 
| 66 | 
            +
                )
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # Train the model
         | 
| 69 | 
            +
                history = model.fit(
         | 
| 70 | 
            +
                    train_generator,
         | 
| 71 | 
            +
                    validation_data=validation_generator,
         | 
| 72 | 
            +
                    epochs=epochs,
         | 
| 73 | 
            +
                    callbacks=[early_stopping]
         | 
| 74 | 
            +
                )
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                # Evaluate the model
         | 
| 77 | 
            +
                test_loss, test_acc = model.evaluate(validation_generator)
         | 
| 78 | 
            +
                print(f'Test accuracy: {test_acc:.4f}')
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                # Save the model
         | 
| 81 | 
            +
                model.save('trained_model.keras')
         | 
| 82 | 
            +
                print("Model saved as 'trained_model.keras'")
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                return model, history
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def load_and_preprocess_image(image_path, target_size=(32, 32)):
         | 
| 87 | 
            +
                """Load and preprocess an image for prediction."""
         | 
| 88 | 
            +
                try:
         | 
| 89 | 
            +
                    img = Image.open(image_path)
         | 
| 90 | 
            +
                    img = img.resize(target_size)
         | 
| 91 | 
            +
                    img = img.convert('RGB')
         | 
| 92 | 
            +
                    img_array = np.array(img) / 255.0
         | 
| 93 | 
            +
                    return np.expand_dims(img_array, axis=0)
         | 
| 94 | 
            +
                except Exception as e:
         | 
| 95 | 
            +
                    print(f"Error processing image: {e}")
         | 
| 96 | 
            +
                    return None
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            def test_model(model_path='trained_model.keras'):
         | 
| 99 | 
            +
                """Load a trained model and use it to classify an image."""
         | 
| 100 | 
            +
                try:
         | 
| 101 | 
            +
                    # Load the trained model
         | 
| 102 | 
            +
                    model = tf.keras.models.load_model(model_path)
         | 
| 103 | 
            +
                except Exception as e:
         | 
| 104 | 
            +
                    print(f"Error loading model: {e}")
         | 
| 105 | 
            +
                    return
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                # Path to the image to test
         | 
| 108 | 
            +
                image_path = input('Enter the path to the image you want to test: ')
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                if not os.path.isfile(image_path):
         | 
| 111 | 
            +
                    print("Invalid path, please enter a valid path to an image.")
         | 
| 112 | 
            +
                    return
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                # Load and preprocess the image
         | 
| 115 | 
            +
                input_image = load_and_preprocess_image(image_path)
         | 
| 116 | 
            +
                if input_image is None:
         | 
| 117 | 
            +
                    return
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                # Predict the class of the image
         | 
| 120 | 
            +
                prediction = model.predict(input_image, verbose=0)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # Define the threshold for classification
         | 
| 123 | 
            +
                threshold = 0.5
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                # Classify the image
         | 
| 126 | 
            +
                classification = "REAL" if prediction[0][0] > threshold else "FAKE"
         | 
| 127 | 
            +
                confidence = prediction[0][0] if prediction[0][0] > threshold else 1 - prediction[0][0]
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                # Print the result
         | 
| 130 | 
            +
                print(f"Classification: {classification}")
         | 
| 131 | 
            +
                print(f"Confidence: {confidence * 100:.2f}%")
         | 
| 132 | 
            +
                print(f"Raw prediction value: {prediction[0][0]:.4f}")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            def main():
         | 
| 135 | 
            +
                """Main function to run the program."""
         | 
| 136 | 
            +
                # Set memory growth to avoid memory allocation errors
         | 
| 137 | 
            +
                gpus = tf.config.experimental.list_physical_devices('GPU')
         | 
| 138 | 
            +
                if gpus:
         | 
| 139 | 
            +
                    try:
         | 
| 140 | 
            +
                        for gpu in gpus:
         | 
| 141 | 
            +
                            tf.config.experimental.set_memory_growth(gpu, True)
         | 
| 142 | 
            +
                    except RuntimeError as e:
         | 
| 143 | 
            +
                        print(f"Error setting memory growth: {e}")
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                # Define hyperparameters
         | 
| 146 | 
            +
                batch_size = 32
         | 
| 147 | 
            +
                epochs = 10
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                while True:
         | 
| 150 | 
            +
                    activation_mode = input('Select mode (train/test/exit): ').lower()
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if activation_mode == 'train':
         | 
| 153 | 
            +
                        train_model(batch_size, epochs)
         | 
| 154 | 
            +
                    elif activation_mode == 'test':
         | 
| 155 | 
            +
                        test_model()
         | 
| 156 | 
            +
                    elif activation_mode == 'exit':
         | 
| 157 | 
            +
                        print("Exiting program.")
         | 
| 158 | 
            +
                        break
         | 
| 159 | 
            +
                    else:
         | 
| 160 | 
            +
                        print('Invalid mode, please select "train", "test", or "exit"')
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            if __name__ == "__main__":
         | 
| 163 | 
            +
                main()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # TensorFlow for deep learning
         | 
| 2 | 
            +
            tensorflow==2.19.0
         | 
| 3 | 
            +
            # NumPy for numerical operations
         | 
| 4 | 
            +
            numpy==2.1.3
         | 
| 5 | 
            +
            # Keras for building neural network models
         | 
| 6 | 
            +
            keras==3.10.0
         | 
| 7 | 
            +
            # Pillow for image processing
         | 
| 8 | 
            +
            pillow==11.2.1
         | 
| 9 | 
            +
            # SciPy for scientific computing
         | 
| 10 | 
            +
            scipy==1.11.4
         |