diff --git a/gpt_shows_rename/gpt.py b/gpt_shows_rename/gpt.py index eac87e9..b614283 100644 --- a/gpt_shows_rename/gpt.py +++ b/gpt_shows_rename/gpt.py @@ -77,22 +77,35 @@ async def get_response(cfg: Config, inp: str, files: List[str], if tvdb_id: prompt += f'\nThe TVDB ID is `{tvdb_id}`.' if is_support_structed_output(cfg.model): - http_client = httpx.Client(proxy=cfg.proxy) - client = openai.Client(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client) - res = client.beta.chat.completions.parse( + http_client = httpx.AsyncClient(proxy=cfg.proxy) + client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client) + result = None + mes = '' + refusal = '' + async with client.beta.chat.completions.stream( model=cfg.model, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ], response_format=Files, - ) - mes = res.choices[0].message - if mes.refusal: - raise ValueError(f"Model refused to answer: {mes.refusal}") - if not mes.parsed: - raise ValueError(f"Unhandled error: {mes.to_json()}") - return mes.parsed + ) as stream: + async for event in stream: + if event.type == "content.delta": + mes += event.delta + print(event.delta, end='', flush=True) + if event.parsed: + result = event.parsed + elif event.type == "refusal.delta": + refusal += event.delta + print(event.delta, end='', flush=True) + print('', end='', flush=True) + if refusal: + raise ValueError(f"Model refused to answer: {refusal}") + if not result: + raise ValueError(f"Unhandle Error: {mes}") + result = Files(**result) + return result else: http_client = httpx.AsyncClient(proxy=cfg.proxy) client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)