Add TMDB client and integrate TMDB data into GPT response handling

This commit is contained in:
2025-01-20 10:58:58 +08:00
parent bf23393f26
commit 65cfc941d9
4 changed files with 79 additions and 6 deletions

View File

@@ -1,13 +1,18 @@
from .config import load_config from .config import load_config
from .file import gen_input_list, link_files from .file import gen_input_list, link_files
from .gpt import get_response from .gpt import get_response
from .tmdb import TmdbClient
import asyncio import asyncio
async def main(): async def main():
cfg = load_config() cfg = load_config()
files = gen_input_list(cfg.input) files = gen_input_list(cfg.input)
res = await get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id) tmdb_data = None
if cfg.tmdb_id and not cfg.no_tmdb:
tmdb = TmdbClient(cfg)
tmdb_data = await tmdb.get_tmdb_data(cfg.tmdb_id)
res = await get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id, tmdb_data)
for f in res.files: for f in res.files:
print(files[f.index], '->', f.name) print(files[f.index], '->', f.name)
input('Continue?') input('Continue?')

View File

@@ -27,7 +27,11 @@ class Config:
@property @property
def input(self) -> str: def input(self) -> str:
return self._args.input return self._args.input
@property
def no_tmdb(self) -> bool:
return self._args.no_tmdb
@property @property
def output(self) -> str: def output(self) -> str:
return self._args.output return self._args.output
@@ -40,6 +44,14 @@ class Config:
def series_name(self) -> Optional[str]: def series_name(self) -> Optional[str]:
return self._args.series_name return self._args.series_name
@property
def tmdb_api_key(self) -> Optional[str]:
return self._args.tmdb_api_key if self._args.tmdb_api_key is not None else self._yaml_config.get('tmdb_api_key')
@property
def tmdb_language(self) -> Optional[str]:
return self._args.tmdb_language if self._args.tmdb_language is not None else self._yaml_config.get('tmdb_language')
@property @property
def tmdb_id(self) -> Optional[int]: def tmdb_id(self) -> Optional[int]:
return self._args.tmdb_id return self._args.tmdb_id
@@ -68,9 +80,12 @@ def get_arg_parser() -> argparse.ArgumentParser:
parser.add_argument('-c', '--config', type=str, default='./config.yml', help='Path to the configuration file') parser.add_argument('-c', '--config', type=str, default='./config.yml', help='Path to the configuration file')
parser.add_argument('-s', '--series-name', type=str, help='Series name (optional)') parser.add_argument('-s', '--series-name', type=str, help='Series name (optional)')
parser.add_argument('-Y', '--year', type=int, help='Year of the series (optional)') parser.add_argument('-Y', '--year', type=int, help='Year of the series (optional)')
parser.add_argument('--tmdb-api-key', type=str, help='TMDB API key (optional)')
parser.add_argument('--tmdb-language', type=str, help='TMDB language (optional)')
parser.add_argument('-t', '--tmdb-id', type=int, help='TMDB ID (optional)') parser.add_argument('-t', '--tmdb-id', type=int, help='TMDB ID (optional)')
parser.add_argument('-T', '--tvdb-id', type=int, help='TVDB ID (optional)') parser.add_argument('-T', '--tvdb-id', type=int, help='TVDB ID (optional)')
parser.add_argument('-H', '--hardlink', action='store_true', help='Use hardlink instead of symlink (optional)') parser.add_argument('-H', '--hardlink', action='store_true', help='Use hardlink instead of symlink.')
parser.add_argument('-n', '--no-tmdb', action='store_true', help='Do not use TMDB API to obtain data.')
parser.add_argument('input', help='Input directory.') parser.add_argument('input', help='Input directory.')
parser.add_argument('output', help='Output directory.') parser.add_argument('output', help='Output directory.')
return parser return parser

View File

