update package

This commit is contained in:
ksyasuda 2022-10-20 00:05:58 -07:00
parent 1869ed4057
commit da656d9582
3 changed files with 25 additions and 19 deletions

1
pydb/.gitignore vendored
View File

@ -2,6 +2,7 @@
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
config.py
# C extensions # C extensions
*.so *.so

View File

@ -22,10 +22,13 @@ logger.addHandler(handler)
def convert_query_result(query_res): def convert_query_result(query_res):
"""Return List(elements) from query result List(Tuples(element)).""" """Return List(elements) from query result List(Tuples(element))."""
return [i[0] if len(query_res) == 1 else list(i) for i in query_res] if len(query_res) == 0:
logger.warning("Received empty result set from DB")
return []
return [list(i) for i in query_res] if len(query_res) > 1 else query_res[0][0]
class DB_WRAPPER: class DbWrapper:
"""Base Class for DB Connection Wrapper""" """Base Class for DB Connection Wrapper"""
def __init__(self, info: DB_INFO, connector, exception=None, dictionary=False): def __init__(self, info: DB_INFO, connector, exception=None, dictionary=False):
@ -44,10 +47,9 @@ class DB_WRAPPER:
else: else:
self._dictionary = False self._dictionary = False
self._cur = self._conn.cursor() self._cur = self._conn.cursor()
self._Exception = exception if exception is not None else Exception self._exception = exception if exception is not None else Exception
except Exception as exception: except self._exception as e:
logger.critical("Something went wrong connection to DB:", exception) raise e
raise exception
def get_connection(self): def get_connection(self):
"""Returns the connection object for the DB""" """Returns the connection object for the DB"""
@ -67,7 +69,7 @@ class DB_WRAPPER:
Returns: Returns:
[type]: [description] [type]: [description]
""" """
return self._Exception return self._exception
def query(self, stmt): def query(self, stmt):
"""Queries the db with <stmt>. Returns list of tuples if there are results.""" """Queries the db with <stmt>. Returns list of tuples if there are results."""
@ -75,7 +77,7 @@ class DB_WRAPPER:
self._cur.execute(stmt) self._cur.execute(stmt)
res = self._cur.fetchall() res = self._cur.fetchall()
return convert_query_result(res) if not self._dictionary else res return convert_query_result(res) if not self._dictionary else res
except self._Exception as exception: except self._exception as exception:
raise exception raise exception
def query_with_commit(self, stmt): def query_with_commit(self, stmt):
@ -83,7 +85,7 @@ class DB_WRAPPER:
try: try:
self._cur.execute(stmt) self._cur.execute(stmt)
self._cur.commit() self._cur.commit()
except self._Exception as exception: except self._exception as exception:
raise exception raise exception
def execute(self, stmt): def execute(self, stmt):
@ -96,7 +98,7 @@ class DB_WRAPPER:
if not self._dictionary if not self._dictionary
else self._cur.fetchall() else self._cur.fetchall()
) )
except self._Exception as e: except self._exception as e:
raise e raise e
def close(self): def close(self):
@ -105,12 +107,12 @@ class DB_WRAPPER:
self._conn.close() self._conn.close()
class MysqlDB(DB_WRAPPER): class MysqlDB(DbWrapper):
"""Mysql Specific Functions""" """Mysql Specific Functions"""
def __init__(self, info, dictionary=False): def __init__(self, info, dictionary=False):
"""MySQL Connection Wrapper""" """MySQL Connection Wrapper"""
DB_WRAPPER.__init__( DbWrapper.__init__(
self, self,
info, info,
mysql.connector.connect, mysql.connector.connect,
@ -152,7 +154,8 @@ class MysqlDB(DB_WRAPPER):
def table_exists(self, schema: str, table: str): def table_exists(self, schema: str, table: str):
stmt = f""" stmt = f"""
SELECT COUNT(*) from information_schema.TABLES SELECT COUNT(*) from information_schema.TABLES
WHERE TABLE_SCHEMA = '{schema}' and TABLE_NAME = '{table}' WHERE UPPER(TABLE_SCHEMA) = UPPER('{schema}')
AND UPPER(TABLE_NAME) = UPPER('{table}')
""" """
return self.query(stmt)[0][0] != 0 return self.query(stmt)[0][0] != 0
@ -196,9 +199,15 @@ class DatabaseManager:
def db_factory( def db_factory(
db_info: Union[Dict[str, int], Dict[str, str]], db_type: str, dictionary=False db_info: Union[Dict[str, int], Dict[str, str]], db_type: str, dictionary=False
): ):
"""Build the correct db wrapper for the db_type."""
db_type = db_type.strip().lower() db_type = db_type.strip().lower()
if db_type == "mysql": if db_type == "mysql":
return MysqlDB(db_info, dictionary=dictionary) dbo = MysqlDB(db_info, dictionary=dictionary)
else:
logger.error("ERROR %s not valid", db_type)
logger.error("Valid types: [ mysql ]")
sys.exit(1)
return dbo
# elif db_type == "snowflake": # elif db_type == "snowflake":
# return SnowflakeWrapper( # return SnowflakeWrapper(
# db_info, snowflake.connector.connect, snowflake.connector.Error # db_info, snowflake.connector.connect, snowflake.connector.Error
@ -211,7 +220,3 @@ def db_factory(
# ) # )
# elif db_type == 'mariadb': # elif db_type == 'mariadb':
# return _MariaDB(db_info) # return _MariaDB(db_info)
else:
logger.error("ERROR %s not valid", db_type)
logger.error("Valid types: [ mysql ]")
sys.exit(1)

View File

@ -28,4 +28,4 @@ def test_mysql_2():
db = DatabaseManager(config.MYSQL_INFO, "mysql").__enter__() db = DatabaseManager(config.MYSQL_INFO, "mysql").__enter__()
res = db.query("SELECT COUNT(*) FROM dir_map") res = db.query("SELECT COUNT(*) FROM dir_map")
logger.info("Result: %s", res) logger.info("Result: %s", res)
assert res is not None and res[0] > 0 assert res is not None and res > 0