From b5843e72adee1269bd936802cdce40f079ef174d Mon Sep 17 00:00:00 2001 From: lifegpc Date: Mon, 20 Jan 2025 09:00:40 +0800 Subject: [PATCH] Add support models which does not support structed output --- gpt_shows_rename/__main__.py | 19 +++++--- gpt_shows_rename/gpt.py | 88 ++++++++++++++++++++++++++++-------- 2 files changed, 82 insertions(+), 25 deletions(-) diff --git a/gpt_shows_rename/__main__.py b/gpt_shows_rename/__main__.py index 2458267..4976fd0 100644 --- a/gpt_shows_rename/__main__.py +++ b/gpt_shows_rename/__main__.py @@ -1,12 +1,17 @@ from .config import load_config from .file import gen_input_list, link_files from .gpt import get_response +import asyncio -cfg = load_config() -files = gen_input_list(cfg.input) -res = get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id) -for f in res.files: - print(files[f.index], '->', f.name) -input('Continue?') -link_files(cfg.input, files, res, cfg.output, cfg.hardlink) +async def main(): + cfg = load_config() + files = gen_input_list(cfg.input) + res = await get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id) + for f in res.files: + print(files[f.index], '->', f.name) + input('Continue?') + link_files(cfg.input, files, res, cfg.output, cfg.hardlink) + + +asyncio.run(main()) diff --git a/gpt_shows_rename/gpt.py b/gpt_shows_rename/gpt.py index c30dd44..eac87e9 100644 --- a/gpt_shows_rename/gpt.py +++ b/gpt_shows_rename/gpt.py @@ -35,10 +35,39 @@ def gen_files_list(files: List[str]): return prompt -def get_response(cfg: Config, inp: str, files: List[str], - series_name: str = None, year: int = None, tmdb_id: int = None, - tvdb_id: int = None) -> Files: +SUPPORTED_MODELS = ["gpt-4o", "gpt-4o-mini", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-mini-2024-07-18"] + + +def is_support_structed_output(model: str): + if model.startswith('ft:'): + models = model.split(':') + if len(models) < 2: + return False + model = models[1] + return model in SUPPORTED_MODELS + + +def parse_result(result: str): + lines = result.splitlines() + files = [] + for line in lines: + line = line.strip() + if not line: + continue + if line.startswith("```"): + continue + try: + files.append(File(**json.loads(line))) + except Exception: + raise ValueError(f"Failed to parse message: {line}") + return Files(files=files) + + +async def get_response(cfg: Config, inp: str, files: List[str], + series_name: str = None, year: int = None, tmdb_id: int = None, + tvdb_id: int = None) -> Files: prompt = f'The input directory is `{inp}`.' + prompt += '\n' + gen_files_list(files) if series_name: prompt += f'\nThe series name is `{series_name}`.' if year: @@ -47,18 +76,41 @@ def get_response(cfg: Config, inp: str, files: List[str], prompt += f'\nThe TMDB ID is `{tmdb_id}`.' if tvdb_id: prompt += f'\nThe TVDB ID is `{tvdb_id}`.' - prompt += '\n' + gen_files_list(files) - http_client = httpx.Client(proxy=cfg.proxy) - client = openai.Client(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client) - res = client.beta.chat.completions.parse( - model=cfg.model, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt}, - ], - response_format=Files, - ) - mes = res.choices[0].message - if mes.refusal: - raise ValueError(f"Model refused to answer: {mes.refusal}") - return mes.parsed + if is_support_structed_output(cfg.model): + http_client = httpx.Client(proxy=cfg.proxy) + client = openai.Client(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client) + res = client.beta.chat.completions.parse( + model=cfg.model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ], + response_format=Files, + ) + mes = res.choices[0].message + if mes.refusal: + raise ValueError(f"Model refused to answer: {mes.refusal}") + if not mes.parsed: + raise ValueError(f"Unhandled error: {mes.to_json()}") + return mes.parsed + else: + http_client = httpx.AsyncClient(proxy=cfg.proxy) + client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client) + res = await client.chat.completions.create( + model=cfg.model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT + "\nThe format of the returned result is consistent with the input file list."}, + {"role": "user", "content": prompt}, + ], + stream=True, + ) + mes = '' + async for chunk in res: + if chunk.choices: + choice = chunk.choices[0] + if choice.delta and choice.delta.content: + data = choice.delta.content + mes += data + print(data, end='', flush=True) + print(flush=True) + return parse_result(mes)