diff --git a/oproxy.py b/oproxy.py index fcf9b95..9b6c2a0 100755 --- a/oproxy.py +++ b/oproxy.py @@ -1,234 +1,154 @@ # バージョン情報: Python 3.12+ / FastAPI 0.115.0 / uvicorn 0.30.0 / httpx 0.28.0 -import argparse -import asyncio -import json -import os -import re -import sys -import threading -import unicodedata +import httpx, asyncio, json, sys, threading, os, argparse, unicodedata from datetime import datetime - -import httpx +from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI, Request from starlette.responses import StreamingResponse -app = FastAPI() - -# --- カラー・設定 --- -C_GRAY, C_CYAN, C_GREEN, C_YELLOW, C_RED, C_RESET = ( - "\033[90m", - "\033[96m", - "\033[92m", - "\033[93m", - "\033[91m", - "\033[0m", -) -B_GREEN, B_YELLOW, B_RED = "\033[42;30m", "\033[43;30m", "\033[41;37m" - -MEM_LIMIT = 16.8 -NAME_MAX_WIDTH = 50 -CONFIG = {"url": "http://127.0.0.1:11430"} +C_GRAY, C_CYAN, C_GREEN, C_YELLOW, C_RED, C_WHITE, C_RESET = \ + "\033[90m", "\033[96m", "\033[92m", "\033[93m", "\033[91m", "\033[97m", "\033[0m" +CONFIG = { + "remote_port": 11430, + "url": "http://127.0.0.1:11430", + "timeout": httpx.Timeout(None), + "loop": None, + "models_cache": [] +} def get_ts(): - return f"{C_GRAY}[{datetime.now().strftime('%H:%M:%S.%f')[:-3]}]{C_RESET}" + ts = datetime.now().strftime('%H:%M:%S.%f')[:-3] + return f"{C_GRAY}[{ts}] [:{CONFIG['remote_port']}]{C_RESET}" +def get_width(text): + count = 0 + for c in text: + if unicodedata.east_asian_width(c) in 'FWA': count += 2 + else: count += 1 + return count -def get_visual_width(text): - return sum( - 2 if unicodedata.east_asian_width(c) in ("W", "F", "A") else 1 for c in text - ) +def pad_text(text, target_width): + return text + (" " * max(0, target_width - get_width(text))) +def pulse(char, color=C_RESET): + print(f"{color}{char}{C_RESET}", end="", flush=True) -def pad_right(text, width): - plain_text = re.sub(r"\033\[[0-9;]*m", "", text) - return text + " " * max(0, width - get_visual_width(plain_text)) +@asynccontextmanager +async def lifespan(app: FastAPI): + CONFIG["loop"] = asyncio.get_running_loop() + asyncio.create_task(update_model_cache()) + yield +app = FastAPI(lifespan=lifespan) -def draw_progress(current, total, model_name=""): - width = 30 - filled = int(width * current / total) - bar = "█" * filled + "░" * (width - filled) - percent = (current / total) * 100 - sys.stdout.write( - f"\r{get_ts()} {C_CYAN}[Scanning] |{bar}| {percent:>3.0f}% {C_GRAY}({model_name[:20]}...){C_RESET}" - ) - sys.stdout.flush() - - -async def check_tool_support(client, model_name): - try: - res = await client.post(f"{CONFIG['url']}/api/show", json={"name": model_name}) - if res.status_code == 200: - info = res.json() - content = " ".join( - [ - info.get("template", ""), - info.get("system", ""), - info.get("modelfile", ""), - ] - ).lower() - return any(x in content for x in ["tool", "function", "call", "assistant"]) - except: - pass - return False - - -def run_analyze(): - asyncio.run(analyze_models()) - - -async def analyze_models(): - url = CONFIG["url"] - print(f"\n{get_ts()} {C_YELLOW}[Analyze] {url} 接続開始...{C_RESET}") +# --- ロジック:モデルリスト取得 --- +async def update_model_cache(): try: async with httpx.AsyncClient(timeout=10.0) as client: - res = await client.get(f"{url}/api/tags") - if res.status_code != 200: - print(f"{get_ts()} {C_RED}分析エラー: HTTP {res.status_code}{C_RESET}") - return - - models_data = res.json().get("models", []) - total = len(models_data) - enriched = [] - for i, m in enumerate(models_data, 1): - full_name = m["name"] - draw_progress(i, total, full_name.split("/")[-1]) - size_gb = m["size"] / (1024**3) - has_tool = await check_tool_support(client, full_name) - score = ( - 0 - if size_gb <= MEM_LIMIT and has_tool - else (1 if size_gb <= MEM_LIMIT else 2) - ) - enriched.append( - { - "full_name": full_name, - "display_name": full_name.split("/")[-1], - "size_gb": size_gb, - "has_tool": has_tool, - "score": score, - } - ) - - print("\n") - enriched.sort(key=lambda x: (x["score"], x["display_name"], -x["size_gb"])) - print( - f"{get_ts()} {C_GREEN}--- リモートモデル戦力分析 (Target: {url}) ---{C_RESET}" - ) - prefix_width = 32 - for em in enriched: - status = ( - f"{B_GREEN} READY{C_RESET}" - if em["score"] == 0 - else ( - f"{B_YELLOW} TOOL {C_RESET}" - if em["score"] == 1 - else f"{B_RED} MEM {C_RESET}" - ) - ) - tool = ( - f"{C_CYAN}[TOOL]{C_RESET}" - if em["has_tool"] - else f"{C_GRAY}[----]{C_RESET}" - ) - name, size = em["display_name"], f"{em['size_gb']:>5.1f} GiB" - if get_visual_width(name) > NAME_MAX_WIDTH: - print(f"{get_ts()} {status} {tool} {name[:NAME_MAX_WIDTH]} {size}") - print( - f"{get_ts()} {' ' * (prefix_width - 15)} {C_GRAY}└ {name[NAME_MAX_WIDTH:]}{C_RESET}" - ) - else: - print( - f"{get_ts()} {status} {tool} {pad_right(name, NAME_MAX_WIDTH)} {size}" - ) - print(f"{get_ts()} {C_GREEN}{'-' * 80}{C_RESET}") - show_help() - except Exception as e: - print(f"\n{get_ts()} {C_RED}分析失敗: {e}{C_RESET}") - + res = await client.get(f"{CONFIG['url']}/api/tags") + if res.status_code == 200: + new_data = [] + for m in res.json().get('models', []): + # ツールサポートの簡易判定 + has_tool = False + try: + s = await client.post(f"{CONFIG['url']}/api/show", json={"name": m['name']}) + info = s.json() + details = str(info.get("template", "")) + str(info.get("details", "")) + has_tool = any(w in details.lower() for w in ["tool", "functions"]) + except: pass + new_data.append({"name": m['name'], "size": m['size']/(1024**3), "tool": has_tool}) + CONFIG["models_cache"] = new_data + except: pass def show_help(): - print(f"\n{C_CYAN}[Command Help]{C_RESET}") - print(f" {C_YELLOW}:p [port]{C_RESET} - 転送先(Ollama)のポートを切り替えて再分析") - print(f" {C_YELLOW}?{C_RESET} - このヘルプを表示") - print(f" {C_YELLOW}q{C_RESET} - プロキシを終了") - print(f"{C_GRAY}------------------------------------------{C_RESET}\n") + print(f"\n{get_ts()} {C_WHITE}>>> h:HELP l:LIST ll:DETAIL s:VRAM [digit]:PORT q:EXIT <<<{C_RESET}", flush=True) +def display_models(full=False): + if not CONFIG["models_cache"]: + print(f"\n{get_ts()} {C_YELLOW}Cache is empty. Ollama may be offline.{C_RESET}", flush=True) + return + print(f"\n{get_ts()} {C_GREEN}--- Models ({'Detailed' if full else 'Short'}) ---{C_RESET}", flush=True) + NAME_W = 55 + for m in CONFIG["models_cache"]: + icon = "❌" if m['size'] > 16.8 else ("✅" if m['tool'] else "⚠️") + tag = f"{C_CYAN}[T]{C_RESET}" if m['tool'] else f"{C_GRAY}[-]{C_RESET}" + if full: + print(f"{get_ts()} {icon} {tag} {C_WHITE}{m['name']}{C_RESET}") + print(f"{get_ts()} {C_GRAY}└─ {m['size']:>6.1f} GiB{C_RESET}") + else: + n = m['name'] + if get_width(n) > NAME_W: + while get_width("..." + n) > NAME_W: n = n[1:] + n = "..." + n + print(f"{get_ts()} {icon} {tag} {C_WHITE}{pad_text(n, NAME_W)}{C_RESET} {C_CYAN}{m['size']:>6.1f} GiB{C_RESET}") + print(f"{get_ts()} {C_GREEN}--- End ---{C_RESET}\n", flush=True) +# --- Proxy 本体 --- @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def sticky_proxy(path: str, request: Request): - target_url = f"{CONFIG['url']}/{path}" - - # 1. リクエスト吸い込みの可視化 - print(f"\n{get_ts()} /{path} -> {CONFIG['url']}: ", end="", flush=True) - body = b"" - async for chunk in request.stream(): - body += chunk - print(f"{C_CYAN}^{C_RESET}", end="", flush=True) - - print(f"{C_YELLOW}|{C_RESET}", end="", flush=True) - - headers = { - k: v - for k, v in request.headers.items() - if k.lower() not in ["host", "content-length"] - } + print(f"\n{get_ts()} {C_WHITE}/{path}{C_RESET} ", end="", flush=True) + body = await request.body() + # リクエストの長さに応じてインジケータを出す + for _ in range(min(len(body)//256 + 1, 5)): pulse("^", C_CYAN) + pulse("|", C_YELLOW) async def stream_response(): - try: - async with httpx.AsyncClient(timeout=None) as client: - async with client.stream( - request.method, target_url, content=body, headers=headers - ) as response: - # 2. レスポンス吐き出しの可視化 - print(f"{C_GREEN}v:{C_RESET}", end="", flush=True) + # 接続エラーを回避するためにリクエストごとにClientを生成 + async with httpx.AsyncClient(timeout=CONFIG["timeout"], base_url="http://127.0.0.1:11432") as client: + try: + # 宛先を強制的に 127.0.0.1 に固定したURLで構築 + target_url = f"{CONFIG['url']}/{path}" + async with client.stream(request.method, target_url, content=body, headers={k:v for k,v in request.headers.items() if k.lower() not in ["host","content-length"]}) as response: + pulse("v", C_GREEN) async for chunk in response.aiter_bytes(): - print(f"{C_GREEN}v{C_RESET}", end="", flush=True) + pulse("v", C_GREEN) yield chunk - print(f"{C_YELLOW}*{C_RESET}", end="", flush=True) - except Exception as e: - print(f" {C_RED}[Err] {e}{C_RESET}") + pulse("*", C_YELLOW) + except Exception as e: + print(f" {C_RED}[Err] {type(e).__name__}: {e}{C_RESET}", flush=True) + finally: + print("", flush=True) return StreamingResponse(stream_response()) - -def interactive_shell(): +def input_thread(): while True: try: - line = sys.stdin.readline().strip().lower() - if not line: - continue - if line == "q": - os._exit(0) - elif line == "?": - show_help() - elif line.startswith(":p"): - parts = line.split() - if len(parts) > 1: - new_port = parts[1] - CONFIG["url"] = f"http://127.0.0.1:{new_port}" - threading.Thread(target=run_analyze, daemon=True).start() - else: - print(f"{C_RED}ポート指定ミス: :p 11435{C_RESET}") - else: - print(f"{C_GRAY}未知のコマンド: '{line}'{C_RESET}") - except EOFError: - break - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-r", "--remote", type=int, default=11430) - parser.add_argument("-l", "--local", type=int, default=11434) - args = parser.parse_args() - CONFIG["url"] = f"http://127.0.0.1:{args.remote}" - asyncio.run(analyze_models()) - threading.Thread(target=interactive_shell, daemon=True).start() - uvicorn.run(app, host="127.0.0.1", port=args.local, log_level="error") - + line = sys.stdin.readline() + if not line: break + cmd = line.strip().lower() + if cmd == 'q': os._exit(0) + elif cmd == 'h': show_help() + elif cmd == 'l': display_models(False) + elif cmd == 'll': display_models(full=True) + elif cmd == 's': + async def ps(): + async with httpx.AsyncClient() as c: + r = await c.get(f"{CONFIG['url']}/api/ps") + if r.status_code == 200: + print(f"\n{get_ts()} {C_CYAN}--- VRAM ---{C_RESET}") + for m in r.json().get("models", []): + print(f"{get_ts()} {m['name']:<25} {m['size_vram']/(1024**3):.1f}G") + if CONFIG["loop"]: asyncio.run_coroutine_threadsafe(ps(), CONFIG["loop"]) + elif cmd.isdigit(): + p = int(cmd) + CONFIG["remote_port"], CONFIG["url"] = p, f"http://127.0.0.1:{p}" + print(f"\n{get_ts()} {C_YELLOW}Switch Target -> {CONFIG['url']}{C_RESET}") + if CONFIG["loop"]: asyncio.run_coroutine_threadsafe(update_model_cache(), CONFIG["loop"]) + except: break if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--remote", type=int, default=11432) # デフォルトを11432に + parser.add_argument("-l", "--local", type=int, default=11434) + args = parser.parse_args() + + CONFIG["remote_port"] = args.remote + CONFIG["url"] = f"http://127.0.0.1:{args.remote}" + + threading.Thread(target=input_thread, daemon=True).start() + print(f"\n{get_ts()} {C_CYAN}oproxy Start (L:{args.local} -> R:{args.remote}){C_RESET}") + show_help() + uvicorn.run(app, host="127.0.0.1", port=args.local, log_level="error")