Home
Project Details

MNIST Digit Recognition Neural Network

A PyTorch neural network for handwritten digit recognition with 98.20% accuracy

Completed: March 2025
Version: 1.0
Python PyTorch Neural Networks Computer Vision Machine Learning
Demo of MNIST digit recognition application

Live demo of the digit recognition application

Project Overview

This project implements a neural network for recognizing handwritten digits using the MNIST dataset. The model achieves 98.20% accuracy on the test set and is deployed in a simple drawing application that allows users to draw digits and see real-time predictions.

The MNIST dataset is a large collection of handwritten digits that is commonly used for training various image processing systems. It's like the "Hello World" of machine learning - a perfect starting point for exploring neural networks and computer vision.

Neural Network Architecture

I designed a multi-layer neural network with carefully calibrated dropout to prevent overfitting:

Neural Network Architecture Visualization

Neural network architecture visualization (each node in the hidden layers represents approximately 20 neurons, while the output layer shows the actual 10 neurons)

The actual architecture consists of:

  • Input Layer: 784 neurons (representing a flattened 28×28 pixel image)
  • Hidden Layer 1: 512 neurons with ReLU activation and Dropout rate of 0.3
  • Hidden Layer 2: 256 neurons with ReLU activation and Dropout rate of 0.2
  • Hidden Layer 3: 128 neurons with ReLU activation and Dropout rate of 0.1
  • Hidden Layer 4: 60 neurons with ReLU activation and Dropout rate of 0.05
  • Output Layer: 10 neurons (one for each digit from 0-9)

For visualization purposes, I've simplified the representation by dividing the actual neuron count by 20 for the hidden layers, while preserving the actual count of 10 neurons in the output layer.

The architecture implementation in PyTorch:

class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() # Flatten the 28x28 image to a 784-dimensional vector self.flatten = nn.Flatten() self.linear_stack = nn.Sequential( #input nn.Flatten(), #layer1 nn.Linear(28*28, 512), nn.ReLU(), nn.Dropout(0.3), #layer 2 nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.2), #layer 3 nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.1), #layer 4 nn.Linear(128, 60), nn.ReLU(), nn.Dropout(0.05), #layer 5 nn.Linear(60,10), ) def forward(self, x): x = self.flatten(x) return self.linear_stack(x)

Technologies & Libraries

This project leverages several powerful machine learning technologies:

  • PyTorch: For creating and training the neural network model
  • Torchvision: For accessing the MNIST dataset and transformations
  • PIL (Python Imaging Library): For image manipulation in the drawing app
  • Tkinter: For creating the interactive GUI application
  • NumPy: For efficient numerical computations
  • Matplotlib: For visualizing model performance and predictions

Training Process

The neural network was trained on the MNIST dataset with the following parameters:

  • Training Epochs: 10
  • Batch Size: 64
  • Optimizer: Adam
  • Learning Rate: 0.001
  • Loss Function: Cross Entropy Loss

The training process involved:

def train_model(model, trainloader, epochs=10): # Loss and Optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Training loop for epoch in range(epochs): running_loss = 0.0 for i, (inputs, labels) in enumerate(trainloader): # Reset gradients optimizer.zero_grad() # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # Backward pass loss.backward() optimizer.step() # Print statistics running_loss += loss.item() if i % 100 == 99: # Print every 100 mini-batches print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}], Loss: {running_loss/100:.4f}') running_loss = 0.0

Interactive User Interface

One of the most exciting aspects of this project is the interactive drawing application that allows users to test the model. The application is built with Tkinter and provides a simple canvas where users can draw digits and get real-time predictions from the trained model. The drawing application's user interface was developed with assistance from Claude 3.7 Sonnet.

Key features of the application include:

  • Canvas for drawing digits with the mouse
  • Real-time prediction of drawn digits
  • Confidence score for predictions
  • Clear button to reset the canvas
def predict(self): # Make sure the model was loaded successfully if not hasattr(self, 'model'): self.result.config(text="No model loaded") return # Create a PIL image from the canvas image = Image.new("L", (280, 280), "black") draw = ImageDraw.Draw(image) # Draw all lines from the canvas onto the PIL image for item in self.canvas.find_all(): coords = self.canvas.coords(item) if len(coords) == 4: # Line has 4 coordinates: x1, y1, x2, y2 draw.line(coords, fill="white", width=15) # Resize to MNIST format (28x28 pixels) image = image.resize((28, 28), Image.Resampling.LANCZOS) # Convert to numpy array and normalize to 0-1 range img_array = np.array(image) / 255.0 # Convert to PyTorch tensor with correct dimensions [batch, channel, height, width] tensor = torch.tensor(img_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0) try: # Make prediction with torch.no_grad(): outputs = self.model(tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) predicted = torch.argmax(outputs, dim=1).item() confidence = probabilities[0][predicted].item() * 100 # Update result label self.result.config(text=f"Prediction: {predicted} ({confidence:.1f}%)") except Exception as e: print(f"Error during prediction: {e}") self.result.config(text="Error making prediction")

Key Learnings

Through this project, I gained valuable experience in:

  • Neural Network Design: Understanding how layer sizes, activation functions, and dropout rates affect model performance
  • PyTorch Workflows: Structuring ML projects using the PyTorch framework
  • Regularization Techniques: Using dropout to prevent overfitting
  • Image Processing: Converting user drawings to the format expected by the model

Challenges & Solutions

During this project, I encountered several challenges:

  • Preventing Overfitting: Initial models achieved near-perfect accuracy on training data but performed poorly on test data. I solved this by implementing a progressive dropout strategy, with higher dropout rates in earlier layers.
  • Balancing Model Complexity: Finding the right architecture that was complex enough to learn patterns but simple enough to generalize well required extensive experimentation.
  • User Interface Design: Creating a responsive drawing interface that accurately captured user input required careful calibration of line widths and image processing steps.
  • Model Deployment: Ensuring the trained model could be loaded correctly in the application required robust error handling and path management.

Future Improvements

This project could be extended in several ways:

  • Implementing batch normalization to improve training stability
  • Exploring convolutional neural networks (CNNs) for improved accuracy
  • Adding data augmentation to improve model robustness
  • Creating a web-based version of the application using Flask or Django
  • Extending the model to recognize more than just digits (letters, symbols, etc.)

Acknowledgments

I'd like to acknowledge the following contributions to this project:

  • Claude 3.7 Sonnet AI for assistance in developing the drawing application UI
  • The PyTorch team for their excellent deep learning framework
  • The creators of the MNIST dataset for providing a standardized benchmark for image recognition

Getting Started with This Project

If you're interested in trying out this project yourself:

Prerequisites

  • Python 3.8+
  • PyTorch
  • NumPy
  • Matplotlib
  • PIL
  • Tkinter (usually comes with Python)

Installation

pip install torch torchvision numpy matplotlib pillow

Running the Application

python digit_app.py

Make sure the path to the model is correct for your system!

Conclusion

This MNIST digit recognition project demonstrates the power of neural networks for image recognition tasks. By achieving 98.20% accuracy with a relatively simple architecture, it shows how effective modern deep learning techniques can be, even without the complexity of convolutional networks.

The interactive drawing application brings the neural network to life, allowing users to experience AI in action as it recognizes their handwritten digits in real-time. This combination of backend ML engineering and frontend user experience design makes for a compelling demonstration of applied machine learning.