378 lines
11 KiB
Python
Raw Permalink Normal View History

2025-03-17 14:58:05 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库连接配置模块
包含数据库连接参数和SQLAlchemy配置以及Redis缓存配置
"""
import os
import json
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
2025-03-18 18:34:03 +08:00
import traceback
import sys
2025-03-17 14:58:05 +08:00
class ConfigDict:
"""配置字典类,支持通过点号访问配置项"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
if isinstance(value, dict):
setattr(self, key, ConfigDict(**value))
else:
setattr(self, key, value)
def get(self, key, default=None):
return getattr(self, key, default)
def to_dict(self):
result = {}
for key, value in self.__dict__.items():
if isinstance(value, ConfigDict):
result[key] = value.to_dict()
else:
result[key] = value
return result
# 数据库连接配置
DB_CONFIG = ConfigDict(
default=dict(
dialect='mysql',
driver='pymysql',
username='root',
password='root',
host='localhost',
port=3306,
database='tianfeng_task',
charset='utf8mb4'
),
test=dict(
dialect='sqlite',
database=':memory:'
)
)
# Redis缓存配置
REDIS_CONFIG = ConfigDict(
default=dict(
host='localhost',
port=6379,
db=0,
password=None,
prefix='tianfeng:',
socket_timeout=5,
socket_connect_timeout=5,
decode_responses=True
),
test=dict(
host='localhost',
port=6379,
db=1,
password=None,
prefix='tianfeng_test:',
decode_responses=True
)
)
# 当前环境,可通过环境变量设置
ENV = os.environ.get('TIANFENG_ENV', 'default')
# 根据环境获取数据库配置
db_conf = getattr(DB_CONFIG, ENV)
# 构建数据库连接URL
if db_conf.dialect == 'sqlite':
DATABASE_URL = f"{db_conf.dialect}:///{db_conf.database}"
else:
DATABASE_URL = (
f"{db_conf.dialect}+{db_conf.driver}://"
f"{db_conf.username}:{db_conf.password}@"
f"{db_conf.host}:{db_conf.port}/{db_conf.database}?"
f"charset={db_conf.charset}"
)
# 创建数据库引擎
engine = create_engine(
DATABASE_URL,
pool_size=20,
max_overflow=0,
pool_recycle=3600,
pool_pre_ping=True,
echo=False # 设置为True可以显示SQL语句用于调试
)
# 创建会话工厂
SessionFactory = sessionmaker(bind=engine)
# 创建线程安全的会话
db_session = scoped_session(SessionFactory)
# 创建基类
Base = declarative_base()
Base.query = db_session.query_property()
# 数据库配置类
class DBConfig:
"""数据库配置类,提供数据库相关的配置和方法"""
config = DB_CONFIG
env = ENV
url = DATABASE_URL
engine = engine
session = db_session
base = Base
@classmethod
def get_config(cls):
"""获取当前环境的数据库配置"""
return getattr(cls.config, cls.env)
@classmethod
def get_session(cls):
"""获取数据库会话"""
return cls.session
@classmethod
def init_db(cls):
"""
初始化数据库
创建所有表
"""
2025-03-18 18:34:03 +08:00
# 测试数据库连接
try:
print(f"尝试连接数据库: {cls.url}")
connection = cls.engine.connect()
print("数据库连接成功!")
connection.close()
except Exception as e:
print(f"数据库连接失败: {str(e)}")
print("详细错误信息:")
traceback.print_exc(file=sys.stdout)
raise
2025-03-17 14:58:05 +08:00
# 导入所有模型确保它们已注册到Base
import data.models
# 首先尝试创建数据库(如果不存在)
if cls.get_config().dialect != 'sqlite':
from sqlalchemy import text
# 创建一个不指定数据库的连接
db_conf = cls.get_config()
2025-03-18 18:34:03 +08:00
try:
print(f"尝试创建数据库 {db_conf.database} (如果不存在)")
temp_url = (
f"{db_conf.dialect}+{db_conf.driver}://"
f"{db_conf.username}:{db_conf.password}@"
f"{db_conf.host}:{db_conf.port}/"
f"?charset={db_conf.charset}"
)
print(f"临时连接URL: {temp_url}")
temp_engine = create_engine(temp_url)
with temp_engine.connect() as conn:
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {db_conf.database} CHARACTER SET {db_conf.charset} COLLATE {db_conf.charset}_unicode_ci;"))
conn.commit()
temp_engine.dispose()
print(f"数据库 {db_conf.database} 创建或已存在")
except Exception as e:
print(f"创建数据库失败: {str(e)}")
print("详细错误信息:")
traceback.print_exc(file=sys.stdout)
raise
2025-03-17 14:58:05 +08:00
# 创建所有表
2025-03-18 18:34:03 +08:00
try:
print("开始创建所有表...")
cls.base.metadata.create_all(bind=cls.engine)
print("所有表创建成功")
except Exception as e:
print(f"创建表失败: {str(e)}")
print("详细错误信息:")
traceback.print_exc(file=sys.stdout)
raise
2025-03-17 14:58:05 +08:00
@classmethod
def shutdown_session(cls, exception=None):
"""
关闭会话
在应用程序关闭时调用
"""
cls.session.remove()
# 缓存配置类
class CacheConfig:
"""缓存配置类提供Redis缓存相关的配置和方法"""
config = REDIS_CONFIG
env = ENV
_redis_client = None
@classmethod
def get_config(cls):
"""获取当前环境的Redis配置"""
return getattr(cls.config, cls.env)
@classmethod
def get_redis_client(cls):
"""获取Redis客户端实例"""
if cls._redis_client is None:
try:
import redis
redis_conf = cls.get_config()
cls._redis_client = redis.Redis(
host=redis_conf.host,
port=redis_conf.port,
db=redis_conf.db,
password=redis_conf.password,
socket_timeout=getattr(redis_conf, 'socket_timeout', 5),
socket_connect_timeout=getattr(redis_conf, 'socket_connect_timeout', 5),
decode_responses=getattr(redis_conf, 'decode_responses', True)
)
except ImportError:
raise ImportError("Redis package is not installed. Please install it with 'pip install redis'")
except Exception as e:
print(f"Error connecting to Redis: {e}")
return None
return cls._redis_client
@classmethod
def get_key(cls, key):
"""获取带前缀的缓存键"""
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
return f"{prefix}{key}"
@classmethod
def set(cls, key, value, expire=None):
"""
设置缓存
Args:
key (str): 缓存键
value (any): 缓存值非字符串类型会被JSON序列化
expire (int, optional): 过期时间
Returns:
bool: 是否设置成功
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
if not isinstance(value, (str, int, float, bool)):
value = json.dumps(value)
full_key = cls.get_key(key)
if expire:
return redis_client.setex(full_key, expire, value)
else:
return redis_client.set(full_key, value)
@classmethod
def get(cls, key, default=None):
"""
获取缓存
Args:
key (str): 缓存键
default (any, optional): 默认值
Returns:
any: 缓存值或默认值
"""
redis_client = cls.get_redis_client()
if not redis_client:
return default
full_key = cls.get_key(key)
value = redis_client.get(full_key)
if value is None:
return default
# 尝试解析JSON
try:
if value.startswith('{') or value.startswith('['):
return json.loads(value)
except (json.JSONDecodeError, AttributeError):
pass
return value
@classmethod
def delete(cls, key):
"""
删除缓存
Args:
key (str): 缓存键
Returns:
bool: 是否删除成功
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
full_key = cls.get_key(key)
return redis_client.delete(full_key) > 0
@classmethod
def exists(cls, key):
"""
检查缓存是否存在
Args:
key (str): 缓存键
Returns:
bool: 是否存在
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
full_key = cls.get_key(key)
return redis_client.exists(full_key) > 0
@classmethod
def ttl(cls, key):
"""
获取缓存剩余过期时间
Args:
key (str): 缓存键
Returns:
int: 剩余秒数-1表示永不过期-2表示不存在
"""
redis_client = cls.get_redis_client()
if not redis_client:
return -2
full_key = cls.get_key(key)
return redis_client.ttl(full_key)
@classmethod
def clear_all(cls):
"""
清除当前环境下的所有缓存
Returns:
bool: 是否清除成功
"""
redis_client = cls.get_redis_client()
if not redis_client:
return False
prefix = getattr(cls.get_config(), 'prefix', 'tianfeng:')
keys = redis_client.keys(f"{prefix}*")
if keys:
return redis_client.delete(*keys) > 0
return True
# 兼容旧代码的函数
def init_db():
"""初始化数据库(兼容旧代码)"""
DBConfig.init_db()
def shutdown_session(exception=None):
"""关闭会话(兼容旧代码)"""
DBConfig.shutdown_session(exception)