mirror of
https://github.com/microsoft/fara.git
synced 2026-06-10 02:54:01 +08:00
fix multi-turn
This commit is contained in:
@@ -5,7 +5,6 @@ import json
|
||||
import ast
|
||||
import io
|
||||
import os
|
||||
import base64
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Dict
|
||||
from urllib.parse import quote_plus
|
||||
@@ -76,8 +75,8 @@ class FaraAgent:
|
||||
assert False, "downloads_folder must be set if save_screenshots is True"
|
||||
self.save_screenshots = save_screenshots
|
||||
self._facts = []
|
||||
self._action_history = []
|
||||
self._task_summary = None
|
||||
self._num_actions = 0
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self._mlm_width = 1440
|
||||
self._mlm_height = 900
|
||||
@@ -200,7 +199,11 @@ class FaraAgent:
|
||||
def maybe_remove_old_screenshots(
|
||||
self, history: List[LLMMessage], includes_current: bool = False
|
||||
) -> List[LLMMessage]:
|
||||
"""Remove old screenshots from the chat history. Assuming we have not yet added the current screenshot message."""
|
||||
"""Remove old screenshots from the chat history. Assuming we have not yet added the current screenshot message.
|
||||
|
||||
Note: Original user messages (marked with is_original=True) are NEVER removed from history,
|
||||
only boilerplate messages added during multi-round processing can be removed.
|
||||
"""
|
||||
if self.max_n_images <= 0:
|
||||
return history
|
||||
|
||||
@@ -210,6 +213,10 @@ class FaraAgent:
|
||||
for i in range(len(history) - 1, -1, -1):
|
||||
msg = history[i]
|
||||
|
||||
is_original_user_message = isinstance(msg, UserMessage) and getattr(
|
||||
msg, "is_original", False
|
||||
)
|
||||
|
||||
if i == 0 and n_images >= max_n_images:
|
||||
# First message is always the task so we keep it and remove the screenshot if necessary
|
||||
msg = self.remove_screenshot_from_message(msg)
|
||||
@@ -224,12 +231,14 @@ class FaraAgent:
|
||||
has_image = True
|
||||
break
|
||||
if has_image:
|
||||
if n_images < max_n_images:
|
||||
if is_original_user_message or n_images < max_n_images:
|
||||
new_history.append(msg)
|
||||
n_images += 1
|
||||
else:
|
||||
new_history.append(msg)
|
||||
elif isinstance(msg.content, ImageObj) and n_images < max_n_images:
|
||||
elif isinstance(msg.content, ImageObj) and (
|
||||
is_original_user_message or n_images < max_n_images
|
||||
):
|
||||
new_history.append(msg)
|
||||
n_images += 1
|
||||
else:
|
||||
@@ -239,6 +248,13 @@ class FaraAgent:
|
||||
|
||||
return new_history
|
||||
|
||||
async def _get_scaled_screenshot(self) -> Image.Image:
|
||||
"""Get current screenshot and scale it for the model."""
|
||||
screenshot = await self._playwright_controller.get_screenshot(self._page)
|
||||
screenshot = Image.open(io.BytesIO(screenshot))
|
||||
_, scaled_screenshot = self._get_system_message(screenshot)
|
||||
return scaled_screenshot
|
||||
|
||||
def _get_system_message(
|
||||
self, screenshot: ImageObj | Image.Image
|
||||
) -> Tuple[List[SystemMessage], Image.Image]:
|
||||
@@ -315,8 +331,24 @@ class FaraAgent:
|
||||
# Ensure page is ready after initialization
|
||||
assert self._page is not None, "Page should be initialized"
|
||||
|
||||
# Add user message to chat history
|
||||
self._chat_history.append(UserMessage(content=user_message))
|
||||
# Get initial screenshot and add user message with image to chat history
|
||||
scaled_screenshot = await self._get_scaled_screenshot()
|
||||
|
||||
if self.save_screenshots:
|
||||
await self._playwright_controller.get_screenshot(
|
||||
self._page,
|
||||
path=os.path.join(
|
||||
self.downloads_folder, f"screenshot{self._num_actions}.png"
|
||||
),
|
||||
)
|
||||
|
||||
self._chat_history.append(
|
||||
UserMessage(
|
||||
content=[ImageObj.from_pil(scaled_screenshot), user_message],
|
||||
is_original=True,
|
||||
)
|
||||
)
|
||||
|
||||
all_actions = []
|
||||
all_observations = []
|
||||
final_answer = "<no_answer>"
|
||||
@@ -333,16 +365,10 @@ class FaraAgent:
|
||||
raise RuntimeError(
|
||||
"Captcha timed out, unable to proceed with web surfing."
|
||||
)
|
||||
if is_first_round and self.save_screenshots:
|
||||
_ = await self._playwright_controller.get_screenshot(
|
||||
self._page,
|
||||
path=os.path.join(
|
||||
self.downloads_folder,
|
||||
f"screenshot{len(self._action_history)}.png",
|
||||
),
|
||||
)
|
||||
|
||||
function_call, raw_response = await self.generate_model_call(is_first_round)
|
||||
function_call, raw_response = await self.generate_model_call(
|
||||
is_first_round, scaled_screenshot if is_first_round else None
|
||||
)
|
||||
assert isinstance(raw_response, str)
|
||||
all_actions.append(raw_response)
|
||||
# Print the model response
|
||||
@@ -362,39 +388,16 @@ class FaraAgent:
|
||||
return final_answer, all_actions, all_observations
|
||||
|
||||
async def generate_model_call(
|
||||
self, is_first_round: bool
|
||||
self, is_first_round: bool, first_screenshot: Image.Image | None = None
|
||||
) -> Tuple[List[FunctionCall], str]:
|
||||
history: List[LLMMessage] = []
|
||||
action_turn = 0
|
||||
for i in range(len(self._chat_history)):
|
||||
m = self._chat_history[i]
|
||||
if isinstance(m, AssistantMessage) and m.source == "assistant":
|
||||
if action_turn >= len(self._action_history):
|
||||
raise RuntimeError(
|
||||
f"OUT OF SYNC: Action history is shorter than chat history agent turns.\n\nAction history: {self._action_history}\n\n"
|
||||
f"Chat history: {self._chat_history}"
|
||||
)
|
||||
else:
|
||||
history.append(self._action_history[action_turn])
|
||||
action_turn += 1
|
||||
else:
|
||||
history.append(m)
|
||||
history = self.maybe_remove_old_screenshots(history)
|
||||
history = self.maybe_remove_old_screenshots(self._chat_history)
|
||||
|
||||
screenshot_for_system = first_screenshot
|
||||
if not is_first_round:
|
||||
# Get screenshot and add new user message for subsequent rounds
|
||||
scaled_screenshot = await self._get_scaled_screenshot()
|
||||
screenshot_for_system = scaled_screenshot
|
||||
|
||||
# Get screenshot
|
||||
screenshot = await self._playwright_controller.get_screenshot(self._page)
|
||||
screenshot = Image.open(io.BytesIO(screenshot))
|
||||
system_message, scaled_screenshot = self._get_system_message(screenshot)
|
||||
if is_first_round:
|
||||
# Assumes first message is always a string
|
||||
text_prompt = self._chat_history[-1].content
|
||||
assert isinstance(text_prompt, str)
|
||||
self._chat_history[-1].content = [
|
||||
ImageObj.from_pil(scaled_screenshot),
|
||||
text_prompt,
|
||||
]
|
||||
history[-1].content = self._chat_history[-1].content
|
||||
else:
|
||||
text_prompt = self.USER_MESSAGE
|
||||
curr_url = await self._playwright_controller.get_page_url(self._page)
|
||||
trimmed_url = get_trimmed_url(curr_url, max_len=self.max_url_chars)
|
||||
@@ -406,14 +409,14 @@ class FaraAgent:
|
||||
self._chat_history.append(curr_message)
|
||||
history.append(curr_message)
|
||||
|
||||
# Generate system message using the screenshot
|
||||
system_message, _ = self._get_system_message(screenshot_for_system)
|
||||
history = system_message + history
|
||||
response = await self._make_model_call(
|
||||
history, extra_create_args={"temperature": 0}
|
||||
)
|
||||
message = response.content
|
||||
|
||||
self._action_history.append(AssistantMessage(content=message))
|
||||
# I ADDED LINE BELOW
|
||||
self._chat_history.append(AssistantMessage(content=message))
|
||||
thoughts, action = self._parse_thoughts_and_action(message)
|
||||
action["arguments"]["thoughts"] = thoughts
|
||||
@@ -563,29 +566,15 @@ class FaraAgent:
|
||||
raise ValueError(f"Unknown tool: {args['action']}")
|
||||
|
||||
await self._playwright_controller.wait_for_load_state(self._page)
|
||||
await self._playwright_controller.sleep(
|
||||
self._page, 3
|
||||
) # There's a 2s sleep below too
|
||||
|
||||
# Handle downloads
|
||||
if self._last_download is not None and self.downloads_folder is not None:
|
||||
fname = os.path.join(
|
||||
self.downloads_folder, self._last_download.suggested_filename
|
||||
)
|
||||
await self._last_download.save_as(fname) # type: ignore
|
||||
page_body = f"<html><head><title>Download Successful</title></head><body style=\"margin: 20px;\"><h1>Successfully downloaded '{self._last_download.suggested_filename}' to local path:<br><br>{fname}</h1></body></html>"
|
||||
await self._playwright_controller.visit_page(
|
||||
self._page,
|
||||
"data:text/html;base64,"
|
||||
+ base64.b64encode(page_body.encode("utf-8")).decode("utf-8"),
|
||||
)
|
||||
await self._playwright_controller.sleep(self._page, 3)
|
||||
|
||||
# Get new screenshot after action
|
||||
self._num_actions += 1
|
||||
if self.save_screenshots:
|
||||
new_screenshot = await self._playwright_controller.get_screenshot(
|
||||
self._page,
|
||||
path=os.path.join(
|
||||
self.downloads_folder, f"screenshot{len(self._action_history)}.png"
|
||||
self.downloads_folder, f"screenshot{self._num_actions}.png"
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -20,9 +20,15 @@ class SystemMessage(LLMMessage):
|
||||
|
||||
@dataclass
|
||||
class UserMessage(LLMMessage):
|
||||
def __init__(self, content: str | List[Dict[str, Any]], source: str = "user"):
|
||||
def __init__(
|
||||
self,
|
||||
content: str | List[Dict[str, Any]],
|
||||
source: str = "user",
|
||||
is_original: bool = False,
|
||||
):
|
||||
self.content = content
|
||||
self.source = source
|
||||
self.is_original = is_original
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user