Add support models which does not support structed output
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
from .config import load_config
|
||||
from .file import gen_input_list, link_files
|
||||
from .gpt import get_response
|
||||
import asyncio
|
||||
|
||||
|
||||
cfg = load_config()
|
||||
files = gen_input_list(cfg.input)
|
||||
res = get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id)
|
||||
for f in res.files:
|
||||
print(files[f.index], '->', f.name)
|
||||
input('Continue?')
|
||||
link_files(cfg.input, files, res, cfg.output, cfg.hardlink)
|
||||
async def main():
|
||||
cfg = load_config()
|
||||
files = gen_input_list(cfg.input)
|
||||
res = await get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id)
|
||||
for f in res.files:
|
||||
print(files[f.index], '->', f.name)
|
||||
input('Continue?')
|
||||
link_files(cfg.input, files, res, cfg.output, cfg.hardlink)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -35,10 +35,39 @@ def gen_files_list(files: List[str]):
|
||||
return prompt
|
||||
|
||||
|
||||
def get_response(cfg: Config, inp: str, files: List[str],
|
||||
series_name: str = None, year: int = None, tmdb_id: int = None,
|
||||
tvdb_id: int = None) -> Files:
|
||||
SUPPORTED_MODELS = ["gpt-4o", "gpt-4o-mini", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-mini-2024-07-18"]
|
||||
|
||||
|
||||
def is_support_structed_output(model: str):
|
||||
if model.startswith('ft:'):
|
||||
models = model.split(':')
|
||||
if len(models) < 2:
|
||||
return False
|
||||
model = models[1]
|
||||
return model in SUPPORTED_MODELS
|
||||
|
||||
|
||||
def parse_result(result: str):
|
||||
lines = result.splitlines()
|
||||
files = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("```"):
|
||||
continue
|
||||
try:
|
||||
files.append(File(**json.loads(line)))
|
||||
except Exception:
|
||||
raise ValueError(f"Failed to parse message: {line}")
|
||||
return Files(files=files)
|
||||
|
||||
|
||||
async def get_response(cfg: Config, inp: str, files: List[str],
|
||||
series_name: str = None, year: int = None, tmdb_id: int = None,
|
||||
tvdb_id: int = None) -> Files:
|
||||
prompt = f'The input directory is `{inp}`.'
|
||||
prompt += '\n' + gen_files_list(files)
|
||||
if series_name:
|
||||
prompt += f'\nThe series name is `{series_name}`.'
|
||||
if year:
|
||||
@@ -47,18 +76,41 @@ def get_response(cfg: Config, inp: str, files: List[str],
|
||||
prompt += f'\nThe TMDB ID is `{tmdb_id}`.'
|
||||
if tvdb_id:
|
||||
prompt += f'\nThe TVDB ID is `{tvdb_id}`.'
|
||||
prompt += '\n' + gen_files_list(files)
|
||||
http_client = httpx.Client(proxy=cfg.proxy)
|
||||
client = openai.Client(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)
|
||||
res = client.beta.chat.completions.parse(
|
||||
model=cfg.model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format=Files,
|
||||
)
|
||||
mes = res.choices[0].message
|
||||
if mes.refusal:
|
||||
raise ValueError(f"Model refused to answer: {mes.refusal}")
|
||||
return mes.parsed
|
||||
if is_support_structed_output(cfg.model):
|
||||
http_client = httpx.Client(proxy=cfg.proxy)
|
||||
client = openai.Client(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)
|
||||
res = client.beta.chat.completions.parse(
|
||||
model=cfg.model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format=Files,
|
||||
)
|
||||
mes = res.choices[0].message
|
||||
if mes.refusal:
|
||||
raise ValueError(f"Model refused to answer: {mes.refusal}")
|
||||
if not mes.parsed:
|
||||
raise ValueError(f"Unhandled error: {mes.to_json()}")
|
||||
return mes.parsed
|
||||
else:
|
||||
http_client = httpx.AsyncClient(proxy=cfg.proxy)
|
||||
client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)
|
||||
res = await client.chat.completions.create(
|
||||
model=cfg.model,
|
||||
messages=[
|
||||
{"role": "system", "content": SYSTEM_PROMPT + "\nThe format of the returned result is consistent with the input file list."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
mes = ''
|
||||
async for chunk in res:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
data = choice.delta.content
|
||||
mes += data
|
||||
print(data, end='', flush=True)
|
||||
print(flush=True)
|
||||
return parse_result(mes)
|
||||
|
||||
Reference in New Issue
Block a user