import argparse
import sys
import time
-from typing import Sequence
from urllib.error import HTTPError, URLError
import argcomplete
return arg_parser
-def main(args: Sequence[str] = None) -> int:
- """Main method for CLI"""
+def main(args: list = None) -> int:
+ """
+ Main method for CLI
+
+ @param args: List[str] | None
+ """
start_time = time.time()
arg_parser = build_argparser()
"""Main SQL persistence module, need to rethink circular imports and shared code"""
import sqlite3
-from typing import Union
# FIXME: maybe just use separate methods for calls with vs. without headers
# avoid the mypy headaches, and the liberal comments # type: ignore
-def sql_entries(sql_result: sqlite3.Cursor, headers=False) -> Union[list, tuple]:
+def sql_entries(sql_result: sqlite3.Cursor) -> list:
"""Formats and returns a `sql_result()` for console digestion and output"""
# TODO: return object: metadata, command, status, errors, etc?
rows = sql_result.fetchall()
- if headers:
- headers = [x[0] for x in sql_result.description]
- return headers, rows
-
return rows
+def sql_entries_headers(sql_result: sqlite3.Cursor) -> tuple:
+ """Formats and returns a `sql_result()` for console digestion and output"""
+ rows = sql_result.fetchall()
+ headers = [x[0] for x in sql_result.description]
+
+ return headers, rows
+
+
def version(con: sqlite3.Connection) -> str:
"""Gets the latest entry from version table"""
con.close()
-def _sql(
- con: sqlite3.Connection,
- query: str,
- db_name: str,
- values: Union[tuple, list] = None,
- headers=False,
-) -> Union[list, tuple]:
+def _prep_query(
+ con: sqlite3.Connection, query: str, db_name: str, values=None
+) -> tuple:
+ """@param values: tuple | list | None"""
+
from ntclient import DEBUG # pylint: disable=import-outside-toplevel
cur = con.cursor()
if DEBUG:
print("%s.sqlite3: %s" % (db_name, query))
if values:
- # TODO: better debug logging, more "control-findable", distinguish from most prints()
+ # TODO: better debug logging, more "control-findable",
+ # distinguish from most prints()
print(values)
# TODO: separate `entry` & `entries` entity for single vs. bulk insert?
if values:
if isinstance(values, list):
- rows = cur.executemany(query, values)
+ result = cur.executemany(query, values)
else: # tuple
- rows = cur.execute(query, values)
+ result = cur.execute(query, values)
else:
- rows = cur.execute(query)
+ result = cur.execute(query)
+
+ return cur, result
+
+
+def _sql(
+ con: sqlite3.Connection,
+ query: str,
+ db_name: str,
+ values=None,
+) -> list:
+ """@param values: tuple | list | None"""
+
+ cur, result = _prep_query(con, query, db_name, values)
+
+ # TODO: print "<number> SELECTED", or other info
+ # BASED ON command SELECT/INSERT/DELETE/UPDATE
+ result = sql_entries(cur)
+
+ close_con_and_cur(con, cur)
+ return result
+
+
+def _sql_headers(
+ con: sqlite3.Connection,
+ query: str,
+ db_name: str,
+ values=None,
+) -> tuple:
+ """@param values: tuple | list | None"""
+
+ cur, result = _prep_query(con, query, db_name, values)
+
+ result = sql_entries_headers(result)
- # TODO: print "<number> SELECTED", or other info BASED ON command SELECT/INSERT/DELETE/UPDATE
- result = sql_entries(rows, headers=headers)
close_con_and_cur(con, cur)
return result
import sqlite3
from ntclient import NT_DB_NAME, NUTRA_DIR, __db_target_nt__
-from ntclient.persistence.sql import _sql, version
+from ntclient.persistence.sql import _sql, _sql_headers, version
from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError
-def nt_sqlite_connect(version_check=True):
+def nt_sqlite_connect(version_check=True) -> sqlite3.Connection:
"""Connects to the nt.sqlite3 file, or throws an exception"""
db_path = os.path.join(NUTRA_DIR, NT_DB_NAME)
if os.path.isfile(db_path):
raise SqlConnectError("ERROR: nt database doesn't exist, please run `nutra init`")
-def nt_ver():
+def nt_ver() -> str:
"""Gets version string for nt.sqlite3 database"""
con = nt_sqlite_connect(version_check=False)
return version(con)
-def sql(query, values=None, headers=False):
+def sql(query, values=None) -> list:
"""Executes a SQL command to nt.sqlite3"""
con = nt_sqlite_connect()
- return _sql(con, query, db_name="nt", values=values, headers=headers)
+ return _sql(con, query, db_name="nt", values=values)
+
+
+def sql_headers(query, values=None) -> tuple:
+ """Executes a SQL command to nt.sqlite3"""
+ con = nt_sqlite_connect()
+ return _sql_headers(con, query, db_name="nt", values=values)
"""nt.sqlite3 functions module"""
-from ntclient.persistence.sql.nt import sql
+from ntclient.persistence.sql.nt import sql, sql_headers
def sql_nt_next_index(table=None):
GROUP BY
id;
"""
- return sql(query, headers=True)
+ return sql_headers(query)
def sql_analyze_recipe(recipe_id):
import sqlite3
import tarfile
import urllib.request
-from typing import Union
from ntclient import NUTRA_DIR, USDA_DB_NAME, __db_target_usda__
-from ntclient.persistence.sql import _sql, version
+from ntclient.persistence.sql import _sql, _sql_headers, version
from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError
return version(con)
-def sql(query, values=None, headers=False, version_check=True) -> Union[list, tuple]:
+def sql(query, values=None, version_check=True) -> list:
"""Executes a SQL command to usda.sqlite3"""
con = usda_sqlite_connect(version_check=version_check)
# TODO: support argument: _sql(..., params=params, ...)
- return _sql(con, query, db_name="usda", values=values, headers=headers)
+ return _sql(con, query, db_name="usda", values=values)
+
+
+def sql_headers(query, values=None, version_check=True) -> tuple:
+ """Executes a SQL command to usda.sqlite3 [WITH HEADERS]"""
+
+ con = usda_sqlite_connect(version_check=version_check)
+
+ # TODO: support argument: _sql(..., params=params, ...)
+ return _sql_headers(con, query, db_name="usda", values=values)
"""usda.sqlite functions module"""
-from ntclient.persistence.sql.usda import sql
+from ntclient.persistence.sql.usda import sql, sql_headers
from ntclient.utils import NUTR_ID_KCAL
food_ids = ",".join(str(x) for x in set(food_ids))
query = query % food_ids
- return sql(query) # type: ignore
+ return sql(query)
def sql_nutrients_overview() -> dict:
"""Shows nutrients 'details'"""
query = "SELECT * FROM nutrients_overview;"
- return sql(query, headers=True) # type: ignore
+ return sql_headers(query)
def sql_servings(food_ids) -> list:
serv.food_id IN (%s);
"""
food_ids = ",".join(str(x) for x in set(food_ids))
- return sql(query % food_ids) # type: ignore
+ return sql(query % food_ids)
def sql_analyze_foods(food_ids) -> list:
"""
# TODO: parameterized queries
food_ids = ",".join(str(x) for x in set(food_ids))
- return sql(query % food_ids) # type: ignore
+ return sql(query % food_ids)
################################################################################
food_id;
"""
- return sql(query % (NUTR_ID_KCAL, nutrient_id)) # type: ignore
+ return sql(query % (NUTR_ID_KCAL, nutrient_id))
def sql_sort_foods(nutr_id) -> list:
nut_data.nutr_val DESC;
"""
- return sql(query % nutr_id) # type: ignore
+ return sql(query % nutr_id)
def sql_sort_foods_by_kcal(nutr_id) -> list:
(nut_data.nutr_val / kcal.nutr_val) DESC;
"""
- return sql(query % nutr_id) # type: ignore
+ return sql(query % nutr_id)