Add support models which does not support structed output

This commit is contained in:
2025-01-20 09:00:40 +08:00
parent c7e39dfa0c
commit b5843e72ad
2 changed files with 82 additions and 25 deletions

View File

@@ -1,12 +1,17 @@
from .config import load_config
from .file import gen_input_list, link_files
from .gpt import get_response
import asyncio
cfg = load_config()
files = gen_input_list(cfg.input)
res = get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id)
for f in res.files:
print(files[f.index], '->', f.name)
input('Continue?')
link_files(cfg.input, files, res, cfg.output, cfg.hardlink)
async def main():
cfg = load_config()
files = gen_input_list(cfg.input)
res = await get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id)
for f in res.files:
print(files[f.index], '->', f.name)
input('Continue?')
link_files(cfg.input, files, res, cfg.output, cfg.hardlink)
asyncio.run(main())

View File

@@ -35,10 +35,39 @@ def gen_files_list(files: List[str]):
return prompt
def get_response(cfg: Config, inp: str, files: List[str],
series_name: str = None, year: int = None, tmdb_id: int = None,
tvdb_id: int = None) -> Files:
SUPPORTED_MODELS = ["gpt-4o", "gpt-4o-mini", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-mini-2024-07-18"]
def is_support_structed_output(model: str):
if model.startswith('ft:'):
models = model.split(':')
if len(models) < 2:
return False
model = models[1]
return model in SUPPORTED_MODELS
def parse_result(result: str):
lines = result.splitlines()
files = []
for line in lines:
line = line.strip()
if not line:
continue
if line.startswith("```"):
continue
try:
files.append(File(**json.loads(line)))
except Exception:
raise ValueError(f"Failed to parse message: {line}")
return Files(files=files)
async def get_response(cfg: Config, inp: str, files: List[str],
series_name: str = None, year: int = None, tmdb_id: int = None,
tvdb_id: int = None) -> Files:
prompt = f'The input directory is `{inp}`.'
prompt += '\n' + gen_files_list(files)
if series_name:
prompt += f'\nThe series name is `{series_name}`.'
if year:
@@ -47,18 +76,41 @@ def get_response(cfg: Config, inp: str, files: List[str],
prompt += f'\nThe TMDB ID is `{tmdb_id}`.'
if tvdb_id:
prompt += f'\nThe TVDB ID is `{tvdb_id}`.'
prompt += '\n' + gen_files_list(files)
http_client = httpx.Client(proxy=cfg.proxy)
client = openai.Client(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)
res = client.beta.chat.completions.parse(
model=cfg.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format=Files,
)
mes = res.choices[0].message
if mes.refusal:
raise ValueError(f"Model refused to answer: {mes.refusal}")
return mes.parsed
if is_support_structed_output(cfg.model):
http_client = httpx.Client(proxy=cfg.proxy)
client = openai.Client(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)
res = client.beta.chat.completions.parse(
model=cfg.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
response_format=Files,
)
mes = res.choices[0].message
if mes.refusal:
raise ValueError(f"Model refused to answer: {mes.refusal}")
if not mes.parsed:
raise ValueError(f"Unhandled error: {mes.to_json()}")
return mes.parsed
else:
http_client = httpx.AsyncClient(proxy=cfg.proxy)
client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)
res = await client.chat.completions.create(
model=cfg.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT + "\nThe format of the returned result is consistent with the input file list."},
{"role": "user", "content": prompt},
],
stream=True,
)
mes = ''
async for chunk in res:
if chunk.choices:
choice = chunk.choices[0]
if choice.delta and choice.delta.content:
data = choice.delta.content
mes += data
print(data, end='', flush=True)
print(flush=True)
return parse_result(mes)