Upload folder using huggingface_hub
Browse files- README.md +104 -3
- dino.py +276 -0
- env.py +145 -0
- models/ppo_dino_1000k.zip +3 -0
- models/ppo_dino_100k.zip +3 -0
- models/ppo_dino_200k.zip +3 -0
- models/ppo_dino_300k.zip +3 -0
- models/ppo_dino_400k.zip +3 -0
- models/ppo_dino_500k.zip +3 -0
- models/ppo_dino_600k.zip +3 -0
- models/ppo_dino_700k.zip +3 -0
- models/ppo_dino_800k.zip +3 -0
- models/ppo_dino_900k.zip +3 -0
- requirements.txt +11 -0
- rf.py +14 -0
README.md
CHANGED
|
@@ -1,3 +1,104 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AIDino
|
| 2 |
+
|
| 3 |
+
[](https://youtu.be/jsjFVBiZG3I)
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
AIDino is a reinforcement learning project that automates gameplay of the Chrome Dino game using deep reinforcement learning. The system connects to Chrome's DevTools Protocol via WebSocket, captures the game state through efficient screenshot processing, and makes intelligent decisions using a trained Proximal Policy Optimization (PPO) model.
|
| 7 |
+
|
| 8 |
+
## Features
|
| 9 |
+
- Chrome DevTools Protocol integration for direct browser communication
|
| 10 |
+
- Efficient screenshot capture and processing using MSS
|
| 11 |
+
- Custom OpenAI Gym environment for reinforcement learning
|
| 12 |
+
- PPO model implementation via Stable Baselines3
|
| 13 |
+
- Image preprocessing with Sobel edge detection
|
| 14 |
+
- Progressive model checkpoints for tracking training progress
|
| 15 |
+
|
| 16 |
+
## Technical Architecture
|
| 17 |
+
|
| 18 |
+
### Chrome Integration
|
| 19 |
+
The system connects to Chrome's DevTools Protocol via WebSocket, allowing for:
|
| 20 |
+
- Programmatic control of the Dino game
|
| 21 |
+
- Precise input simulation (keyboard events)
|
| 22 |
+
- Real-time game state monitoring
|
| 23 |
+
- Screenshot capture of the game area
|
| 24 |
+
|
| 25 |
+
### Custom Gym Environment
|
| 26 |
+
The project implements a custom OpenAI Gym environment (`DinoEnv`) that:
|
| 27 |
+
- Defines a discrete action space (jump, duck, do nothing)
|
| 28 |
+
- Processes screenshots into suitable observations
|
| 29 |
+
- Provides appropriate rewards based on game progress
|
| 30 |
+
- Handles game reset and initialization
|
| 31 |
+
|
| 32 |
+
### Reinforcement Learning
|
| 33 |
+
The training process uses:
|
| 34 |
+
- Proximal Policy Optimization (PPO) algorithm
|
| 35 |
+
- MLP policy network architecture
|
| 36 |
+
- Reward function that encourages longer survival
|
| 37 |
+
- Periodic model checkpoints to track training progress
|
| 38 |
+
|
| 39 |
+
## Model Checkpoints
|
| 40 |
+
The repository includes 10 progressive model checkpoints from training:
|
| 41 |
+
- ppo_dino_100k.zip through ppo_dino_1000k.zip
|
| 42 |
+
- Each checkpoint represents increased training (100k to 1M timesteps)
|
| 43 |
+
|
| 44 |
+
## Getting Started
|
| 45 |
+
|
| 46 |
+
### Environment Setup
|
| 47 |
+
1. Create and activate a virtual environment:
|
| 48 |
+
```bash
|
| 49 |
+
# Create virtual environment
|
| 50 |
+
python -m venv venv
|
| 51 |
+
|
| 52 |
+
# Activate on Linux/macOS
|
| 53 |
+
source venv/bin/activate
|
| 54 |
+
|
| 55 |
+
# Activate on Windows
|
| 56 |
+
# venv\Scripts\activate
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
2. Install dependencies:
|
| 60 |
+
```bash
|
| 61 |
+
pip install -r requirements.txt
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Running the Project
|
| 65 |
+
1. Launch Chrome with remote debugging:
|
| 66 |
+
```bash
|
| 67 |
+
# Linux
|
| 68 |
+
google-chrome --remote-debugging-port=1234
|
| 69 |
+
|
| 70 |
+
# Windows
|
| 71 |
+
# chrome.exe --remote-debugging-port=1234
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
2. Train the model:
|
| 75 |
+
```bash
|
| 76 |
+
python rf.py
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
3. Use a pre-trained model by modifying the load path in `rf.py`:
|
| 80 |
+
```python
|
| 81 |
+
# Uncomment and modify this line in rf.py
|
| 82 |
+
# model = PPO.load("models/ppo_dino_1000k", env=env)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## How It Works
|
| 86 |
+
1. The environment connects to Chrome via DevTools Protocol
|
| 87 |
+
2. The game is initialized and the initial state is captured
|
| 88 |
+
3. For each timestep:
|
| 89 |
+
- The current game state is captured through screenshots
|
| 90 |
+
- Images are processed using edge detection
|
| 91 |
+
- The model selects an action (jump, duck, nothing)
|
| 92 |
+
- The action is executed through the Chrome connection
|
| 93 |
+
- Rewards are calculated based on survival and game progress
|
| 94 |
+
4. Training continues until the model achieves optimal performance
|
| 95 |
+
|
| 96 |
+
## Performance
|
| 97 |
+
The model demonstrates progressive improvement across training checkpoints, with later checkpoints showing significantly better game performance and higher average scores.
|
| 98 |
+
|
| 99 |
+
## Contributing
|
| 100 |
+
Feel free to contribute to this project! Here are some ways you can help:
|
| 101 |
+
- Improve the reinforcement learning model or try different algorithms
|
| 102 |
+
- Enhance the image processing for better feature detection
|
| 103 |
+
- Optimize the performance for faster training
|
| 104 |
+
- Report bugs or suggest features by opening an issue
|
dino.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import websockets
|
| 4 |
+
import requests
|
| 5 |
+
import base64
|
| 6 |
+
import time
|
| 7 |
+
import mss
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import pyautogui
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Dino:
|
| 16 |
+
def __init__(self, class_name):
|
| 17 |
+
self.class_name = class_name
|
| 18 |
+
self.ws_url = self.get_ws_url()
|
| 19 |
+
self.websocket = None
|
| 20 |
+
self.command_id = 1
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_ws_url():
|
| 24 |
+
response = requests.get('http://localhost:1234/json')
|
| 25 |
+
data = response.json()
|
| 26 |
+
return data[0]['webSocketDebuggerUrl']
|
| 27 |
+
|
| 28 |
+
async def connect(self):
|
| 29 |
+
self.websocket = await websockets.connect(self.ws_url)
|
| 30 |
+
# Enable necessary domains
|
| 31 |
+
await self.send_command("DOM.enable", {})
|
| 32 |
+
await self.send_command("CSS.enable", {})
|
| 33 |
+
await self.send_command("Page.enable", {})
|
| 34 |
+
await self.send_command("Runtime.enable", {})
|
| 35 |
+
|
| 36 |
+
async def send_command(self, method, params):
|
| 37 |
+
command = {
|
| 38 |
+
"id": self.command_id,
|
| 39 |
+
"method": method,
|
| 40 |
+
"params": params
|
| 41 |
+
}
|
| 42 |
+
await self.websocket.send(json.dumps(command))
|
| 43 |
+
self.command_id += 1
|
| 44 |
+
|
| 45 |
+
while True:
|
| 46 |
+
response = await self.websocket.recv()
|
| 47 |
+
response_data = json.loads(response)
|
| 48 |
+
if response_data.get("id") == command["id"]:
|
| 49 |
+
return response_data
|
| 50 |
+
|
| 51 |
+
async def capture_screenshot(self):
|
| 52 |
+
try:
|
| 53 |
+
# Get document root
|
| 54 |
+
root = await self.send_command("DOM.getDocument", {"depth": -1})
|
| 55 |
+
root_node_id = root["result"]["root"]["nodeId"]
|
| 56 |
+
|
| 57 |
+
# Get the node ID of the element with the specified class name
|
| 58 |
+
search = await self.send_command("DOM.querySelector", {"nodeId": root_node_id, "selector": f".{self.class_name}"})
|
| 59 |
+
node_id = search["result"]["nodeId"]
|
| 60 |
+
|
| 61 |
+
# Get the box model of the element
|
| 62 |
+
box_model = await self.send_command("DOM.getBoxModel", {"nodeId": node_id})
|
| 63 |
+
content_box = box_model["result"]["model"]["content"]
|
| 64 |
+
|
| 65 |
+
# Capture screenshot of the area
|
| 66 |
+
screenshot = await self.send_command("Page.captureScreenshot", {
|
| 67 |
+
"clip": {
|
| 68 |
+
"x": content_box[0],
|
| 69 |
+
"y": content_box[1],
|
| 70 |
+
"width": content_box[2] - content_box[0],
|
| 71 |
+
"height": content_box[5] - content_box[1],
|
| 72 |
+
"scale": 1
|
| 73 |
+
}
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
# Decode the base64 screenshot data
|
| 77 |
+
screenshot_data = base64.b64decode(screenshot["result"]["data"])
|
| 78 |
+
image = Image.open(BytesIO(screenshot_data))
|
| 79 |
+
|
| 80 |
+
resized_image = image.resize((image.width//5, image.height//5))
|
| 81 |
+
|
| 82 |
+
# Get the current date and time
|
| 83 |
+
#current_time = datetime.now()
|
| 84 |
+
|
| 85 |
+
# Format the date and time as a string
|
| 86 |
+
#timestamp_string = current_time.strftime('%H:%M:%S')
|
| 87 |
+
|
| 88 |
+
cropped_image = resized_image.crop((52, 0, 82, resized_image.height))
|
| 89 |
+
final_image = cropped_image.resize((30, 92))
|
| 90 |
+
return final_image
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"An error occurred: {e}")
|
| 94 |
+
|
| 95 |
+
async def get_window_name(self):
|
| 96 |
+
try:
|
| 97 |
+
# Evaluate JavaScript to get the window name
|
| 98 |
+
response = await self.send_command("Runtime.evaluate", {
|
| 99 |
+
"expression": "window.name"
|
| 100 |
+
})
|
| 101 |
+
#print(response)
|
| 102 |
+
window_name = response["result"]["result"]["value"]
|
| 103 |
+
|
| 104 |
+
print(f"Window name: {window_name}")
|
| 105 |
+
return window_name
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"An error occurred while getting window name: {e}")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
async def enable_all_obstacles(self):
|
| 111 |
+
try:
|
| 112 |
+
# Evaluate JavaScript to get the window name
|
| 113 |
+
response = await self.send_command("Runtime.evaluate", {
|
| 114 |
+
"expression": "spriteDefinitionByType.original.OBSTACLES[2].minSpeed = 0"
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
print(f"Enabled all obstacles")
|
| 118 |
+
return True
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"An error occurred while enabling obstacles: {e}")
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
async def open_dino(self):
|
| 125 |
+
try:
|
| 126 |
+
response = await self.send_command("Page.navigate", {
|
| 127 |
+
"url": "chrome://dino/"
|
| 128 |
+
})
|
| 129 |
+
|
| 130 |
+
return True
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"An error occurred while opening game: {e}")
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
async def send_key_event(self, key, code, key_code):
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
|
| 139 |
+
response1 = await self.send_command("Input.dispatchKeyEvent", {
|
| 140 |
+
"type": "rawKeyDown",
|
| 141 |
+
"key": key,
|
| 142 |
+
"code": code,
|
| 143 |
+
"keyCode": key_code,
|
| 144 |
+
"windowsVirtualKeyCode": key_code,
|
| 145 |
+
"nativeVirtualKeyCode": key_code,
|
| 146 |
+
"modifiers": 0
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
if key_code == 40: time.sleep(0.4)
|
| 150 |
+
|
| 151 |
+
response = await self.send_command("Input.dispatchKeyEvent", {
|
| 152 |
+
"type": "keyUp",
|
| 153 |
+
"key": key,
|
| 154 |
+
"code": code,
|
| 155 |
+
"keyCode": key_code,
|
| 156 |
+
"windowsVirtualKeyCode": key_code,
|
| 157 |
+
"nativeVirtualKeyCode": key_code,
|
| 158 |
+
"modifiers": 0
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
return True
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"An error occurred while sending key event: {e}")
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
async def send_key_event2(self, key):
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
|
| 170 |
+
pyautogui.press(key)
|
| 171 |
+
|
| 172 |
+
return True
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"An error occurred while sending key event: {e}")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
async def check_status(self):
|
| 178 |
+
try:
|
| 179 |
+
crashed = await self.send_command("Runtime.evaluate", {
|
| 180 |
+
"expression": "Runner.instance_.crashed"
|
| 181 |
+
})
|
| 182 |
+
score = 0.0
|
| 183 |
+
try:
|
| 184 |
+
score = await self.send_command("Runtime.evaluate", {
|
| 185 |
+
"expression": "Runner.instance_.distanceRan"
|
| 186 |
+
})
|
| 187 |
+
|
| 188 |
+
score = float(score['result']['result']['value']) // 10
|
| 189 |
+
except:
|
| 190 |
+
pass
|
| 191 |
+
|
| 192 |
+
return {
|
| 193 |
+
"crashed": crashed['result']['result']['value'],
|
| 194 |
+
"score": score
|
| 195 |
+
}
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"An error occurred while checking status: {e}")
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
async def complete_action(self):
|
| 201 |
+
try:
|
| 202 |
+
crashed = await self.send_command("Runtime.evaluate", {
|
| 203 |
+
"expression": "Runner.instance_.crashed"
|
| 204 |
+
})
|
| 205 |
+
crashed = crashed['result']['result']['value']
|
| 206 |
+
while not crashed:
|
| 207 |
+
jumping = await self.send_command("Runtime.evaluate", {
|
| 208 |
+
"expression": "Runner.instance_.tRex.jumping"
|
| 209 |
+
})
|
| 210 |
+
jumping = jumping['result']['result']['value']
|
| 211 |
+
|
| 212 |
+
ducking = await self.send_command("Runtime.evaluate", {
|
| 213 |
+
"expression": "Runner.instance_.tRex.ducking"
|
| 214 |
+
})
|
| 215 |
+
ducking = ducking['result']['result']['value']
|
| 216 |
+
|
| 217 |
+
crashed = await self.send_command("Runtime.evaluate", {
|
| 218 |
+
"expression": "Runner.instance_.crashed"
|
| 219 |
+
})
|
| 220 |
+
crashed = crashed['result']['result']['value']
|
| 221 |
+
|
| 222 |
+
if (not jumping) and (not ducking): break
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
print(f"An error occurred while selecting action: {e}")
|
| 226 |
+
return None
|
| 227 |
+
|
| 228 |
+
async def capture_screenshot2(self):
|
| 229 |
+
try:
|
| 230 |
+
with mss.mss() as sct:
|
| 231 |
+
# Define the region to capture
|
| 232 |
+
monitor = {
|
| 233 |
+
"top": 245,
|
| 234 |
+
"left": 730,
|
| 235 |
+
"width": 200,
|
| 236 |
+
"height": 45,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
# Capture the screenshot
|
| 240 |
+
screenshot = sct.grab(monitor)
|
| 241 |
+
|
| 242 |
+
# Convert the raw bytes data to a numpy array
|
| 243 |
+
img = np.array(screenshot)
|
| 244 |
+
|
| 245 |
+
# Convert the BGRA image to RGB
|
| 246 |
+
img = img[:, :, :3]
|
| 247 |
+
img = img[..., ::-1]
|
| 248 |
+
|
| 249 |
+
# Convert the numpy array to a PIL image
|
| 250 |
+
image = Image.fromarray(img)
|
| 251 |
+
|
| 252 |
+
#resized_image = image.resize((100, 80))
|
| 253 |
+
|
| 254 |
+
# Get the current date and time
|
| 255 |
+
#current_time = datetime.now()
|
| 256 |
+
|
| 257 |
+
# Format the date and time as a string
|
| 258 |
+
#timestamp_string = current_time.strftime('%H:%M:%S')
|
| 259 |
+
|
| 260 |
+
#resized_image.save(timestamp_string + 'resized_image.png')
|
| 261 |
+
|
| 262 |
+
return image
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
except Exception as e:
|
| 266 |
+
print(f"An error occurred while opening game: {e}")
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
async def start(self):
|
| 270 |
+
await self.connect()
|
| 271 |
+
|
| 272 |
+
# Get the window name once
|
| 273 |
+
#await self.get_window_name()
|
| 274 |
+
|
| 275 |
+
#await self.capture_screenshot()
|
| 276 |
+
|
env.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium import spaces
|
| 3 |
+
import numpy as np
|
| 4 |
+
from dino import Dino
|
| 5 |
+
import asyncio
|
| 6 |
+
import time
|
| 7 |
+
from scipy.ndimage import convolve
|
| 8 |
+
import numpy as np
|
| 9 |
+
import gymnasium as gym
|
| 10 |
+
from tensorboardX import SummaryWriter
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
class DinoEnv(gym.Env):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super(DinoEnv, self).__init__()
|
| 18 |
+
|
| 19 |
+
# Define the action and observation space
|
| 20 |
+
# Actions: 0=up, 1=down, 2 do nothing
|
| 21 |
+
self.action_space = spaces.Discrete(3)
|
| 22 |
+
|
| 23 |
+
self.sobel_kernel = np.array([
|
| 24 |
+
[-1, 0, 1],
|
| 25 |
+
[-2, 0, 2],
|
| 26 |
+
[-1, 0, 1]
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
self.im_size = 9000
|
| 30 |
+
|
| 31 |
+
# Define the observation space: 11316 integer pixels
|
| 32 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(self.im_size,), dtype=np.int32)
|
| 33 |
+
|
| 34 |
+
# Initialize the state (e.g., all pixels set to 0)
|
| 35 |
+
self.state = np.zeros((self.im_size,), dtype=np.int32)
|
| 36 |
+
|
| 37 |
+
# Initialize other variables (e.g., a variable to keep track of the score)
|
| 38 |
+
self.score = 0
|
| 39 |
+
|
| 40 |
+
self.current_step = 0
|
| 41 |
+
|
| 42 |
+
def step(self, action):
|
| 43 |
+
print( str(action) + ' : ' + str(self.score))
|
| 44 |
+
self.current_step += 1
|
| 45 |
+
if action == 0: # up
|
| 46 |
+
#asyncio.get_event_loop().run_until_complete(self.dino.send_key_event2("down"))
|
| 47 |
+
asyncio.get_event_loop().run_until_complete(self.dino.send_key_event("ArrowUp", "ArrowUp", 38))
|
| 48 |
+
asyncio.get_event_loop().run_until_complete(self.dino.complete_action())
|
| 49 |
+
elif action == 1: # down
|
| 50 |
+
#asyncio.get_event_loop().run_until_complete(self.dino.send_key_event2("down"))
|
| 51 |
+
asyncio.get_event_loop().run_until_complete(self.dino.send_key_event("ArrowDown", "ArrowDown", 40))
|
| 52 |
+
asyncio.get_event_loop().run_until_complete(self.dino.complete_action())
|
| 53 |
+
elif action == 2: # do nothing
|
| 54 |
+
time.sleep(0.1)
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
self.state = self.get_screenshot()
|
| 58 |
+
|
| 59 |
+
status = asyncio.get_event_loop().run_until_complete(self.dino.check_status())
|
| 60 |
+
|
| 61 |
+
if not status or status['crashed']:
|
| 62 |
+
with open('scores.txt', 'a') as f:
|
| 63 |
+
now = datetime.now() # current date and time
|
| 64 |
+
date_time = now.strftime("%m/%d/%Y, %H:%M:%S")
|
| 65 |
+
print("---------date and time:",date_time)
|
| 66 |
+
f.write(date_time + ' - '+ str(self.score) + '\n')
|
| 67 |
+
reward = -100
|
| 68 |
+
done = True
|
| 69 |
+
if not status:
|
| 70 |
+
asyncio.get_event_loop().run_until_complete(self.dino.open_dino())
|
| 71 |
+
else:
|
| 72 |
+
reward = 2
|
| 73 |
+
#if action == 1: reward = 3
|
| 74 |
+
done = False
|
| 75 |
+
|
| 76 |
+
self.score+=reward
|
| 77 |
+
|
| 78 |
+
info = {}
|
| 79 |
+
|
| 80 |
+
#observation, reward, terminated, truncated, info
|
| 81 |
+
return self.state, reward, done, False, info
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def reset(self, seed=None, options=None):
|
| 85 |
+
super().reset(seed=seed)
|
| 86 |
+
|
| 87 |
+
self.dino = Dino('runner-canvas')
|
| 88 |
+
|
| 89 |
+
asyncio.get_event_loop().run_until_complete(self.dino.start())
|
| 90 |
+
|
| 91 |
+
asyncio.get_event_loop().run_until_complete(self.dino.open_dino())
|
| 92 |
+
|
| 93 |
+
time.sleep(2)
|
| 94 |
+
|
| 95 |
+
asyncio.get_event_loop().run_until_complete(self.dino.send_key_event(" ", "Space", 32))
|
| 96 |
+
|
| 97 |
+
time.sleep(2)
|
| 98 |
+
|
| 99 |
+
asyncio.get_event_loop().run_until_complete(self.dino.capture_screenshot2())
|
| 100 |
+
|
| 101 |
+
self.state = self.get_screenshot()
|
| 102 |
+
|
| 103 |
+
info = {}
|
| 104 |
+
|
| 105 |
+
self.current_step = 0
|
| 106 |
+
|
| 107 |
+
self.score = 0
|
| 108 |
+
|
| 109 |
+
asyncio.get_event_loop().run_until_complete(self.dino.enable_all_obstacles())
|
| 110 |
+
|
| 111 |
+
return self.state, info
|
| 112 |
+
|
| 113 |
+
def get_screenshot(self):
|
| 114 |
+
image = asyncio.get_event_loop().run_until_complete(self.dino.capture_screenshot2())
|
| 115 |
+
|
| 116 |
+
gray_image = image.convert('L')
|
| 117 |
+
|
| 118 |
+
image_array = np.array(gray_image)
|
| 119 |
+
|
| 120 |
+
convolved_image = convolve(image_array, self.sobel_kernel)
|
| 121 |
+
# Get the current date and time
|
| 122 |
+
current_time = datetime.now()
|
| 123 |
+
|
| 124 |
+
# Format the date and time as a string
|
| 125 |
+
timestamp_string = current_time.strftime('%H:%M:%S')
|
| 126 |
+
|
| 127 |
+
image = Image.fromarray(convolved_image)
|
| 128 |
+
|
| 129 |
+
# Save the image
|
| 130 |
+
#image.save('im/' + timestamp_string + 'conv_image.png')
|
| 131 |
+
|
| 132 |
+
flattened_array = convolved_image.flatten()
|
| 133 |
+
|
| 134 |
+
return flattened_array
|
| 135 |
+
|
| 136 |
+
def render(self, mode='human'):
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# Register the custom environment
|
| 141 |
+
gym.envs.registration.register(
|
| 142 |
+
id='DinoEnv-v0',
|
| 143 |
+
entry_point=__name__ + ':DinoEnv',
|
| 144 |
+
)
|
| 145 |
+
|
models/ppo_dino_1000k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:262d370a4c6af4ce1de1fc87ce618e8f2f79a65cef5756001d5f3c61cd3a2dbc
|
| 3 |
+
size 14129512
|
models/ppo_dino_100k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2275e46dc4d90d5a3ab3b7348a28a1af0f61234dd78b6a9f745ca7b62bbd39b
|
| 3 |
+
size 14129459
|
models/ppo_dino_200k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd376e9c53002769c61e7d4464bdcc864b6b6f0b2695390779af87e52198ad4a
|
| 3 |
+
size 14129479
|
models/ppo_dino_300k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7f0099ccd3a945d7ed7e72176bb5dc10f6213dfceff5a3596d9dac7944465a2
|
| 3 |
+
size 14129480
|
models/ppo_dino_400k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f43c84fc9a2506bfd22632e325dc355701812d1539a4ac73f9a9328169c3fb4d
|
| 3 |
+
size 14129484
|
models/ppo_dino_500k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6cedc8fbf789ccc67d0d374f5ec7d421419b8627ed5f7abc7ed240e4adbfffe
|
| 3 |
+
size 14129492
|
models/ppo_dino_600k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92d5465e28ce410f8ecc34b8779e08fd4326212c0ed18baf8ca237ef2854ab12
|
| 3 |
+
size 14129480
|
models/ppo_dino_700k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e877d70cf77fc234f0ef9c44360de26c20d68d5355d586296cecbc33508b6ed8
|
| 3 |
+
size 14129488
|
models/ppo_dino_800k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f50e439a5ac9a6afa674cd267c98ce7921eadbca8df2c5e507ce686147b9c393
|
| 3 |
+
size 14129496
|
models/ppo_dino_900k.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ad4fa130ef150c2f715ae7aae5563168f543c51c51986b10530d0d7b423374b
|
| 3 |
+
size 14129496
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
asyncio
|
| 2 |
+
websockets
|
| 3 |
+
requests
|
| 4 |
+
mss
|
| 5 |
+
numpy
|
| 6 |
+
Pillow
|
| 7 |
+
pyautogui
|
| 8 |
+
gymnasium
|
| 9 |
+
scipy
|
| 10 |
+
stable-baselines3
|
| 11 |
+
tensorboardX
|
rf.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
import env
|
| 3 |
+
from stable_baselines3 import PPO
|
| 4 |
+
env = gym.make("DinoEnv-v0")
|
| 5 |
+
model = PPO("MlpPolicy", env, verbose=1)
|
| 6 |
+
|
| 7 |
+
# Load the saved PPO model (Optional)
|
| 8 |
+
#model = PPO.load("models/ppo_dino_backup8", env=env)
|
| 9 |
+
for i in range (100):
|
| 10 |
+
model.learn(total_timesteps=100000)
|
| 11 |
+
model.save("models/ppo_dino")
|
| 12 |
+
model.save("models/ppo_dino_backup" + str(i))
|
| 13 |
+
|
| 14 |
+
env.close()
|