fix multi-turn

This commit is contained in:
Hussein Mozannar
2025-11-25 13:36:24 -08:00
parent 38c58ca681
commit e6a0174662
2 changed files with 62 additions and 67 deletions

View File

@@ -5,7 +5,6 @@ import json
import ast
import io
import os
import base64
from PIL import Image
from typing import List, Tuple, Dict
from urllib.parse import quote_plus
@@ -76,8 +75,8 @@ class FaraAgent:
assert False, "downloads_folder must be set if save_screenshots is True"
self.save_screenshots = save_screenshots
self._facts = []
self._action_history = []
self._task_summary = None
self._num_actions = 0
self.logger = logger or logging.getLogger(__name__)
self._mlm_width = 1440
self._mlm_height = 900
@@ -200,7 +199,11 @@ class FaraAgent:
def maybe_remove_old_screenshots(
self, history: List[LLMMessage], includes_current: bool = False
) -> List[LLMMessage]:
"""Remove old screenshots from the chat history. Assuming we have not yet added the current screenshot message."""
"""Remove old screenshots from the chat history. Assuming we have not yet added the current screenshot message.
Note: Original user messages (marked with is_original=True) are NEVER removed from history,
only boilerplate messages added during multi-round processing can be removed.
"""
if self.max_n_images <= 0:
return history
@@ -210,6 +213,10 @@ class FaraAgent:
for i in range(len(history) - 1, -1, -1):
msg = history[i]
is_original_user_message = isinstance(msg, UserMessage) and getattr(
msg, "is_original", False
)
if i == 0 and n_images >= max_n_images:
# First message is always the task so we keep it and remove the screenshot if necessary
msg = self.remove_screenshot_from_message(msg)
@@ -224,12 +231,14 @@ class FaraAgent:
has_image = True
break
if has_image:
if n_images < max_n_images:
if is_original_user_message or n_images < max_n_images:
new_history.append(msg)
n_images += 1
else:
new_history.append(msg)
elif isinstance(msg.content, ImageObj) and n_images < max_n_images:
elif isinstance(msg.content, ImageObj) and (
is_original_user_message or n_images < max_n_images
):
new_history.append(msg)
n_images += 1
else:
@@ -239,6 +248,13 @@ class FaraAgent:
return new_history
async def _get_scaled_screenshot(self) -> Image.Image:
"""Get current screenshot and scale it for the model."""
screenshot = await self._playwright_controller.get_screenshot(self._page)
screenshot = Image.open(io.BytesIO(screenshot))
_, scaled_screenshot = self._get_system_message(screenshot)
return scaled_screenshot
def _get_system_message(
self, screenshot: ImageObj | Image.Image
) -> Tuple[List[SystemMessage], Image.Image]:
@@ -315,8 +331,24 @@ class FaraAgent:
# Ensure page is ready after initialization
assert self._page is not None, "Page should be initialized"
# Add user message to chat history
self._chat_history.append(UserMessage(content=user_message))
# Get initial screenshot and add user message with image to chat history
scaled_screenshot = await self._get_scaled_screenshot()
if self.save_screenshots:
await self._playwright_controller.get_screenshot(
self._page,
path=os.path.join(
self.downloads_folder, f"screenshot{self._num_actions}.png"
),
)
self._chat_history.append(
UserMessage(
content=[ImageObj.from_pil(scaled_screenshot), user_message],
is_original=True,
)
)
all_actions = []
all_observations = []
final_answer = "<no_answer>"
@@ -333,16 +365,10 @@ class FaraAgent:
raise RuntimeError(
"Captcha timed out, unable to proceed with web surfing."
)
if is_first_round and self.save_screenshots:
_ = await self._playwright_controller.get_screenshot(
self._page,
path=os.path.join(
self.downloads_folder,
f"screenshot{len(self._action_history)}.png",
),
)
function_call, raw_response = await self.generate_model_call(is_first_round)
function_call, raw_response = await self.generate_model_call(
is_first_round, scaled_screenshot if is_first_round else None
)
assert isinstance(raw_response, str)
all_actions.append(raw_response)
# Print the model response
@@ -362,39 +388,16 @@ class FaraAgent:
return final_answer, all_actions, all_observations
async def generate_model_call(
self, is_first_round: bool
self, is_first_round: bool, first_screenshot: Image.Image | None = None
) -> Tuple[List[FunctionCall], str]:
history: List[LLMMessage] = []
action_turn = 0
for i in range(len(self._chat_history)):
m = self._chat_history[i]
if isinstance(m, AssistantMessage) and m.source == "assistant":
if action_turn >= len(self._action_history):
raise RuntimeError(
f"OUT OF SYNC: Action history is shorter than chat history agent turns.\n\nAction history: {self._action_history}\n\n"
f"Chat history: {self._chat_history}"
)
else:
history.append(self._action_history[action_turn])
action_turn += 1
else:
history.append(m)
history = self.maybe_remove_old_screenshots(history)
history = self.maybe_remove_old_screenshots(self._chat_history)
screenshot_for_system = first_screenshot
if not is_first_round:
# Get screenshot and add new user message for subsequent rounds
scaled_screenshot = await self._get_scaled_screenshot()
screenshot_for_system = scaled_screenshot
# Get screenshot
screenshot = await self._playwright_controller.get_screenshot(self._page)
screenshot = Image.open(io.BytesIO(screenshot))
system_message, scaled_screenshot = self._get_system_message(screenshot)
if is_first_round:
# Assumes first message is always a string
text_prompt = self._chat_history[-1].content
assert isinstance(text_prompt, str)
self._chat_history[-1].content = [
ImageObj.from_pil(scaled_screenshot),
text_prompt,
]
history[-1].content = self._chat_history[-1].content
else:
text_prompt = self.USER_MESSAGE
curr_url = await self._playwright_controller.get_page_url(self._page)
trimmed_url = get_trimmed_url(curr_url, max_len=self.max_url_chars)
@@ -406,14 +409,14 @@ class FaraAgent:
self._chat_history.append(curr_message)
history.append(curr_message)
# Generate system message using the screenshot
system_message, _ = self._get_system_message(screenshot_for_system)
history = system_message + history
response = await self._make_model_call(
history, extra_create_args={"temperature": 0}
)
message = response.content
self._action_history.append(AssistantMessage(content=message))
# I ADDED LINE BELOW
self._chat_history.append(AssistantMessage(content=message))
thoughts, action = self._parse_thoughts_and_action(message)
action["arguments"]["thoughts"] = thoughts
@@ -563,29 +566,15 @@ class FaraAgent:
raise ValueError(f"Unknown tool: {args['action']}")
await self._playwright_controller.wait_for_load_state(self._page)
await self._playwright_controller.sleep(
self._page, 3
) # There's a 2s sleep below too
# Handle downloads
if self._last_download is not None and self.downloads_folder is not None:
fname = os.path.join(
self.downloads_folder, self._last_download.suggested_filename
)
await self._last_download.save_as(fname) # type: ignore
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
await self._playwright_controller.visit_page(
self._page,
"data:text/html;base64,"
+ base64.b64encode(page_body.encode("utf-8")).decode("utf-8"),
)
await self._playwright_controller.sleep(self._page, 3)
# Get new screenshot after action
self._num_actions += 1
if self.save_screenshots:
new_screenshot = await self._playwright_controller.get_screenshot(
self._page,
path=os.path.join(
self.downloads_folder, f"screenshot{len(self._action_history)}.png"
self.downloads_folder, f"screenshot{self._num_actions}.png"
),
)
else:

View File

@@ -20,9 +20,15 @@ class SystemMessage(LLMMessage):
@dataclass
class UserMessage(LLMMessage):
def __init__(self, content: str | List[Dict[str, Any]], source: str = "user"):
def __init__(
self,
content: str | List[Dict[str, Any]],
source: str = "user",
is_original: bool = False,
):
self.content = content
self.source = source
self.is_original = is_original
@dataclass