{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### **Reinforcement Learning in PyTorch: A Tutorial from Scratch**\n", "\n", "This tutorial introduces reinforcement learning (RL) using PyTorch. We'll focus on Q-Learning with Deep Q-Networks (DQN) to teach an agent how to navigate a simple environment. For demonstration, we'll use the `CartPole-v1` environment from OpenAI's Gym.\n", "\n", "---\n", "\n", "### **What is Reinforcement Learning?**\n", "Reinforcement Learning is a framework where an agent interacts with an environment to learn a policy \\( \\pi(s) \\), mapping states (\\(s\\)) to actions (\\(a\\)), by maximizing cumulative rewards.\n", "\n", "Key terms:\n", "- **State**: The current representation of the environment.\n", "- **Action**: The agent's decision.\n", "- **Reward**: Feedback from the environment.\n", "- **Policy**: Strategy for choosing actions.\n", "\n", "---\n", "\n", "### **Step-by-Step Implementation**\n", "#### **1. Setup**\n", "Install required libraries:\n", "```bash\n", "pip install gym torch matplotlib\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "#### **2. Imports**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gym\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import numpy as np\n", "from collections import deque\n", "import random\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "---\n", "\n", "#### **3. Create the Neural Network**\n", "The DQN approximates the Q-value function \\( Q(s, a) \\).\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class DQN(nn.Module):\n", " def __init__(self, state_dim, action_dim):\n", " super(DQN, self).__init__()\n", " self.fc1 = nn.Linear(state_dim, 128)\n", " self.fc2 = nn.Linear(128, 128)\n", " self.fc3 = nn.Linear(128, action_dim)\n", "\n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x) # Q-values for all actions\n", " return x\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "---\n", "\n", "#### **4. Replay Buffer**\n", "Experience replay stores past transitions to improve sample efficiency." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ReplayBuffer:\n", " def __init__(self, capacity):\n", " self.buffer = deque(maxlen=capacity)\n", "\n", " def push(self, state, action, reward, next_state, done):\n", " self.buffer.append((state, action, reward, next_state, done))\n", "\n", " def sample(self, batch_size):\n", " batch = random.sample(self.buffer, batch_size)\n", " states, actions, rewards, next_states, dones = zip(*batch)\n", " return (np.array(states), np.array(actions), np.array(rewards),\n", " np.array(next_states), np.array(dones))\n", "\n", " def __len__(self):\n", " return len(self.buffer)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "#### **5. Epsilon-Greedy Policy**\n", "The agent explores the environment while balancing exploration and exploitation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def epsilon_greedy_policy(state, epsilon, model, action_dim):\n", " if random.random() < epsilon:\n", " return random.randint(0, action_dim - 1) # Explore\n", " state = torch.FloatTensor(state).unsqueeze(0)\n", " with torch.no_grad():\n", " q_values = model(state)\n", " return q_values.argmax().item() # Exploit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "#### **6. Train the Agent**\n", "Define the main training loop for DQN." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def train_dqn(env, num_episodes=500, batch_size=64, gamma=0.99, epsilon_decay=0.995, min_epsilon=0.01):\n", " state_dim = env.observation_space.shape[0]\n", " action_dim = env.action_space.n\n", " \n", " # Initialize DQN and target networks\n", " dqn = DQN(state_dim, action_dim)\n", " target_dqn = DQN(state_dim, action_dim)\n", " target_dqn.load_state_dict(dqn.state_dict())\n", " optimizer = optim.Adam(dqn.parameters(), lr=0.001)\n", "\n", " replay_buffer = ReplayBuffer(capacity=10000)\n", "\n", " epsilon = 1.0\n", " rewards_per_episode = []\n", "\n", " for episode in range(num_episodes):\n", " state = env.reset()\n", " episode_reward = 0\n", " \n", " while True:\n", " # Choose an action using epsilon-greedy policy\n", " action = epsilon_greedy_policy(state, epsilon, dqn, action_dim)\n", " next_state, reward, done, _ = env.step(action)\n", " episode_reward += reward\n", "\n", " # Store transition in replay buffer\n", " replay_buffer.push(state, action, reward, next_state, done)\n", " state = next_state\n", "\n", " # Train the model if enough data is available\n", " if len(replay_buffer) >= batch_size:\n", " states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)\n", "\n", " states = torch.FloatTensor(states)\n", " actions = torch.LongTensor(actions).unsqueeze(1)\n", " rewards = torch.FloatTensor(rewards)\n", " next_states = torch.FloatTensor(next_states)\n", " dones = torch.FloatTensor(dones)\n", "\n", " # Compute target Q-values\n", " with torch.no_grad():\n", " target_q_values = rewards + gamma * (1 - dones) * target_dqn(next_states).max(1)[0]\n", "\n", " # Compute current Q-values\n", " current_q_values = dqn(states).gather(1, actions).squeeze()\n", "\n", " # Loss: MSE\n", " loss = F.mse_loss(current_q_values, target_q_values)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if done:\n", " break\n", "\n", " # Update target network periodically\n", " if episode % 10 == 0:\n", " target_dqn.load_state_dict(dqn.state_dict())\n", "\n", " # Decay epsilon\n", " epsilon = max(min_epsilon, epsilon * epsilon_decay)\n", "\n", " rewards_per_episode.append(episode_reward)\n", " print(f\"Episode {episode}, Reward: {episode_reward}, Epsilon: {epsilon:.2f}\")\n", "\n", " return rewards_per_episode, dqn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "#### **7. Visualize Training**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# Train the model\n", "env = gym.make('CartPole-v1')\n", "rewards, trained_dqn = train_dqn(env)\n", "\n", "# Plot rewards\n", "plt.plot(rewards)\n", "plt.xlabel(\"Episode\")\n", "plt.ylabel(\"Reward\")\n", "plt.title(\"Training Rewards\")\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "\n", "---\n", "\n", "### **Results**\n", "Run the code to observe:\n", "- A graph of rewards improving over episodes.\n", "- The agent balancing the pole for increasing durations.\n", "\n", "---\n", "\n", "### **Exercises**\n", "1. Experiment with different network architectures.\n", "2. Change hyperparameters like learning rate or replay buffer size.\n", "3. Test the trained model in the environment using a greedy policy.\n", "\n", "---\n", "\n", "This tutorial introduces RL with a simple DQN implementation in PyTorch, laying the foundation for more complex algorithms like Double DQN or Policy Gradient methods." ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }