mirror of
https://github.com/microsoft/fara.git
synced 2026-06-10 02:54:01 +08:00
fixes
This commit is contained in:
116
README.md
116
README.md
@@ -17,7 +17,7 @@
|
||||
|
||||
**Fara-7B** is Microsoft's first **agentic small language model (SLM)** designed specifically for computer use. With only 7 billion parameters, Fara-7B is an ultra-compact Computer Use Agent (CUA) that achieves state-of-the-art performance within its size class and is competitive with larger, more resource-intensive agentic systems.
|
||||
|
||||
Try Fara-7B locally as follows (see [Installation](##Installation) for detailed instructions) or via Magentic-UI:
|
||||
Try Fara-7B locally as follows (see [Installation](#Installation) for detailed instructions on Windows ) or via Magentic-UI:
|
||||
|
||||
```bash
|
||||
# 1. Clone repository
|
||||
@@ -44,7 +44,7 @@ To try Fara-7B inside Magentic-UI, please follow the instructions here [Magentic
|
||||
|
||||
|
||||
Notes:
|
||||
- If you're using Windows, we highly recommend using WSL2 (Windows Subsystem for Linux).
|
||||
- If you're using Windows, we highly recommend using WSL2 (Windows Subsystem for Linux). Please the Windows instructions in the [Installation](#Installation) section.
|
||||
- You might need to do `--tensor-parallel-size 2` with vllm command if you run out of memory
|
||||
|
||||
<table>
|
||||
@@ -156,27 +156,45 @@ Our evaluation setup leverages:
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
# Installation
|
||||
|
||||
Install the package using either UV or pip:
|
||||
|
||||
## Linux
|
||||
|
||||
The following instructions are for Linux systems, see the Windows section below for Windows instructions.
|
||||
|
||||
Install the package using pip and set up the environment with Playwright:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
# 1. Clone repository
|
||||
git clone https://github.com/microsoft/fara.git
|
||||
cd fara
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then install Playwright browsers:
|
||||
|
||||
```bash
|
||||
# 2. Setup environment
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e .[vllm]
|
||||
playwright install
|
||||
```
|
||||
|
||||
---
|
||||
Note: If you plan on hosting with Azure Foundry only, you can skip the `[vllm]` and just do `pip install -e .`
|
||||
|
||||
|
||||
## Windows
|
||||
|
||||
For Windows, we highly recommend using WSL2 (Windows Subsystem for Linux) to provide a Linux-like environment. However, if you prefer to run natively on Windows, follow these steps:
|
||||
|
||||
```bash
|
||||
# 1. Clone repository
|
||||
git clone https://github.com/microsoft/fara.git
|
||||
cd fara
|
||||
|
||||
# 2. Setup environment
|
||||
python3 -m venv .venv
|
||||
.venv\Scripts\activate
|
||||
pip install -e .
|
||||
python3 -m playwright install
|
||||
```
|
||||
|
||||
## Hosting the Model
|
||||
|
||||
@@ -189,11 +207,10 @@ Deploy Fara-7B on [Azure Foundry](https://ai.azure.com/explore/models/Fara-7B/ve
|
||||
**Setup:**
|
||||
|
||||
1. Deploy the Fara-7B model on Azure Foundry and obtain your endpoint URL and API key
|
||||
2. Add your endpoint details to the existing `endpoint_configs/` directory (example configs are already provided):
|
||||
|
||||
```bash
|
||||
# Edit one of the existing config files or create a new one
|
||||
# endpoint_configs/fara-7b-hosting-ansrz.json (example format):
|
||||
Then create a endpoint configuration JSON file (e.g., `azure_foundry_config.json`):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "Fara-7B",
|
||||
"base_url": "https://your-endpoint.inference.ml.azure.com/",
|
||||
@@ -201,61 +218,54 @@ Deploy Fara-7B on [Azure Foundry](https://ai.azure.com/explore/models/Fara-7B/ve
|
||||
}
|
||||
```
|
||||
|
||||
3. Run the Fara agent:
|
||||
Then you can run Fara-7B using this endpoint configuration.
|
||||
|
||||
2. Run the Fara agent:
|
||||
|
||||
```bash
|
||||
fara-cli --task "how many pages does wikipedia have" --start_page "https://www.bing.com"
|
||||
fara-cli --task "how many pages does wikipedia have" --endpoint_config azure_foundry_config.json [--headful]
|
||||
```
|
||||
|
||||
Note: you can also specify the endpoint config with the args `--base_url [your_base_url] --api_key [your_api_key] --model [your_model_name]` instead of using a config JSON file.
|
||||
|
||||
Note: If you see an error that the `fara-cli` command is not found, then try:
|
||||
|
||||
```bash
|
||||
python -m fara.run_fara --task "what is the weather in new york now"
|
||||
```
|
||||
|
||||
That's it! No GPU or model downloads required.
|
||||
|
||||
### Self-hosting with VLLM
|
||||
### Self-hosting with vLLM or LM Studio / Ollama
|
||||
|
||||
If you have access to GPU resources, you can self-host Fara-7B using VLLM. This requires a GPU machine with sufficient VRAM.
|
||||
**If you have access to GPU resources, you can self-host Fara-7B using vLLM. This requires a GPU machine with sufficient VRAM (e.g., 24GB or more).**
|
||||
|
||||
All that is required is to run the following command to start the VLLM server:
|
||||
Only on Linux: all that is required is to run the following command to start the VLLM server:
|
||||
|
||||
```bash
|
||||
vllm serve "microsoft/Fara-7B" --port 5000 --dtype auto
|
||||
```
|
||||
For quantized models or lower VRAM GPUs, please see [Fara-7B GGUF on HuggingFace](https://huggingface.co/bartowski/microsoft_Fara-7B-GGUF).
|
||||
|
||||
### Testing the Fara Agent
|
||||
** For Windows/Mac, vLLM is not natively supported. You can use WSL2 on Windows to run the above command or LM Studio / Ollama as described below. **
|
||||
|
||||
Otherwise, you can use [LM Studio](https://lmstudio.ai/) or [Ollama](https://ollama.com/) to host the model locally. We currently recommend the following GGUF versions of our models [Fara-7B GGUF on HuggingFace](https://huggingface.co/bartowski/microsoft_Fara-7B-GGUF) for use with LM Studio or Ollama. Select the largest model that fits your GPU. Please ensure that context length is set to at least 15000 tokens and temperature to 0 for best results.
|
||||
|
||||
Then you can run Fara-7B pointing to your local server:
|
||||
|
||||
Run the test script to see Fara in action:
|
||||
|
||||
```bash
|
||||
fara-cli --task "how many pages does wikipedia have" --start_page "https://www.bing.com" --endpoint_config endpoint_configs/azure_foundry_config.json [--headful] [--downloads_folder "/path/to/downloads"] [--save_screenshots] [--max_rounds 100] [--browserbase]
|
||||
fara-cli --task "what is the weather in new york now"
|
||||
```
|
||||
|
||||
In self-hosting scenario the `endpoint_config` points to `endpoint_configs/vllm_config.json` from the VLLM server above.
|
||||
If you didn't use vLLM to host, please specify the correct `--base_url [your_base_url] --api_key [your_api_key] --model [your_model_name]`
|
||||
|
||||
If you set `--browserbase`, export environment variables for the API key and project ID.
|
||||
|
||||
#### Expected Output
|
||||
If you see an error that the `fara-cli` command is not found, then try:
|
||||
|
||||
```bash
|
||||
python -m fara.run_fara --task "what is the weather in new york now"
|
||||
```
|
||||
Initializing Browser...
|
||||
Browser Running... Starting Fara Agent...
|
||||
##########################################
|
||||
Task: how many pages does wikipedia have
|
||||
##########################################
|
||||
Running Fara...
|
||||
|
||||
|
||||
Thought #1: To find the current number of Wikipedia pages, I'll search for the latest Wikipedia page count statistics.
|
||||
Action #1: executing tool 'web_search' with arguments {"action": "web_search", "query": "Wikipedia total number of articles"}
|
||||
Observation#1: I typed 'Wikipedia total number of articles' into the browser search bar.
|
||||
|
||||
Thought #2: Wikipedia currently has 7,095,446 articles.
|
||||
Action #2: executing tool 'terminate' with arguments {"action": "terminate", "status": "success"}
|
||||
Observation#2: Wikipedia currently has 7,095,446 articles.
|
||||
|
||||
Final Answer: Wikipedia currently has 7,095,446 articles.
|
||||
|
||||
Enter another task (or press Enter to exit):
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Reproducibility
|
||||
|
||||
|
||||
@@ -32,14 +32,20 @@ dependencies = [
|
||||
"pyyaml",
|
||||
"jsonschema",
|
||||
"browserbase",
|
||||
"vllm>=0.10.0"
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/microsoft/fara"
|
||||
Repository = "https://github.com/microsoft/fara"
|
||||
Issues = "https://github.com/microsoft/fara/issues"
|
||||
|
||||
[project.optional-dependencies]
|
||||
vllm = ["vllm>=0.10.0"]
|
||||
lmstudio = ["lmstudio"]
|
||||
ollama = ["ollama"]
|
||||
|
||||
[project.scripts]
|
||||
fara-cli = "fara.run_fara:main"
|
||||
|
||||
@@ -6,7 +6,7 @@ import signal
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import platform
|
||||
import browserbase
|
||||
from browserbase import Browserbase
|
||||
from playwright.async_api import (
|
||||
@@ -48,7 +48,7 @@ class BrowserBB:
|
||||
self.single_tab_mode = single_tab_mode
|
||||
self.use_browser_base = use_browser_base
|
||||
self.logger = logger or logging.getLogger("browser_manager")
|
||||
|
||||
self.is_linux = platform.system() == "Linux"
|
||||
self._viewport_height = viewport_height
|
||||
self._viewport_width = viewport_width
|
||||
|
||||
@@ -194,7 +194,8 @@ class BrowserBB:
|
||||
|
||||
async def _init_regular_browser(self, channel: str = "chromium") -> None:
|
||||
"""Initialize regular browser according to the specified channel."""
|
||||
if not self.headless:
|
||||
if not self.headless and self.is_linux:
|
||||
print("STARTING XVFB")
|
||||
self.start_xvfb()
|
||||
|
||||
launch_args: Dict[str, Any] = {"headless": self.headless}
|
||||
@@ -218,7 +219,7 @@ class BrowserBB:
|
||||
|
||||
async def _init_persistent_browser(self) -> None:
|
||||
"""Initialize persistent browser with data directory."""
|
||||
if not self.headless:
|
||||
if not self.headless and self.is_linux:
|
||||
self.start_xvfb()
|
||||
|
||||
launch_args: Dict[str, Any] = {"headless": self.headless}
|
||||
|
||||
@@ -15,7 +15,7 @@ from playwright.async_api import BrowserContext
|
||||
import asyncio
|
||||
from .browser.playwright_controller import PlaywrightController
|
||||
from ._prompts import get_computer_use_system_prompt
|
||||
from .types import (
|
||||
from .fara_types import (
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
@@ -379,15 +379,20 @@ class FaraAgent:
|
||||
thoughts, action_dict = self._parse_thoughts_and_action(raw_response)
|
||||
action_args = action_dict.get("arguments", {})
|
||||
action = action_args["action"]
|
||||
self.logger.info(f"\nThought #{i+1}: {thoughts}\nAction #{i+1}: executing tool '{action}' with arguments {json.dumps(action_args)}")
|
||||
|
||||
self.logger.debug(
|
||||
f"\nThought #{i+1}: {thoughts}\nAction #{i+1}: executing tool '{action}' with arguments {json.dumps(action_args)}"
|
||||
)
|
||||
print(
|
||||
f"\nThought #{i+1}: {thoughts}\nAction #{i+1}: executing tool '{action}' with arguments {json.dumps(action_args)}"
|
||||
)
|
||||
(
|
||||
is_stop_action,
|
||||
new_screenshot,
|
||||
action_description,
|
||||
) = await self.execute_action(function_call)
|
||||
all_observations.append(action_description)
|
||||
self.logger.info(f"Observation#{i+1}: {action_description}")
|
||||
self.logger.debug(f"Observation#{i+1}: {action_description}")
|
||||
print(f"Observation#{i+1}: {action_description}")
|
||||
if is_stop_action:
|
||||
final_answer = thoughts
|
||||
break
|
||||
@@ -564,7 +569,7 @@ class FaraAgent:
|
||||
elif args["action"] == "pause_and_memorize_fact":
|
||||
fact = str(args.get("fact"))
|
||||
self._facts.append(fact)
|
||||
action_description= f"I memorized the following fact: {fact}"
|
||||
action_description = f"I memorized the following fact: {fact}"
|
||||
elif args["action"] == "stop" or args["action"] == "terminate":
|
||||
action_description = args.get("thoughts")
|
||||
is_stop_action = True
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
from fara import FaraAgent
|
||||
from fara.browser.browser_bb import BrowserBB
|
||||
from .fara_agent import FaraAgent
|
||||
from .browser.browser_bb import BrowserBB
|
||||
import logging
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
@@ -11,8 +11,8 @@ import json
|
||||
|
||||
# Configure logging to only show logs from fara.fara_agent
|
||||
logging.basicConfig(
|
||||
level=logging.CRITICAL,
|
||||
format="%(message)s",
|
||||
level=logging.CRITICAL,
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
# Enable INFO level only for fara.fara_agent
|
||||
@@ -159,21 +159,51 @@ def main():
|
||||
default=None,
|
||||
help="Path to the endpoint configuration JSON file. By default, tries local vllm on 5000 port",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api_key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for the model endpoint (overrides endpoint_config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for the model endpoint (overrides endpoint_config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model name to use (overrides endpoint_config)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.browserbase:
|
||||
assert os.environ.get("BROWSERBASE_API_KEY"), (
|
||||
"BROWSERBASE_API_KEY environment variable must be set to use browserbase"
|
||||
)
|
||||
assert os.environ.get("BROWSERBASE_PROJECT_ID"), (
|
||||
"BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID environment variables must be set to use browserbase"
|
||||
)
|
||||
assert os.environ.get(
|
||||
"BROWSERBASE_API_KEY"
|
||||
), "BROWSERBASE_API_KEY environment variable must be set to use browserbase"
|
||||
assert os.environ.get(
|
||||
"BROWSERBASE_PROJECT_ID"
|
||||
), "BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID environment variables must be set to use browserbase"
|
||||
|
||||
endpoint_config = DEFAULT_ENDPOINT_CONFIG
|
||||
if args.endpoint_config:
|
||||
with open(args.endpoint_config, "r") as f:
|
||||
endpoint_config = json.load(f)
|
||||
assert (
|
||||
"api_key" in endpoint_config
|
||||
and "base_url" in endpoint_config
|
||||
and "model" in endpoint_config
|
||||
), "endpoint_config file must contain api_key, base_url, and model fields"
|
||||
# Override with command-line arguments if provided
|
||||
if args.api_key:
|
||||
endpoint_config["api_key"] = args.api_key
|
||||
if args.base_url:
|
||||
endpoint_config["base_url"] = args.base_url
|
||||
if args.model:
|
||||
endpoint_config["model"] = args.model
|
||||
|
||||
asyncio.run(
|
||||
run_fara_agent(
|
||||
|
||||
@@ -28,10 +28,15 @@ DEFAULT_HF_MODEL_ID = "microsoft/Fara-7B"
|
||||
|
||||
|
||||
def _is_azure_blob_url(model_path: str) -> bool:
|
||||
return model_path.startswith(("https://", "http://")) and "blob.core.windows.net" in model_path
|
||||
return (
|
||||
model_path.startswith(("https://", "http://"))
|
||||
and "blob.core.windows.net" in model_path
|
||||
)
|
||||
|
||||
|
||||
def _download_model_from_hf(output_dir: Path, model_id: str = DEFAULT_HF_MODEL_ID) -> str:
|
||||
def _download_model_from_hf(
|
||||
output_dir: Path, model_id: str = DEFAULT_HF_MODEL_ID
|
||||
) -> str:
|
||||
"""Download model from HuggingFace Hub if not already present."""
|
||||
if snapshot_download is None:
|
||||
raise ImportError(
|
||||
@@ -63,13 +68,15 @@ def _download_model_from_hf(output_dir: Path, model_id: str = DEFAULT_HF_MODEL_I
|
||||
|
||||
def _extract_model_name(model_url: str) -> str:
|
||||
"""Extract model name from URL for consistent naming."""
|
||||
url_parts = model_url.rstrip('/').split('/')
|
||||
url_parts = model_url.rstrip("/").split("/")
|
||||
return url_parts[-1] if url_parts else model_url
|
||||
|
||||
|
||||
def _cache_model(model_url: str) -> str:
|
||||
if AzFolder is None:
|
||||
raise RuntimeError("Azure support not available. Install aztool or run without --cache.")
|
||||
raise RuntimeError(
|
||||
"Azure support not available. Install aztool or run without --cache."
|
||||
)
|
||||
|
||||
cache_root = Path(args.cache_dir or os.path.expanduser("~/.cache/vllm_models"))
|
||||
cache_root.mkdir(parents=True, exist_ok=True)
|
||||
@@ -120,8 +127,18 @@ def _prepare_cached_model(model_url: str) -> str:
|
||||
raise FileNotFoundError(f"Local model directory not found: {model_url}")
|
||||
return str(model_path.resolve())
|
||||
|
||||
|
||||
class AzVllm:
|
||||
def __init__(self, model_url, port, device_id, max_n_images, dtype='auto', enforce_eager=False, use_external_endpoint=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_url,
|
||||
port,
|
||||
device_id,
|
||||
max_n_images,
|
||||
dtype="auto",
|
||||
enforce_eager=False,
|
||||
use_external_endpoint=False,
|
||||
):
|
||||
self.model_az = None
|
||||
self.local_model_path = None
|
||||
self.vllm = None
|
||||
@@ -141,7 +158,9 @@ class AzVllm:
|
||||
if not model_path.exists():
|
||||
# Auto-download from HuggingFace if path doesn't exist
|
||||
logging.warning(f"Local model directory not found: {model_url}")
|
||||
logging.info(f"Attempting to download {DEFAULT_HF_MODEL_ID} from HuggingFace...")
|
||||
logging.info(
|
||||
f"Attempting to download {DEFAULT_HF_MODEL_ID} from HuggingFace..."
|
||||
)
|
||||
self.local_model_path = _download_model_from_hf(model_path)
|
||||
else:
|
||||
self.local_model_path = str(model_path.resolve())
|
||||
@@ -150,7 +169,7 @@ class AzVllm:
|
||||
def __enter__(self):
|
||||
# No-op if using external endpoint
|
||||
if self.use_external_endpoint:
|
||||
print('Using external endpoint, skipping VLLM startup')
|
||||
print("Using external endpoint, skipping VLLM startup")
|
||||
return self
|
||||
|
||||
if self.model_az:
|
||||
@@ -162,33 +181,35 @@ class AzVllm:
|
||||
for file in files:
|
||||
print(f"\t{os.path.join(root, file)}")
|
||||
self.vllm = VLLM(
|
||||
model_path = self.context.path,
|
||||
port = self.port,
|
||||
device_id = self.device_id,
|
||||
max_n_images = self.max_n_images,
|
||||
dtype = self.dtype,
|
||||
enforce_eager = self.enforce_eager
|
||||
model_path=self.context.path,
|
||||
port=self.port,
|
||||
device_id=self.device_id,
|
||||
max_n_images=self.max_n_images,
|
||||
dtype=self.dtype,
|
||||
enforce_eager=self.enforce_eager,
|
||||
)
|
||||
self.vllm.start()
|
||||
print('VLLM has started')
|
||||
print("VLLM has started")
|
||||
elif self.local_model_path:
|
||||
print(f"VLLM using on-disk model at path {self.local_model_path}, contents:")
|
||||
print(
|
||||
f"VLLM using on-disk model at path {self.local_model_path}, contents:"
|
||||
)
|
||||
### sometimes need to ls the directory or else huggingface will complain a config.json doesn't exist
|
||||
for root, dirs, files in os.walk(self.local_model_path):
|
||||
for file in files:
|
||||
print(f"\t{os.path.join(root, file)}")
|
||||
self.vllm = VLLM(
|
||||
model_path = self.local_model_path,
|
||||
port = self.port,
|
||||
device_id = self.device_id,
|
||||
max_n_images = self.max_n_images,
|
||||
dtype = self.dtype,
|
||||
enforce_eager = self.enforce_eager
|
||||
model_path=self.local_model_path,
|
||||
port=self.port,
|
||||
device_id=self.device_id,
|
||||
max_n_images=self.max_n_images,
|
||||
dtype=self.dtype,
|
||||
enforce_eager=self.enforce_eager,
|
||||
)
|
||||
self.vllm.start()
|
||||
print('VLLM has started')
|
||||
print("VLLM has started")
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.vllm:
|
||||
if self.vllm and (self.vllm.status == Status.Running):
|
||||
@@ -196,6 +217,7 @@ class AzVllm:
|
||||
if self.context:
|
||||
self.context.unmount()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
cached_vllm: Optional[VLLM] = None
|
||||
@@ -214,17 +236,18 @@ async def lifespan(app: FastAPI):
|
||||
device_id=args.device_id,
|
||||
max_n_images=args.max_n_images,
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager
|
||||
enforce_eager=args.enforce_eager,
|
||||
)
|
||||
cached_vllm.start()
|
||||
else:
|
||||
az_vllm = AzVllm(
|
||||
model_url = args.model_url,
|
||||
port = args.vllm_port,
|
||||
device_id = args.device_id,
|
||||
max_n_images = args.max_n_images,
|
||||
dtype = args.dtype,
|
||||
enforce_eager = args.enforce_eager)
|
||||
model_url=args.model_url,
|
||||
port=args.vllm_port,
|
||||
device_id=args.device_id,
|
||||
max_n_images=args.max_n_images,
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager,
|
||||
)
|
||||
az_vllm.__enter__()
|
||||
app.state.resolved_model_path = args.model_url
|
||||
app.state.model_name = _extract_model_name(args.model_url)
|
||||
@@ -239,7 +262,7 @@ async def lifespan(app: FastAPI):
|
||||
app.state.model_name = None
|
||||
|
||||
|
||||
app = FastAPI(lifespan = lifespan)
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@@ -247,21 +270,18 @@ async def post_v1_chat_completions(request: Request):
|
||||
body = await request.body()
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f'http://localhost:{args.vllm_port}/v1/chat/completions',
|
||||
f"http://localhost:{args.vllm_port}/v1/chat/completions",
|
||||
content=body,
|
||||
headers=dict(request.headers),
|
||||
timeout=None
|
||||
timeout=None,
|
||||
)
|
||||
return Response(
|
||||
content=resp.content,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers
|
||||
content=resp.content, status_code=resp.status_code, headers=resp.headers
|
||||
)
|
||||
|
||||
|
||||
@app.get("/model")
|
||||
async def get_model():
|
||||
|
||||
return {"model": _extract_model_name(args.model_url), "model_url": args.model_url}
|
||||
|
||||
|
||||
@@ -271,12 +291,37 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--port", type=int, default=5000, help="port")
|
||||
parser.add_argument("--vllm_port", type=int, default=5001, help="vllm port")
|
||||
parser.add_argument("--device_id", type=str, default="0", help="device id")
|
||||
parser.add_argument("--max_n_images", type=int, default=3, help="Maximum number of images to process")
|
||||
parser.add_argument('--dtype', type=str, choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], default='auto', help='Data type for VLLM model (default: auto)')
|
||||
parser.add_argument('--enforce_eager', action='store_true', help='Enforce eager execution mode for compatibility')
|
||||
parser.add_argument('--cache', action='store_true', help='Enable caching / local path serving instead of Azure mount')
|
||||
parser.add_argument('--cache_dir', type=str, default=None, help='Directory to cache downloaded models (default: ~/.cache/vllm_models)')
|
||||
parser.add_argument(
|
||||
"--max_n_images",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of images to process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
|
||||
default="auto",
|
||||
help="Data type for VLLM model (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enforce_eager",
|
||||
action="store_true",
|
||||
help="Enforce eager execution mode for compatibility",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache",
|
||||
action="store_true",
|
||||
help="Enable caching / local path serving instead of Azure mount",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to cache downloaded models (default: ~/.cache/vllm_models)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
||||
@@ -5,32 +5,39 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
NotStarted = 0
|
||||
Running = 1
|
||||
Stopped = 2
|
||||
|
||||
|
||||
class VLLM:
|
||||
cmd_template = ' '.join([
|
||||
"python -O -u -m vllm.entrypoints.openai.api_server",
|
||||
"--host={host}",
|
||||
"--port={port}",
|
||||
"--model={model_dir}",
|
||||
"--served-model-name {model_name}",
|
||||
"--tensor-parallel-size {tensor_parallel_size}",
|
||||
"--gpu-memory-utilization 0.95",
|
||||
"--trust-remote-code",
|
||||
"--dtype {dtype}"
|
||||
])
|
||||
def __init__(self,
|
||||
model_path,
|
||||
max_n_images,
|
||||
device_id = "0",
|
||||
host = "0.0.0.0",
|
||||
port = 5000,
|
||||
model_name = "gpt-4o-mini-2024-07-18",
|
||||
dtype = "auto",
|
||||
enforce_eager = False):
|
||||
cmd_template = " ".join(
|
||||
[
|
||||
"python -O -u -m vllm.entrypoints.openai.api_server",
|
||||
"--host={host}",
|
||||
"--port={port}",
|
||||
"--model={model_dir}",
|
||||
"--served-model-name {model_name}",
|
||||
"--tensor-parallel-size {tensor_parallel_size}",
|
||||
"--gpu-memory-utilization 0.95",
|
||||
"--trust-remote-code",
|
||||
"--dtype {dtype}",
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
max_n_images,
|
||||
device_id="0",
|
||||
host="0.0.0.0",
|
||||
port=5000,
|
||||
model_name="gpt-4o-mini-2024-07-18",
|
||||
dtype="auto",
|
||||
enforce_eager=False,
|
||||
):
|
||||
self.model_path = model_path
|
||||
self.device_id = device_id
|
||||
self.host = host
|
||||
@@ -43,10 +50,12 @@ class VLLM:
|
||||
# new versions of vllm require dictionary-like arguments for this
|
||||
# see https://docs.vllm.ai/en/latest/configuration/engine_args.html#multimodalconfig
|
||||
self.cmd += f" --limit-mm-per-prompt.image {self.max_n_images}"
|
||||
if enforce_eager: # Most helpful for float32 cases when attention backends are incompatible
|
||||
if (
|
||||
enforce_eager
|
||||
): # Most helpful for float32 cases when attention backends are incompatible
|
||||
self.cmd += " --enforce-eager"
|
||||
self.model_name = model_name
|
||||
self.tensor_parallel_size = len(str(device_id).split(','))
|
||||
self.tensor_parallel_size = len(str(device_id).split(","))
|
||||
self.status = Status.NotStarted
|
||||
self.process = None
|
||||
self.logs = []
|
||||
@@ -57,12 +66,13 @@ class VLLM:
|
||||
|
||||
def start(self):
|
||||
def _drain(pipe):
|
||||
for line in iter(pipe.readline, ''):
|
||||
for line in iter(pipe.readline, ""):
|
||||
self.logs.append(line)
|
||||
print(line, end='')
|
||||
print(line, end="")
|
||||
|
||||
env = os.environ.copy()
|
||||
env['CUDA_VISIBLE_DEVICES'] = self.device_id
|
||||
env['NCCL_DEBUG'] = "TRACE"
|
||||
env["CUDA_VISIBLE_DEVICES"] = self.device_id
|
||||
env["NCCL_DEBUG"] = "TRACE"
|
||||
self.process = subprocess.Popen(
|
||||
self.cmd.format(
|
||||
host=self.host,
|
||||
@@ -70,13 +80,13 @@ class VLLM:
|
||||
model_dir=self.model_path,
|
||||
model_name=self.model_name,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
dtype=self.dtype
|
||||
dtype=self.dtype,
|
||||
).split(),
|
||||
stdout = subprocess.PIPE,
|
||||
stderr = subprocess.STDOUT,
|
||||
text = True,
|
||||
shell = False,
|
||||
env = env
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
shell=False,
|
||||
env=env,
|
||||
)
|
||||
t = threading.Thread(target=_drain, args=(self.process.stdout,), daemon=True)
|
||||
t.start()
|
||||
@@ -88,9 +98,9 @@ class VLLM:
|
||||
if "Application startup complete." in line:
|
||||
logging.info("VLLM process started successfully.")
|
||||
self.status = Status.Running
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
self.status = Status.Stopped
|
||||
self.status = Status.Stopped
|
||||
|
||||
Reference in New Issue
Block a user