update package
This commit is contained in:
parent
1869ed4057
commit
da656d9582
1
pydb/.gitignore
vendored
1
pydb/.gitignore
vendored
@ -2,6 +2,7 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
config.py
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
@ -22,10 +22,13 @@ logger.addHandler(handler)
|
|||||||
|
|
||||||
def convert_query_result(query_res):
|
def convert_query_result(query_res):
|
||||||
"""Return List(elements) from query result List(Tuples(element))."""
|
"""Return List(elements) from query result List(Tuples(element))."""
|
||||||
return [i[0] if len(query_res) == 1 else list(i) for i in query_res]
|
if len(query_res) == 0:
|
||||||
|
logger.warning("Received empty result set from DB")
|
||||||
|
return []
|
||||||
|
return [list(i) for i in query_res] if len(query_res) > 1 else query_res[0][0]
|
||||||
|
|
||||||
|
|
||||||
class DB_WRAPPER:
|
class DbWrapper:
|
||||||
"""Base Class for DB Connection Wrapper"""
|
"""Base Class for DB Connection Wrapper"""
|
||||||
|
|
||||||
def __init__(self, info: DB_INFO, connector, exception=None, dictionary=False):
|
def __init__(self, info: DB_INFO, connector, exception=None, dictionary=False):
|
||||||
@ -44,10 +47,9 @@ class DB_WRAPPER:
|
|||||||
else:
|
else:
|
||||||
self._dictionary = False
|
self._dictionary = False
|
||||||
self._cur = self._conn.cursor()
|
self._cur = self._conn.cursor()
|
||||||
self._Exception = exception if exception is not None else Exception
|
self._exception = exception if exception is not None else Exception
|
||||||
except Exception as exception:
|
except self._exception as e:
|
||||||
logger.critical("Something went wrong connection to DB:", exception)
|
raise e
|
||||||
raise exception
|
|
||||||
|
|
||||||
def get_connection(self):
|
def get_connection(self):
|
||||||
"""Returns the connection object for the DB"""
|
"""Returns the connection object for the DB"""
|
||||||
@ -67,7 +69,7 @@ class DB_WRAPPER:
|
|||||||
Returns:
|
Returns:
|
||||||
[type]: [description]
|
[type]: [description]
|
||||||
"""
|
"""
|
||||||
return self._Exception
|
return self._exception
|
||||||
|
|
||||||
def query(self, stmt):
|
def query(self, stmt):
|
||||||
"""Queries the db with <stmt>. Returns list of tuples if there are results."""
|
"""Queries the db with <stmt>. Returns list of tuples if there are results."""
|
||||||
@ -75,7 +77,7 @@ class DB_WRAPPER:
|
|||||||
self._cur.execute(stmt)
|
self._cur.execute(stmt)
|
||||||
res = self._cur.fetchall()
|
res = self._cur.fetchall()
|
||||||
return convert_query_result(res) if not self._dictionary else res
|
return convert_query_result(res) if not self._dictionary else res
|
||||||
except self._Exception as exception:
|
except self._exception as exception:
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
def query_with_commit(self, stmt):
|
def query_with_commit(self, stmt):
|
||||||
@ -83,7 +85,7 @@ class DB_WRAPPER:
|
|||||||
try:
|
try:
|
||||||
self._cur.execute(stmt)
|
self._cur.execute(stmt)
|
||||||
self._cur.commit()
|
self._cur.commit()
|
||||||
except self._Exception as exception:
|
except self._exception as exception:
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
def execute(self, stmt):
|
def execute(self, stmt):
|
||||||
@ -96,7 +98,7 @@ class DB_WRAPPER:
|
|||||||
if not self._dictionary
|
if not self._dictionary
|
||||||
else self._cur.fetchall()
|
else self._cur.fetchall()
|
||||||
)
|
)
|
||||||
except self._Exception as e:
|
except self._exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@ -105,12 +107,12 @@ class DB_WRAPPER:
|
|||||||
self._conn.close()
|
self._conn.close()
|
||||||
|
|
||||||
|
|
||||||
class MysqlDB(DB_WRAPPER):
|
class MysqlDB(DbWrapper):
|
||||||
"""Mysql Specific Functions"""
|
"""Mysql Specific Functions"""
|
||||||
|
|
||||||
def __init__(self, info, dictionary=False):
|
def __init__(self, info, dictionary=False):
|
||||||
"""MySQL Connection Wrapper"""
|
"""MySQL Connection Wrapper"""
|
||||||
DB_WRAPPER.__init__(
|
DbWrapper.__init__(
|
||||||
self,
|
self,
|
||||||
info,
|
info,
|
||||||
mysql.connector.connect,
|
mysql.connector.connect,
|
||||||
@ -152,7 +154,8 @@ class MysqlDB(DB_WRAPPER):
|
|||||||
def table_exists(self, schema: str, table: str):
|
def table_exists(self, schema: str, table: str):
|
||||||
stmt = f"""
|
stmt = f"""
|
||||||
SELECT COUNT(*) from information_schema.TABLES
|
SELECT COUNT(*) from information_schema.TABLES
|
||||||
WHERE TABLE_SCHEMA = '{schema}' and TABLE_NAME = '{table}'
|
WHERE UPPER(TABLE_SCHEMA) = UPPER('{schema}')
|
||||||
|
AND UPPER(TABLE_NAME) = UPPER('{table}')
|
||||||
"""
|
"""
|
||||||
return self.query(stmt)[0][0] != 0
|
return self.query(stmt)[0][0] != 0
|
||||||
|
|
||||||
@ -196,9 +199,15 @@ class DatabaseManager:
|
|||||||
def db_factory(
|
def db_factory(
|
||||||
db_info: Union[Dict[str, int], Dict[str, str]], db_type: str, dictionary=False
|
db_info: Union[Dict[str, int], Dict[str, str]], db_type: str, dictionary=False
|
||||||
):
|
):
|
||||||
|
"""Build the correct db wrapper for the db_type."""
|
||||||
db_type = db_type.strip().lower()
|
db_type = db_type.strip().lower()
|
||||||
if db_type == "mysql":
|
if db_type == "mysql":
|
||||||
return MysqlDB(db_info, dictionary=dictionary)
|
dbo = MysqlDB(db_info, dictionary=dictionary)
|
||||||
|
else:
|
||||||
|
logger.error("ERROR %s not valid", db_type)
|
||||||
|
logger.error("Valid types: [ mysql ]")
|
||||||
|
sys.exit(1)
|
||||||
|
return dbo
|
||||||
# elif db_type == "snowflake":
|
# elif db_type == "snowflake":
|
||||||
# return SnowflakeWrapper(
|
# return SnowflakeWrapper(
|
||||||
# db_info, snowflake.connector.connect, snowflake.connector.Error
|
# db_info, snowflake.connector.connect, snowflake.connector.Error
|
||||||
@ -211,7 +220,3 @@ def db_factory(
|
|||||||
# )
|
# )
|
||||||
# elif db_type == 'mariadb':
|
# elif db_type == 'mariadb':
|
||||||
# return _MariaDB(db_info)
|
# return _MariaDB(db_info)
|
||||||
else:
|
|
||||||
logger.error("ERROR %s not valid", db_type)
|
|
||||||
logger.error("Valid types: [ mysql ]")
|
|
||||||
sys.exit(1)
|
|
||||||
|
@ -28,4 +28,4 @@ def test_mysql_2():
|
|||||||
db = DatabaseManager(config.MYSQL_INFO, "mysql").__enter__()
|
db = DatabaseManager(config.MYSQL_INFO, "mysql").__enter__()
|
||||||
res = db.query("SELECT COUNT(*) FROM dir_map")
|
res = db.query("SELECT COUNT(*) FROM dir_map")
|
||||||
logger.info("Result: %s", res)
|
logger.info("Result: %s", res)
|
||||||
assert res is not None and res[0] > 0
|
assert res is not None and res > 0
|
||||||
|
Loading…
Reference in New Issue
Block a user