Add support for customizable file extensions in config and file processing

This commit is contained in:
2025-01-20 13:19:53 +08:00
parent 2a2e4f75fa
commit a10ce77afc
4 changed files with 40 additions and 9 deletions

View File

@@ -7,7 +7,7 @@ import asyncio
async def main():
cfg = load_config()
files = gen_input_list(cfg.input)
files = gen_input_list(cfg.input, cfg.exts)
tmdb_data = None
if cfg.tmdb_id and not cfg.no_tmdb:
if not cfg.tmdb_api_key:

View File

@@ -1,9 +1,13 @@
import argparse
import openai
import os.path
from typing import Optional
from typing import List, Optional
import yaml
DEFAULT_EXTS = ['.mp4', '.mkv', '.ass', '.srt']
class Config:
def __init__(self, args: argparse.Namespace, yaml_config: dict):
self._args = args
@@ -20,6 +24,21 @@ class Config:
def base_url(self) -> Optional[str]:
return self._args.base_url if self._args.base_url is not None else self._yaml_config.get('base_url', 'https://api.openai.com/v1')
@property
def exts(self) -> List[str]:
if self._args.exts:
return self._args.exts
exts = self._yaml_config.get('exts')
if isinstance(exts, str):
return exts.split(";")
if isinstance(exts, list):
aexts = []
for ext in exts:
if isinstance(ext, str):
aexts.append(ext)
return aexts
return DEFAULT_EXTS
@property
def hardlink(self) -> bool:
return self._args.hardlink
@@ -91,6 +110,7 @@ def get_arg_parser() -> argparse.ArgumentParser:
parser.add_argument('-H', '--hardlink', action='store_true', help='Use hardlink instead of symlink.')
parser.add_argument('-n', '--no-tmdb', action='store_true', help='Do not use TMDB API to obtain data.')
parser.add_argument('-S', '--season-number', type=int, help='Season number (optional)')
parser.add_argument('-e', '--exts', action='append', help='File extensions to process (optional)')
parser.add_argument('input', help='Input directory.')
parser.add_argument('output', help='Output directory.')
return parser

View File

@@ -4,10 +4,7 @@ from typing import List
from .gpt import Files
EXTS = ['.mp4', '.mkv', '.ass', '.srt']
def gen_input_list(dir: str, prefix: str = None) -> List[str]:
def gen_input_list(dir: str, exts: List[str], prefix: str = None) -> List[str]:
if prefix is None:
prefix = dir
re = []
@@ -16,11 +13,11 @@ def gen_input_list(dir: str, prefix: str = None) -> List[str]:
continue
path = os.path.join(dir, f)
if os.path.isdir(path):
data = gen_input_list(path, prefix=prefix)
data = gen_input_list(path, exts, prefix)
re += data
else:
exts = os.path.splitext(f)[1]
if exts not in EXTS:
ext = os.path.splitext(f)[1]
if ext not in exts:
continue
p = os.path.relpath(path, prefix)
re.append(p)