oproxy/oproxy.py

241 lines
7.8 KiB
Python
Executable file

# バージョン情報: Python 3.12+ / FastAPI 0.115.0 / uvicorn 0.30.0 / httpx 0.28.0
import argparse
import asyncio
import json
import os
import sys
import threading
import unicodedata
from contextlib import asynccontextmanager
from datetime import datetime
import httpx
import uvicorn
from fastapi import FastAPI, Request
from starlette.responses import StreamingResponse
C_GRAY, C_CYAN, C_GREEN, C_YELLOW, C_RED, C_WHITE, C_RESET = (
"\033[90m",
"\033[96m",
"\033[92m",
"\033[93m",
"\033[91m",
"\033[97m",
"\033[0m",
)
CONFIG = {
"remote_port": 11430,
"url": "http://127.0.0.1:11430",
"timeout": httpx.Timeout(None),
"loop": None,
"models_cache": [],
}
def get_ts():
ts = datetime.now().strftime("%H:%M:%S.%f")[:-3]
return f"{C_GRAY}[{ts}] [:{CONFIG['remote_port']}]{C_RESET}"
def get_width(text):
count = 0
for c in text:
if unicodedata.east_asian_width(c) in "FWA":
count += 2
else:
count += 1
return count
def pad_text(text, target_width):
return text + (" " * max(0, target_width - get_width(text)))
def pulse(char, color=C_RESET):
print(f"{color}{char}{C_RESET}", end="", flush=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
CONFIG["loop"] = asyncio.get_running_loop()
asyncio.create_task(update_model_cache())
yield
app = FastAPI(lifespan=lifespan)
# --- ロジック:モデルリスト取得 ---
async def update_model_cache():
try:
async with httpx.AsyncClient(timeout=10.0) as client:
res = await client.get(f"{CONFIG['url']}/api/tags")
if res.status_code == 200:
new_data = []
for m in res.json().get("models", []):
# ツールサポートの簡易判定
has_tool = False
try:
s = await client.post(
f"{CONFIG['url']}/api/show", json={"name": m["name"]}
)
info = s.json()
details = str(info.get("template", "")) + str(
info.get("details", "")
)
has_tool = any(
w in details.lower() for w in ["tool", "functions"]
)
except:
pass
new_data.append(
{
"name": m["name"],
"size": m["size"] / (1024**3),
"tool": has_tool,
}
)
CONFIG["models_cache"] = new_data
except:
pass
def show_help():
print(
f"\n{get_ts()} {C_WHITE}>>> h:HELP l:LIST ll:DETAIL s:VRAM [digit]:PORT q:EXIT <<<{C_RESET}",
flush=True,
)
def display_models(full=False, short=False):
if not CONFIG["models_cache"] or short:
print(
f"\n{get_ts()} {C_YELLOW}Cache is empty. Ollama may be offline.{C_RESET}",
flush=True,
)
return
print(
f"\n{get_ts()} {C_GREEN}--- Models ({'Detailed' if full else 'Short'}) ---{C_RESET}",
flush=True,
)
NAME_W = 55
for m in CONFIG["models_cache"]:
icon = "" if m["size"] > 16.8 else ("" if m["tool"] else "⚠️")
tag = f"{C_CYAN}[T]{C_RESET}" if m["tool"] else f"{C_GRAY}[-]{C_RESET}"
if full:
print(f"{get_ts()} {icon} {tag} {C_WHITE}{m['name']}{C_RESET}")
print(f"{get_ts()} {C_GRAY}└─ {m['size']:>6.1f} GiB{C_RESET}")
else:
n = m["name"]
if get_width(n) > NAME_W:
while get_width("..." + n) > NAME_W:
n = n[1:]
n = "..." + n
print(
f"{get_ts()} {icon} {tag} {C_WHITE}{pad_text(n, NAME_W)}{C_RESET} {C_CYAN}{m['size']:>6.1f} GiB{C_RESET}"
)
print(f"{get_ts()} {C_GREEN}--- End ---{C_RESET}\n", flush=True)
# --- Proxy 本体 ---
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def sticky_proxy(path: str, request: Request):
print(f"\n{get_ts()} {C_WHITE}/{path}{C_RESET} ", end="", flush=True)
body = await request.body()
# リクエストの長さに応じてインジケータを出す
for _ in range(min(len(body) // 256 + 1, 5)):
pulse("^", C_CYAN)
pulse("|", C_YELLOW)
async def stream_response():
# 接続エラーを回避するためにリクエストごとにClientを生成
async with httpx.AsyncClient(
timeout=CONFIG["timeout"], base_url="http://127.0.0.1:11432"
) as client:
try:
# 宛先を強制的に 127.0.0.1 に固定したURLで構築
target_url = f"{CONFIG['url']}/{path}"
async with client.stream(
request.method,
target_url,
content=body,
headers={
k: v
for k, v in request.headers.items()
if k.lower() not in ["host", "content-length"]
},
) as response:
pulse("v", C_GREEN)
async for chunk in response.aiter_bytes():
pulse("v", C_GREEN)
yield chunk
pulse("*", C_YELLOW)
except Exception as e:
print(f" {C_RED}[Err] {type(e).__name__}: {e}{C_RESET}", flush=True)
finally:
print("", flush=True)
return StreamingResponse(stream_response())
def input_thread():
while True:
try:
line = sys.stdin.readline()
if not line:
break
cmd = line.strip().lower()
if cmd == "q":
os._exit(0)
elif cmd == "h":
show_help()
elif cmd == "l":
display_models(False)
elif cmd == "ll":
display_models(full=True)
elif cmd == "s":
async def ps():
async with httpx.AsyncClient() as c:
r = await c.get(f"{CONFIG['url']}/api/ps")
if r.status_code == 200:
print(f"\n{get_ts()} {C_CYAN}--- VRAM ---{C_RESET}")
for m in r.json().get("models", []):
print(
f"{get_ts()} {m['name']:<25} {m['size_vram'] / (1024**3):.1f}G"
)
if CONFIG["loop"]:
asyncio.run_coroutine_threadsafe(ps(), CONFIG["loop"])
elif cmd.isdigit():
p = int(cmd)
CONFIG["remote_port"], CONFIG["url"] = p, f"http://127.0.0.1:{p}"
print(
f"\n{get_ts()} {C_YELLOW}Switch Target -> {CONFIG['url']}{C_RESET}"
)
if CONFIG["loop"]:
asyncio.run_coroutine_threadsafe(
update_model_cache(), CONFIG["loop"]
)
except:
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-r", "--remote", type=int, default=11432
) # デフォルトを11432に
parser.add_argument("-l", "--local", type=int, default=11434)
args = parser.parse_args()
CONFIG["remote_port"] = args.remote
CONFIG["url"] = f"http://127.0.0.1:{args.remote}"
threading.Thread(target=input_thread, daemon=True).start()
print(
f"\n{get_ts()} {C_CYAN}oproxy Start (L:{args.local} -> R:{args.remote}){C_RESET}"
)
show_help()
uvicorn.run(app, host="127.0.0.1", port=args.local, log_level="error")