diff --git a/py-kms/pykms_Base.py b/py-kms/pykms_Base.py index e0e9a6b..75ca3a8 100644 --- a/py-kms/pykms_Base.py +++ b/py-kms/pykms_Base.py @@ -193,16 +193,17 @@ could be detected as not genuine !{end}" %currentClientCount) infoDict = { "machineName" : kmsRequest.getMachineName(), "clientMachineId" : str(clientMachineId), - "appId" : appName, + "applicationId" : appName, "skuId" : skuName, "licenseStatus" : kmsRequest.getLicenseStatus(), - "requestTime" : int(time.time()), + "lastRequestIP" : self.srv_config['raddr'][0], # (ip, port) + "lastRequestTime" : int(time.time()), "kmsEpid" : None } loggersrv.info("Machine Name: %s" % infoDict["machineName"]) loggersrv.info("Client Machine ID: %s" % infoDict["clientMachineId"]) - loggersrv.info("Application ID: %s" % infoDict["appId"]) + loggersrv.info("Application ID: %s" % infoDict["applicationId"]) loggersrv.info("SKU ID: %s" % infoDict["skuId"]) loggersrv.info("License Status: %s" % infoDict["licenseStatus"]) loggersrv.info("Request Time: %s" % local_dt.strftime('%Y-%m-%d %H:%M:%S %Z (UTC%z)')) @@ -211,7 +212,7 @@ could be detected as not genuine !{end}" %currentClientCount) loggersrv.mininfo("", extra = {'host': str(self.srv_config['raddr']), 'status' : infoDict["licenseStatus"], 'product' : infoDict["skuId"]}) - # Create database. + # Send change to database. if self.srv_config['sqlite']: sql_update(self.srv_config['sqlite'], infoDict) diff --git a/py-kms/pykms_Sql.py b/py-kms/pykms_Sql.py index 3ed7bfd..1fd2fbc 100644 --- a/py-kms/pykms_Sql.py +++ b/py-kms/pykms_Sql.py @@ -1,22 +1,37 @@ #!/usr/bin/env python3 -import datetime +from datetime import datetime import os import logging -# sqlite3 is optional. -try: - import sqlite3 -except ImportError: - pass - -from pykms_Format import pretty_printer - #-------------------------------------------------------------------------------------------------------------------------------------------------------- loggersrv = logging.getLogger('logsrv') +_column_name_to_index = { + 'clientMachineId': 0, + 'machineName': 1, + 'applicationId': 2, + 'skuId': 3, + 'licenseStatus': 4, + 'lastRequestTime': 5, + 'kmsEpid': 6, + 'requestCount': 7, + 'lastRequestIP': 8, +} + +# sqlite3 is optional. +available = False +try: + import sqlite3 + available = True +except ImportError: + pass def sql_initialize(dbName): + if available is False: + loggersrv.info("'sqlite3' module not found! SQLite database support cannot be enabled.") + return + loggersrv.debug(f'SQLite database support enabled. Database file: "{dbName}"') if not os.path.isfile(dbName): # Initialize the database. loggersrv.debug(f'Initializing database file "{dbName}"...') @@ -25,9 +40,37 @@ def sql_initialize(dbName): cur = con.cursor() cur.execute("CREATE TABLE clients(clientMachineId TEXT, machineName TEXT, applicationId TEXT, skuId TEXT, licenseStatus TEXT, lastRequestTime INTEGER, kmsEpid TEXT, requestCount INTEGER, PRIMARY KEY(clientMachineId, applicationId))") except sqlite3.Error as e: - pretty_printer(log_obj = loggersrv.error, to_exit = True, put_text = "{reverse}{red}{bold}Sqlite Error: %s. Exiting...{end}" %str(e)) + loggersrv.exception("Sqlite Error during database initialization!") + raise + if os.path.isfile(dbName): + # Update database + try: + with sqlite3.connect(dbName) as con: + cur = con.cursor() + # Create simple "metadata" table if not exists. + cur.execute("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT);") + # Get the current schema version + cur.execute("SELECT value FROM metadata WHERE key='schema_version';") + row = cur.fetchone() + if row is None: + current_version = 0 + else: + current_version = int(row[0]) + loggersrv.debug(f'Current database schema version: {current_version}') + # Apply necessary migrations + if current_version < 1: + # v1: Add "lastRequestIP" column to "clients" table. + loggersrv.info("Upgrading database schema to version 1...") + cur.execute("ALTER TABLE clients ADD COLUMN lastRequestIP TEXT;") + cur.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES ('schema_version', '1');") + loggersrv.info("Database schema updated to version 1.") + except sqlite3.Error as e: + loggersrv.exception("Sqlite Error during database upgrade!") + raise def sql_get_all(dbName): + if available is False: + return if not os.path.isfile(dbName): return None with sqlite3.connect(dbName) as con: @@ -35,81 +78,72 @@ def sql_get_all(dbName): cur.execute("SELECT * FROM clients") clients = [] for row in cur.fetchall(): - clients.append({ - 'clientMachineId': row[0], - 'machineName': row[1], - 'applicationId': row[2], - 'skuId': row[3], - 'licenseStatus': row[4], - 'lastRequestTime': datetime.datetime.fromtimestamp(row[5]).isoformat(), - 'kmsEpid': row[6], - 'requestCount': row[7] - }) + loggersrv.debug(f"Row: {row}") + obj = {} + for col_name, index in _column_name_to_index.items(): + if col_name == "lastRequestTime": + obj[col_name] = datetime.fromtimestamp(row[_column_name_to_index['lastRequestTime']]).isoformat() + else: + obj[col_name] = row[index] + loggersrv.debug(f"Obj: {obj}") + clients.append(obj) return clients def sql_update(dbName, infoDict): - con = None + if available is False: + return + + # make sure all column names are present + for col_name in _column_name_to_index.keys(): + if col_name in ["requestCount", "kmsEpid"]: + continue + if col_name not in infoDict: + raise ValueError(f"infoDict is missing required column: {col_name}") + try: - con = sqlite3.connect(dbName) - cur = con.cursor() - cur.execute("SELECT * FROM clients WHERE clientMachineId=:clientMachineId AND applicationId=:appId;", infoDict) - try: + with sqlite3.connect(dbName) as con: + cur = con.cursor() + cur.execute(f"SELECT {', '.join(_column_name_to_index.keys())} FROM clients WHERE clientMachineId=:clientMachineId AND applicationId=:applicationId;", infoDict) data = cur.fetchone() if not data: - # Insert row. - cur.execute("INSERT INTO clients (clientMachineId, machineName, applicationId, \ -skuId, licenseStatus, lastRequestTime, requestCount) VALUES (:clientMachineId, :machineName, :appId, :skuId, :licenseStatus, :requestTime, 1);", infoDict) - else: - # Update data. - if data[1] != infoDict["machineName"]: - cur.execute("UPDATE clients SET machineName=:machineName WHERE \ -clientMachineId=:clientMachineId AND applicationId=:appId;", infoDict) - if data[2] != infoDict["appId"]: - cur.execute("UPDATE clients SET applicationId=:appId WHERE \ -clientMachineId=:clientMachineId AND applicationId=:appId;", infoDict) - if data[3] != infoDict["skuId"]: - cur.execute("UPDATE clients SET skuId=:skuId WHERE \ -clientMachineId=:clientMachineId AND applicationId=:appId;", infoDict) - if data[4] != infoDict["licenseStatus"]: - cur.execute("UPDATE clients SET licenseStatus=:licenseStatus WHERE \ -clientMachineId=:clientMachineId AND applicationId=:appId;", infoDict) - if data[5] != infoDict["requestTime"]: - cur.execute("UPDATE clients SET lastRequestTime=:requestTime WHERE \ -clientMachineId=:clientMachineId AND applicationId=:appId;", infoDict) - # Increment requestCount - cur.execute("UPDATE clients SET requestCount=requestCount+1 WHERE \ -clientMachineId=:clientMachineId AND applicationId=:appId;", infoDict) + # Insert new row with all given info + infoDict["requestCount"] = 1 + cur.execute(f"""INSERT INTO clients ({', '.join(_column_name_to_index.keys())}) + VALUES ({', '.join(':' + col for col in _column_name_to_index.keys())});""", infoDict) - except sqlite3.Error as e: - pretty_printer(log_obj = loggersrv.error, to_exit = True, - put_text = "{reverse}{red}{bold}Sqlite Error: %s. Exiting...{end}" %str(e)) - except sqlite3.Error as e: - pretty_printer(log_obj = loggersrv.error, to_exit = True, - put_text = "{reverse}{red}{bold}Sqlite Error: %s. Exiting...{end}" %str(e)) - finally: - if con: - con.commit() - con.close() + else: + # Update only changed columns + common_postfix = "WHERE clientMachineId=:clientMachineId AND applicationId=:applicationId;" + def update_column_if_changed(column_name, new_value): + assert column_name in _column_name_to_index, f"Unknown column name: {column_name}" + assert "clientMachineId" in infoDict and "applicationId" in infoDict, "infoDict must contain 'clientMachineId' and 'applicationId'" + if data[_column_name_to_index[column_name]] != new_value: + query = f"UPDATE clients SET {column_name}=? {common_postfix}" + cur.execute(query, (new_value, infoDict['clientMachineId'], infoDict['applicationId'])) + + # Dynamically check and maybe up date all columns + for column_name in _column_name_to_index.keys(): + if column_name in ["clientMachineId", "applicationId", "requestCount"]: + continue # Skip these columns + if column_name == "kmsEpid": + # this one can only be updated by the special function + continue + update_column_if_changed(column_name, infoDict[column_name]) + + # Finally increment requestCount + cur.execute(f"UPDATE clients SET requestCount=requestCount+1 {common_postfix}", infoDict) + except sqlite3.Error: + loggersrv.exception("Sqlite Error during sql_update!") def sql_update_epid(dbName, kmsRequest, response, appName): - cmid = str(kmsRequest['clientMachineId'].get()) - con = None - try: - con = sqlite3.connect(dbName) - cur = con.cursor() - cur.execute("SELECT * FROM clients WHERE clientMachineId=? AND applicationId=?;", (cmid, appName)) - try: - data = cur.fetchone() - cur.execute("UPDATE clients SET kmsEpid=? WHERE \ -clientMachineId=? AND applicationId=?;", (str(response["kmsEpid"].decode('utf-16le')), cmid, appName)) + if available is False: + return - except sqlite3.Error as e: - pretty_printer(log_obj = loggersrv.error, to_exit = True, - put_text = "{reverse}{red}{bold}Sqlite Error: %s. Exiting...{end}" %str(e)) - except sqlite3.Error as e: - pretty_printer(log_obj = loggersrv.error, to_exit = True, - put_text = "{reverse}{red}{bold}Sqlite Error: %s. Exiting...{end}" %str(e)) - finally: - if con: - con.commit() - con.close() + cmid = str(kmsRequest['clientMachineId'].get()) + try: + with sqlite3.connect(dbName) as con: + cur = con.cursor() + cur.execute("UPDATE clients SET kmsEpid=? WHERE clientMachineId=? AND applicationId=?;", + (str(response["kmsEpid"].decode('utf-16le')), cmid, appName)) + except sqlite3.Error: + loggersrv.exception("Sqlite Error during sql_update_epid!")