diff --git a/oproxy.py b/oproxy.py index 9b6c2a0..e4e525b 100755 --- a/oproxy.py +++ b/oproxy.py @@ -1,47 +1,71 @@ # バージョン情報: Python 3.12+ / FastAPI 0.115.0 / uvicorn 0.30.0 / httpx 0.28.0 -import httpx, asyncio, json, sys, threading, os, argparse, unicodedata -from datetime import datetime +import argparse +import asyncio +import json +import os +import sys +import threading +import unicodedata from contextlib import asynccontextmanager +from datetime import datetime + +import httpx import uvicorn from fastapi import FastAPI, Request from starlette.responses import StreamingResponse -C_GRAY, C_CYAN, C_GREEN, C_YELLOW, C_RED, C_WHITE, C_RESET = \ - "\033[90m", "\033[96m", "\033[92m", "\033[93m", "\033[91m", "\033[97m", "\033[0m" +C_GRAY, C_CYAN, C_GREEN, C_YELLOW, C_RED, C_WHITE, C_RESET = ( + "\033[90m", + "\033[96m", + "\033[92m", + "\033[93m", + "\033[91m", + "\033[97m", + "\033[0m", +) CONFIG = { "remote_port": 11430, "url": "http://127.0.0.1:11430", "timeout": httpx.Timeout(None), "loop": None, - "models_cache": [] + "models_cache": [], } + def get_ts(): - ts = datetime.now().strftime('%H:%M:%S.%f')[:-3] + ts = datetime.now().strftime("%H:%M:%S.%f")[:-3] return f"{C_GRAY}[{ts}] [:{CONFIG['remote_port']}]{C_RESET}" + def get_width(text): count = 0 for c in text: - if unicodedata.east_asian_width(c) in 'FWA': count += 2 - else: count += 1 + if unicodedata.east_asian_width(c) in "FWA": + count += 2 + else: + count += 1 return count + def pad_text(text, target_width): return text + (" " * max(0, target_width - get_width(text))) + def pulse(char, color=C_RESET): print(f"{color}{char}{C_RESET}", end="", flush=True) + @asynccontextmanager async def lifespan(app: FastAPI): CONFIG["loop"] = asyncio.get_running_loop() asyncio.create_task(update_model_cache()) yield + app = FastAPI(lifespan=lifespan) + # --- ロジック:モデルリスト取得 --- async def update_model_cache(): try: @@ -49,58 +73,99 @@ async def update_model_cache(): res = await client.get(f"{CONFIG['url']}/api/tags") if res.status_code == 200: new_data = [] - for m in res.json().get('models', []): + for m in res.json().get("models", []): # ツールサポートの簡易判定 has_tool = False try: - s = await client.post(f"{CONFIG['url']}/api/show", json={"name": m['name']}) + s = await client.post( + f"{CONFIG['url']}/api/show", json={"name": m["name"]} + ) info = s.json() - details = str(info.get("template", "")) + str(info.get("details", "")) - has_tool = any(w in details.lower() for w in ["tool", "functions"]) - except: pass - new_data.append({"name": m['name'], "size": m['size']/(1024**3), "tool": has_tool}) + details = str(info.get("template", "")) + str( + info.get("details", "") + ) + has_tool = any( + w in details.lower() for w in ["tool", "functions"] + ) + except: + pass + new_data.append( + { + "name": m["name"], + "size": m["size"] / (1024**3), + "tool": has_tool, + } + ) CONFIG["models_cache"] = new_data - except: pass + except: + pass + def show_help(): - print(f"\n{get_ts()} {C_WHITE}>>> h:HELP l:LIST ll:DETAIL s:VRAM [digit]:PORT q:EXIT <<<{C_RESET}", flush=True) + print( + f"\n{get_ts()} {C_WHITE}>>> h:HELP l:LIST ll:DETAIL s:VRAM [digit]:PORT q:EXIT <<<{C_RESET}", + flush=True, + ) -def display_models(full=False): - if not CONFIG["models_cache"]: - print(f"\n{get_ts()} {C_YELLOW}Cache is empty. Ollama may be offline.{C_RESET}", flush=True) + +def display_models(full=False, short=False): + if not CONFIG["models_cache"] or short: + print( + f"\n{get_ts()} {C_YELLOW}Cache is empty. Ollama may be offline.{C_RESET}", + flush=True, + ) return - print(f"\n{get_ts()} {C_GREEN}--- Models ({'Detailed' if full else 'Short'}) ---{C_RESET}", flush=True) + print( + f"\n{get_ts()} {C_GREEN}--- Models ({'Detailed' if full else 'Short'}) ---{C_RESET}", + flush=True, + ) NAME_W = 55 for m in CONFIG["models_cache"]: - icon = "❌" if m['size'] > 16.8 else ("✅" if m['tool'] else "⚠️") - tag = f"{C_CYAN}[T]{C_RESET}" if m['tool'] else f"{C_GRAY}[-]{C_RESET}" + icon = "❌" if m["size"] > 16.8 else ("✅" if m["tool"] else "⚠️") + tag = f"{C_CYAN}[T]{C_RESET}" if m["tool"] else f"{C_GRAY}[-]{C_RESET}" if full: print(f"{get_ts()} {icon} {tag} {C_WHITE}{m['name']}{C_RESET}") print(f"{get_ts()} {C_GRAY}└─ {m['size']:>6.1f} GiB{C_RESET}") else: - n = m['name'] + n = m["name"] if get_width(n) > NAME_W: - while get_width("..." + n) > NAME_W: n = n[1:] + while get_width("..." + n) > NAME_W: + n = n[1:] n = "..." + n - print(f"{get_ts()} {icon} {tag} {C_WHITE}{pad_text(n, NAME_W)}{C_RESET} {C_CYAN}{m['size']:>6.1f} GiB{C_RESET}") + print( + f"{get_ts()} {icon} {tag} {C_WHITE}{pad_text(n, NAME_W)}{C_RESET} {C_CYAN}{m['size']:>6.1f} GiB{C_RESET}" + ) print(f"{get_ts()} {C_GREEN}--- End ---{C_RESET}\n", flush=True) + # --- Proxy 本体 --- @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def sticky_proxy(path: str, request: Request): print(f"\n{get_ts()} {C_WHITE}/{path}{C_RESET} ", end="", flush=True) body = await request.body() # リクエストの長さに応じてインジケータを出す - for _ in range(min(len(body)//256 + 1, 5)): pulse("^", C_CYAN) + for _ in range(min(len(body) // 256 + 1, 5)): + pulse("^", C_CYAN) pulse("|", C_YELLOW) async def stream_response(): # 接続エラーを回避するためにリクエストごとにClientを生成 - async with httpx.AsyncClient(timeout=CONFIG["timeout"], base_url="http://127.0.0.1:11432") as client: + async with httpx.AsyncClient( + timeout=CONFIG["timeout"], base_url="http://127.0.0.1:11432" + ) as client: try: # 宛先を強制的に 127.0.0.1 に固定したURLで構築 target_url = f"{CONFIG['url']}/{path}" - async with client.stream(request.method, target_url, content=body, headers={k:v for k,v in request.headers.items() if k.lower() not in ["host","content-length"]}) as response: + async with client.stream( + request.method, + target_url, + content=body, + headers={ + k: v + for k, v in request.headers.items() + if k.lower() not in ["host", "content-length"] + }, + ) as response: pulse("v", C_GREEN) async for chunk in response.aiter_bytes(): pulse("v", C_GREEN) @@ -113,42 +178,64 @@ async def sticky_proxy(path: str, request: Request): return StreamingResponse(stream_response()) + def input_thread(): while True: try: line = sys.stdin.readline() - if not line: break + if not line: + break cmd = line.strip().lower() - if cmd == 'q': os._exit(0) - elif cmd == 'h': show_help() - elif cmd == 'l': display_models(False) - elif cmd == 'll': display_models(full=True) - elif cmd == 's': + if cmd == "q": + os._exit(0) + elif cmd == "h": + show_help() + elif cmd == "l": + display_models(False) + elif cmd == "ll": + display_models(full=True) + elif cmd == "s": + async def ps(): async with httpx.AsyncClient() as c: r = await c.get(f"{CONFIG['url']}/api/ps") if r.status_code == 200: print(f"\n{get_ts()} {C_CYAN}--- VRAM ---{C_RESET}") for m in r.json().get("models", []): - print(f"{get_ts()} {m['name']:<25} {m['size_vram']/(1024**3):.1f}G") - if CONFIG["loop"]: asyncio.run_coroutine_threadsafe(ps(), CONFIG["loop"]) + print( + f"{get_ts()} {m['name']:<25} {m['size_vram'] / (1024**3):.1f}G" + ) + + if CONFIG["loop"]: + asyncio.run_coroutine_threadsafe(ps(), CONFIG["loop"]) elif cmd.isdigit(): p = int(cmd) CONFIG["remote_port"], CONFIG["url"] = p, f"http://127.0.0.1:{p}" - print(f"\n{get_ts()} {C_YELLOW}Switch Target -> {CONFIG['url']}{C_RESET}") - if CONFIG["loop"]: asyncio.run_coroutine_threadsafe(update_model_cache(), CONFIG["loop"]) - except: break + print( + f"\n{get_ts()} {C_YELLOW}Switch Target -> {CONFIG['url']}{C_RESET}" + ) + if CONFIG["loop"]: + asyncio.run_coroutine_threadsafe( + update_model_cache(), CONFIG["loop"] + ) + except: + break + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-r", "--remote", type=int, default=11432) # デフォルトを11432に + parser.add_argument( + "-r", "--remote", type=int, default=11432 + ) # デフォルトを11432に parser.add_argument("-l", "--local", type=int, default=11434) args = parser.parse_args() - + CONFIG["remote_port"] = args.remote CONFIG["url"] = f"http://127.0.0.1:{args.remote}" - + threading.Thread(target=input_thread, daemon=True).start() - print(f"\n{get_ts()} {C_CYAN}oproxy Start (L:{args.local} -> R:{args.remote}){C_RESET}") + print( + f"\n{get_ts()} {C_CYAN}oproxy Start (L:{args.local} -> R:{args.remote}){C_RESET}" + ) show_help() uvicorn.run(app, host="127.0.0.1", port=args.local, log_level="error")