shanaka95 commited on
Commit
3e9e3ce
·
verified ·
1 Parent(s): 3f7e178

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,104 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AIDino
2
+
3
+ [![AIDino Demo Video](https://img.shields.io/badge/Watch%20Demo-YouTube-red?style=for-the-badge&logo=youtube)](https://youtu.be/jsjFVBiZG3I)
4
+
5
+ ## Overview
6
+ AIDino is a reinforcement learning project that automates gameplay of the Chrome Dino game using deep reinforcement learning. The system connects to Chrome's DevTools Protocol via WebSocket, captures the game state through efficient screenshot processing, and makes intelligent decisions using a trained Proximal Policy Optimization (PPO) model.
7
+
8
+ ## Features
9
+ - Chrome DevTools Protocol integration for direct browser communication
10
+ - Efficient screenshot capture and processing using MSS
11
+ - Custom OpenAI Gym environment for reinforcement learning
12
+ - PPO model implementation via Stable Baselines3
13
+ - Image preprocessing with Sobel edge detection
14
+ - Progressive model checkpoints for tracking training progress
15
+
16
+ ## Technical Architecture
17
+
18
+ ### Chrome Integration
19
+ The system connects to Chrome's DevTools Protocol via WebSocket, allowing for:
20
+ - Programmatic control of the Dino game
21
+ - Precise input simulation (keyboard events)
22
+ - Real-time game state monitoring
23
+ - Screenshot capture of the game area
24
+
25
+ ### Custom Gym Environment
26
+ The project implements a custom OpenAI Gym environment (`DinoEnv`) that:
27
+ - Defines a discrete action space (jump, duck, do nothing)
28
+ - Processes screenshots into suitable observations
29
+ - Provides appropriate rewards based on game progress
30
+ - Handles game reset and initialization
31
+
32
+ ### Reinforcement Learning
33
+ The training process uses:
34
+ - Proximal Policy Optimization (PPO) algorithm
35
+ - MLP policy network architecture
36
+ - Reward function that encourages longer survival
37
+ - Periodic model checkpoints to track training progress
38
+
39
+ ## Model Checkpoints
40
+ The repository includes 10 progressive model checkpoints from training:
41
+ - ppo_dino_100k.zip through ppo_dino_1000k.zip
42
+ - Each checkpoint represents increased training (100k to 1M timesteps)
43
+
44
+ ## Getting Started
45
+
46
+ ### Environment Setup
47
+ 1. Create and activate a virtual environment:
48
+ ```bash
49
+ # Create virtual environment
50
+ python -m venv venv
51
+
52
+ # Activate on Linux/macOS
53
+ source venv/bin/activate
54
+
55
+ # Activate on Windows
56
+ # venv\Scripts\activate
57
+ ```
58
+
59
+ 2. Install dependencies:
60
+ ```bash
61
+ pip install -r requirements.txt
62
+ ```
63
+
64
+ ### Running the Project
65
+ 1. Launch Chrome with remote debugging:
66
+ ```bash
67
+ # Linux
68
+ google-chrome --remote-debugging-port=1234
69
+
70
+ # Windows
71
+ # chrome.exe --remote-debugging-port=1234
72
+ ```
73
+
74
+ 2. Train the model:
75
+ ```bash
76
+ python rf.py
77
+ ```
78
+
79
+ 3. Use a pre-trained model by modifying the load path in `rf.py`:
80
+ ```python
81
+ # Uncomment and modify this line in rf.py
82
+ # model = PPO.load("models/ppo_dino_1000k", env=env)
83
+ ```
84
+
85
+ ## How It Works
86
+ 1. The environment connects to Chrome via DevTools Protocol
87
+ 2. The game is initialized and the initial state is captured
88
+ 3. For each timestep:
89
+ - The current game state is captured through screenshots
90
+ - Images are processed using edge detection
91
+ - The model selects an action (jump, duck, nothing)
92
+ - The action is executed through the Chrome connection
93
+ - Rewards are calculated based on survival and game progress
94
+ 4. Training continues until the model achieves optimal performance
95
+
96
+ ## Performance
97
+ The model demonstrates progressive improvement across training checkpoints, with later checkpoints showing significantly better game performance and higher average scores.
98
+
99
+ ## Contributing
100
+ Feel free to contribute to this project! Here are some ways you can help:
101
+ - Improve the reinforcement learning model or try different algorithms
102
+ - Enhance the image processing for better feature detection
103
+ - Optimize the performance for faster training
104
+ - Report bugs or suggest features by opening an issue
dino.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import websockets
4
+ import requests
5
+ import base64
6
+ import time
7
+ import mss
8
+ import numpy as np
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ from datetime import datetime
12
+ import pyautogui
13
+
14
+
15
+ class Dino:
16
+ def __init__(self, class_name):
17
+ self.class_name = class_name
18
+ self.ws_url = self.get_ws_url()
19
+ self.websocket = None
20
+ self.command_id = 1
21
+
22
+ @staticmethod
23
+ def get_ws_url():
24
+ response = requests.get('http://localhost:1234/json')
25
+ data = response.json()
26
+ return data[0]['webSocketDebuggerUrl']
27
+
28
+ async def connect(self):
29
+ self.websocket = await websockets.connect(self.ws_url)
30
+ # Enable necessary domains
31
+ await self.send_command("DOM.enable", {})
32
+ await self.send_command("CSS.enable", {})
33
+ await self.send_command("Page.enable", {})
34
+ await self.send_command("Runtime.enable", {})
35
+
36
+ async def send_command(self, method, params):
37
+ command = {
38
+ "id": self.command_id,
39
+ "method": method,
40
+ "params": params
41
+ }
42
+ await self.websocket.send(json.dumps(command))
43
+ self.command_id += 1
44
+
45
+ while True:
46
+ response = await self.websocket.recv()
47
+ response_data = json.loads(response)
48
+ if response_data.get("id") == command["id"]:
49
+ return response_data
50
+
51
+ async def capture_screenshot(self):
52
+ try:
53
+ # Get document root
54
+ root = await self.send_command("DOM.getDocument", {"depth": -1})
55
+ root_node_id = root["result"]["root"]["nodeId"]
56
+
57
+ # Get the node ID of the element with the specified class name
58
+ search = await self.send_command("DOM.querySelector", {"nodeId": root_node_id, "selector": f".{self.class_name}"})
59
+ node_id = search["result"]["nodeId"]
60
+
61
+ # Get the box model of the element
62
+ box_model = await self.send_command("DOM.getBoxModel", {"nodeId": node_id})
63
+ content_box = box_model["result"]["model"]["content"]
64
+
65
+ # Capture screenshot of the area
66
+ screenshot = await self.send_command("Page.captureScreenshot", {
67
+ "clip": {
68
+ "x": content_box[0],
69
+ "y": content_box[1],
70
+ "width": content_box[2] - content_box[0],
71
+ "height": content_box[5] - content_box[1],
72
+ "scale": 1
73
+ }
74
+ })
75
+
76
+ # Decode the base64 screenshot data
77
+ screenshot_data = base64.b64decode(screenshot["result"]["data"])
78
+ image = Image.open(BytesIO(screenshot_data))
79
+
80
+ resized_image = image.resize((image.width//5, image.height//5))
81
+
82
+ # Get the current date and time
83
+ #current_time = datetime.now()
84
+
85
+ # Format the date and time as a string
86
+ #timestamp_string = current_time.strftime('%H:%M:%S')
87
+
88
+ cropped_image = resized_image.crop((52, 0, 82, resized_image.height))
89
+ final_image = cropped_image.resize((30, 92))
90
+ return final_image
91
+
92
+ except Exception as e:
93
+ print(f"An error occurred: {e}")
94
+
95
+ async def get_window_name(self):
96
+ try:
97
+ # Evaluate JavaScript to get the window name
98
+ response = await self.send_command("Runtime.evaluate", {
99
+ "expression": "window.name"
100
+ })
101
+ #print(response)
102
+ window_name = response["result"]["result"]["value"]
103
+
104
+ print(f"Window name: {window_name}")
105
+ return window_name
106
+ except Exception as e:
107
+ print(f"An error occurred while getting window name: {e}")
108
+ return None
109
+
110
+ async def enable_all_obstacles(self):
111
+ try:
112
+ # Evaluate JavaScript to get the window name
113
+ response = await self.send_command("Runtime.evaluate", {
114
+ "expression": "spriteDefinitionByType.original.OBSTACLES[2].minSpeed = 0"
115
+ })
116
+
117
+ print(f"Enabled all obstacles")
118
+ return True
119
+ except Exception as e:
120
+ print(f"An error occurred while enabling obstacles: {e}")
121
+ return None
122
+
123
+
124
+ async def open_dino(self):
125
+ try:
126
+ response = await self.send_command("Page.navigate", {
127
+ "url": "chrome://dino/"
128
+ })
129
+
130
+ return True
131
+ except Exception as e:
132
+ print(f"An error occurred while opening game: {e}")
133
+ return None
134
+
135
+ async def send_key_event(self, key, code, key_code):
136
+
137
+ try:
138
+
139
+ response1 = await self.send_command("Input.dispatchKeyEvent", {
140
+ "type": "rawKeyDown",
141
+ "key": key,
142
+ "code": code,
143
+ "keyCode": key_code,
144
+ "windowsVirtualKeyCode": key_code,
145
+ "nativeVirtualKeyCode": key_code,
146
+ "modifiers": 0
147
+ })
148
+
149
+ if key_code == 40: time.sleep(0.4)
150
+
151
+ response = await self.send_command("Input.dispatchKeyEvent", {
152
+ "type": "keyUp",
153
+ "key": key,
154
+ "code": code,
155
+ "keyCode": key_code,
156
+ "windowsVirtualKeyCode": key_code,
157
+ "nativeVirtualKeyCode": key_code,
158
+ "modifiers": 0
159
+ })
160
+
161
+ return True
162
+ except Exception as e:
163
+ print(f"An error occurred while sending key event: {e}")
164
+ return None
165
+
166
+ async def send_key_event2(self, key):
167
+
168
+ try:
169
+
170
+ pyautogui.press(key)
171
+
172
+ return True
173
+ except Exception as e:
174
+ print(f"An error occurred while sending key event: {e}")
175
+ return None
176
+
177
+ async def check_status(self):
178
+ try:
179
+ crashed = await self.send_command("Runtime.evaluate", {
180
+ "expression": "Runner.instance_.crashed"
181
+ })
182
+ score = 0.0
183
+ try:
184
+ score = await self.send_command("Runtime.evaluate", {
185
+ "expression": "Runner.instance_.distanceRan"
186
+ })
187
+
188
+ score = float(score['result']['result']['value']) // 10
189
+ except:
190
+ pass
191
+
192
+ return {
193
+ "crashed": crashed['result']['result']['value'],
194
+ "score": score
195
+ }
196
+ except Exception as e:
197
+ print(f"An error occurred while checking status: {e}")
198
+ return None
199
+
200
+ async def complete_action(self):
201
+ try:
202
+ crashed = await self.send_command("Runtime.evaluate", {
203
+ "expression": "Runner.instance_.crashed"
204
+ })
205
+ crashed = crashed['result']['result']['value']
206
+ while not crashed:
207
+ jumping = await self.send_command("Runtime.evaluate", {
208
+ "expression": "Runner.instance_.tRex.jumping"
209
+ })
210
+ jumping = jumping['result']['result']['value']
211
+
212
+ ducking = await self.send_command("Runtime.evaluate", {
213
+ "expression": "Runner.instance_.tRex.ducking"
214
+ })
215
+ ducking = ducking['result']['result']['value']
216
+
217
+ crashed = await self.send_command("Runtime.evaluate", {
218
+ "expression": "Runner.instance_.crashed"
219
+ })
220
+ crashed = crashed['result']['result']['value']
221
+
222
+ if (not jumping) and (not ducking): break
223
+
224
+ except Exception as e:
225
+ print(f"An error occurred while selecting action: {e}")
226
+ return None
227
+
228
+ async def capture_screenshot2(self):
229
+ try:
230
+ with mss.mss() as sct:
231
+ # Define the region to capture
232
+ monitor = {
233
+ "top": 245,
234
+ "left": 730,
235
+ "width": 200,
236
+ "height": 45,
237
+ }
238
+
239
+ # Capture the screenshot
240
+ screenshot = sct.grab(monitor)
241
+
242
+ # Convert the raw bytes data to a numpy array
243
+ img = np.array(screenshot)
244
+
245
+ # Convert the BGRA image to RGB
246
+ img = img[:, :, :3]
247
+ img = img[..., ::-1]
248
+
249
+ # Convert the numpy array to a PIL image
250
+ image = Image.fromarray(img)
251
+
252
+ #resized_image = image.resize((100, 80))
253
+
254
+ # Get the current date and time
255
+ #current_time = datetime.now()
256
+
257
+ # Format the date and time as a string
258
+ #timestamp_string = current_time.strftime('%H:%M:%S')
259
+
260
+ #resized_image.save(timestamp_string + 'resized_image.png')
261
+
262
+ return image
263
+
264
+
265
+ except Exception as e:
266
+ print(f"An error occurred while opening game: {e}")
267
+ return None
268
+
269
+ async def start(self):
270
+ await self.connect()
271
+
272
+ # Get the window name once
273
+ #await self.get_window_name()
274
+
275
+ #await self.capture_screenshot()
276
+
env.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium import spaces
3
+ import numpy as np
4
+ from dino import Dino
5
+ import asyncio
6
+ import time
7
+ from scipy.ndimage import convolve
8
+ import numpy as np
9
+ import gymnasium as gym
10
+ from tensorboardX import SummaryWriter
11
+ from datetime import datetime
12
+ from PIL import Image
13
+ from datetime import datetime
14
+
15
+ class DinoEnv(gym.Env):
16
+ def __init__(self):
17
+ super(DinoEnv, self).__init__()
18
+
19
+ # Define the action and observation space
20
+ # Actions: 0=up, 1=down, 2 do nothing
21
+ self.action_space = spaces.Discrete(3)
22
+
23
+ self.sobel_kernel = np.array([
24
+ [-1, 0, 1],
25
+ [-2, 0, 2],
26
+ [-1, 0, 1]
27
+ ])
28
+
29
+ self.im_size = 9000
30
+
31
+ # Define the observation space: 11316 integer pixels
32
+ self.observation_space = spaces.Box(low=0, high=255, shape=(self.im_size,), dtype=np.int32)
33
+
34
+ # Initialize the state (e.g., all pixels set to 0)
35
+ self.state = np.zeros((self.im_size,), dtype=np.int32)
36
+
37
+ # Initialize other variables (e.g., a variable to keep track of the score)
38
+ self.score = 0
39
+
40
+ self.current_step = 0
41
+
42
+ def step(self, action):
43
+ print( str(action) + ' : ' + str(self.score))
44
+ self.current_step += 1
45
+ if action == 0: # up
46
+ #asyncio.get_event_loop().run_until_complete(self.dino.send_key_event2("down"))
47
+ asyncio.get_event_loop().run_until_complete(self.dino.send_key_event("ArrowUp", "ArrowUp", 38))
48
+ asyncio.get_event_loop().run_until_complete(self.dino.complete_action())
49
+ elif action == 1: # down
50
+ #asyncio.get_event_loop().run_until_complete(self.dino.send_key_event2("down"))
51
+ asyncio.get_event_loop().run_until_complete(self.dino.send_key_event("ArrowDown", "ArrowDown", 40))
52
+ asyncio.get_event_loop().run_until_complete(self.dino.complete_action())
53
+ elif action == 2: # do nothing
54
+ time.sleep(0.1)
55
+ pass
56
+
57
+ self.state = self.get_screenshot()
58
+
59
+ status = asyncio.get_event_loop().run_until_complete(self.dino.check_status())
60
+
61
+ if not status or status['crashed']:
62
+ with open('scores.txt', 'a') as f:
63
+ now = datetime.now() # current date and time
64
+ date_time = now.strftime("%m/%d/%Y, %H:%M:%S")
65
+ print("---------date and time:",date_time)
66
+ f.write(date_time + ' - '+ str(self.score) + '\n')
67
+ reward = -100
68
+ done = True
69
+ if not status:
70
+ asyncio.get_event_loop().run_until_complete(self.dino.open_dino())
71
+ else:
72
+ reward = 2
73
+ #if action == 1: reward = 3
74
+ done = False
75
+
76
+ self.score+=reward
77
+
78
+ info = {}
79
+
80
+ #observation, reward, terminated, truncated, info
81
+ return self.state, reward, done, False, info
82
+
83
+
84
+ def reset(self, seed=None, options=None):
85
+ super().reset(seed=seed)
86
+
87
+ self.dino = Dino('runner-canvas')
88
+
89
+ asyncio.get_event_loop().run_until_complete(self.dino.start())
90
+
91
+ asyncio.get_event_loop().run_until_complete(self.dino.open_dino())
92
+
93
+ time.sleep(2)
94
+
95
+ asyncio.get_event_loop().run_until_complete(self.dino.send_key_event(" ", "Space", 32))
96
+
97
+ time.sleep(2)
98
+
99
+ asyncio.get_event_loop().run_until_complete(self.dino.capture_screenshot2())
100
+
101
+ self.state = self.get_screenshot()
102
+
103
+ info = {}
104
+
105
+ self.current_step = 0
106
+
107
+ self.score = 0
108
+
109
+ asyncio.get_event_loop().run_until_complete(self.dino.enable_all_obstacles())
110
+
111
+ return self.state, info
112
+
113
+ def get_screenshot(self):
114
+ image = asyncio.get_event_loop().run_until_complete(self.dino.capture_screenshot2())
115
+
116
+ gray_image = image.convert('L')
117
+
118
+ image_array = np.array(gray_image)
119
+
120
+ convolved_image = convolve(image_array, self.sobel_kernel)
121
+ # Get the current date and time
122
+ current_time = datetime.now()
123
+
124
+ # Format the date and time as a string
125
+ timestamp_string = current_time.strftime('%H:%M:%S')
126
+
127
+ image = Image.fromarray(convolved_image)
128
+
129
+ # Save the image
130
+ #image.save('im/' + timestamp_string + 'conv_image.png')
131
+
132
+ flattened_array = convolved_image.flatten()
133
+
134
+ return flattened_array
135
+
136
+ def render(self, mode='human'):
137
+ pass
138
+
139
+
140
+ # Register the custom environment
141
+ gym.envs.registration.register(
142
+ id='DinoEnv-v0',
143
+ entry_point=__name__ + ':DinoEnv',
144
+ )
145
+
models/ppo_dino_1000k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:262d370a4c6af4ce1de1fc87ce618e8f2f79a65cef5756001d5f3c61cd3a2dbc
3
+ size 14129512
models/ppo_dino_100k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2275e46dc4d90d5a3ab3b7348a28a1af0f61234dd78b6a9f745ca7b62bbd39b
3
+ size 14129459
models/ppo_dino_200k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd376e9c53002769c61e7d4464bdcc864b6b6f0b2695390779af87e52198ad4a
3
+ size 14129479
models/ppo_dino_300k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7f0099ccd3a945d7ed7e72176bb5dc10f6213dfceff5a3596d9dac7944465a2
3
+ size 14129480
models/ppo_dino_400k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f43c84fc9a2506bfd22632e325dc355701812d1539a4ac73f9a9328169c3fb4d
3
+ size 14129484
models/ppo_dino_500k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6cedc8fbf789ccc67d0d374f5ec7d421419b8627ed5f7abc7ed240e4adbfffe
3
+ size 14129492
models/ppo_dino_600k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92d5465e28ce410f8ecc34b8779e08fd4326212c0ed18baf8ca237ef2854ab12
3
+ size 14129480
models/ppo_dino_700k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e877d70cf77fc234f0ef9c44360de26c20d68d5355d586296cecbc33508b6ed8
3
+ size 14129488
models/ppo_dino_800k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f50e439a5ac9a6afa674cd267c98ce7921eadbca8df2c5e507ce686147b9c393
3
+ size 14129496
models/ppo_dino_900k.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ad4fa130ef150c2f715ae7aae5563168f543c51c51986b10530d0d7b423374b
3
+ size 14129496
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ asyncio
2
+ websockets
3
+ requests
4
+ mss
5
+ numpy
6
+ Pillow
7
+ pyautogui
8
+ gymnasium
9
+ scipy
10
+ stable-baselines3
11
+ tensorboardX
rf.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ import env
3
+ from stable_baselines3 import PPO
4
+ env = gym.make("DinoEnv-v0")
5
+ model = PPO("MlpPolicy", env, verbose=1)
6
+
7
+ # Load the saved PPO model (Optional)
8
+ #model = PPO.load("models/ppo_dino_backup8", env=env)
9
+ for i in range (100):
10
+ model.learn(total_timesteps=100000)
11
+ model.save("models/ppo_dino")
12
+ model.save("models/ppo_dino_backup" + str(i))
13
+
14
+ env.close()