@@ -2,12 +2,12 @@ import json
import openai import openai
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import Dict, List
from .config import Config from .config import Config
SYSTEM_PROMPT = '''You are an assistant, and your goal is to help users rename file names according to the following rules. The user will provide an input directory and a list of files in JSONL format. You will output the new location for each file after renaming based on the file list. SYSTEM_PROMPT = '''You are an assistant, and your goal is to help users rename file names according to the following rules. The user will provide an input directory and a list of files in JSONL format. You will output the new location for each file after renaming based on the file list.
You will rename files based on the information extracted from the input directory and the file list. Prioritize using the information specified by the user. If no specific information is provided by the user, use the information extracted from the inputs mentioned above. You will rename files based on the information extracted from the input directory and the file list. Prioritize using the information specified by the user. If no specific information is provided by the user, use the information extracted from the inputs mentioned above. If user provide TMDB data, use the information from TMDB data first.
The format for the highest-level directory is `Series Name (Year)`, which may optionally include a TMDB ID or TVDB ID, for example, `Series Name (Year) [tmdbid-1234]`. The format for the highest-level directory is `Series Name (Year)`, which may optionally include a TMDB ID or TVDB ID, for example, `Series Name (Year) [tmdbid-1234]`.
The second-level directory format is `Season XX`. If there is not enough information, use `Season 01` by default. Special episodes, such as OVA, can use `Season 00`. Other movies, such as Bonus, should use `extras`. Trailers for episodes should be same as the episode. The second-level directory format is `Season XX`. If there is not enough information, use `Season 01` by default. Special episodes, such as OVA, can use `Season 00`. Other movies, such as Bonus, should use `extras`. Trailers for episodes should be same as the episode.
The format for files in the third level is `SXXEXX Episode Name`. If multiple episodes are merged, use the following format: `SXXEXX-EXX Episode Name1/Episode Name2`. The episode name is optional. `SXXEXX` must be empty if file is in `extras` folder. If video is a trailer, add `.trailer` to name. For trailers, it is not necessary to replace the episode number with a small one. The format for files in the third level is `SXXEXX Episode Name`. If multiple episodes are merged, use the following format: `SXXEXX-EXX Episode Name1/Episode Name2`. The episode name is optional. `SXXEXX` must be empty if file is in `extras` folder. If video is a trailer, add `.trailer` to name. For trailers, it is not necessary to replace the episode number with a small one.
@@ -24,6 +24,11 @@ class Files(BaseModel):
files: List[File] files: List[File]
class TmdbData(BaseModel):
series_info: dict
seasons_info: Dict[int, dict]
def gen_files_list(files: List[str]): def gen_files_list(files: List[str]):
prompt = '''Here are file list: prompt = '''Here are file list:
```jsonl''' ```jsonl'''
@@ -35,6 +40,19 @@ def gen_files_list(files: List[str]):
return prompt return prompt
def gen_tmdb_data(tmdb_data: TmdbData):
prompt = '''Here are series info from TMDB:
```json'''
prompt += '\n' + json.dumps(tmdb_data.series_info, ensure_ascii=False, separators=(',', ':'))
prompt += '\n```'
for season_number, season_info in tmdb_data.seasons_info.items():
prompt += f'\nHere are season {season_number} info from TMDB:'
prompt += '\n```json'
prompt += '\n' + json.dumps(season_info, ensure_ascii=False, separators=(',', ':'))
prompt += '\n```'
return prompt
SUPPORTED_MODELS = ["gpt-4o", "gpt-4o-mini", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-mini-2024-07-18"] SUPPORTED_MODELS = ["gpt-4o", "gpt-4o-mini", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "gpt-4o-mini-2024-07-18"]
@@ -65,8 +83,10 @@ def parse_result(result: str):
async def get_response(cfg: Config, inp: str, files: List[str], async def get_response(cfg: Config, inp: str, files: List[str],
series_name: str = None, year: int = None, tmdb_id: int = None, series_name: str = None, year: int = None, tmdb_id: int = None,
tvdb_id: int = None) -> Files: tvdb_id: int = None, tmdb_data: TmdbData = None) -> Files:
prompt = f'The input directory is `{inp}`.' prompt = f'The input directory is `{inp}`.'
if tmdb_data:
prompt += '\n' + gen_tmdb_data(tmdb_data)
prompt += '\n' + gen_files_list(files) prompt += '\n' + gen_files_list(files)
if series_name: if series_name:
prompt += f'\nThe series name is `{series_name}`.' prompt += f'\nThe series name is `{series_name}`.'
@@ -76,6 +96,7 @@ async def get_response(cfg: Config, inp: str, files: List[str],
prompt += f'\nThe TMDB ID is `{tmdb_id}`.' prompt += f'\nThe TMDB ID is `{tmdb_id}`.'
if tvdb_id: if tvdb_id:
prompt += f'\nThe TVDB ID is `{tvdb_id}`.' prompt += f'\nThe TVDB ID is `{tvdb_id}`.'
input(prompt)
http_client = httpx.AsyncClient(proxy=cfg.proxy) http_client = httpx.AsyncClient(proxy=cfg.proxy)
client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client) client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client)
if is_support_structed_output(cfg.model): if is_support_structed_output(cfg.model):

32
gpt_shows_rename/tmdb.py Normal file
View File

@@ -0,0 +1,32 @@
import httpx
from .config import Config
from .gpt import TmdbData
class TmdbClient:
def __init__(self, cfg: Config):
self._cfg = cfg
self._client = httpx.AsyncClient(proxy=cfg.proxy, base_url='https://api.themoviedb.org', headers={'Authorization': f'Bearer {cfg.tmdb_api_key}'})
async def get_series_info(self, tmdb_id: int):
return (await self._client.get(f'/3/tv/{tmdb_id}', params=self.get_params())).json()
async def get_series_episodes(self, tmdb_id: int, season_number: int):
return (await self._client.get(f'/3/tv/{tmdb_id}/season/{season_number}', params=self.get_params())).json()
def get_params(self):
params = {}
language = self._cfg.tmdb_language
if language:
params['language'] = language
return params
async def get_tmdb_data(self, tmdb_id: int, season_number: int = None) -> TmdbData:
series_info = await self.get_series_info(tmdb_id)
seasons_info = {}
for season in series_info["seasons"]:
nseason_number = season["season_number"]
if season_number is not None and nseason_number != 0 and nseason_number != season_number:
continue
seasons_info[nseason_number] = await self.get_series_episodes(tmdb_id, nseason_number)
return TmdbData(series_info=series_info, seasons_info=seasons_info)