# maubot - A plugin-based Matrix bot system. # Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from __future__ import annotations from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from collections import defaultdict import asyncio import inspect import io import logging import os.path from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap from mautrix.types import UserID from mautrix.util import background_task from mautrix.util.async_db import Database, Scheme, UpgradeTable from mautrix.util.async_getter_lock import async_getter_lock from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.util.logging import TraceLogger from .client import Client from .db import DatabaseEngine, Instance as DBInstance from .lib.optionalalchemy import Engine, MetaData, create_engine from .lib.plugin_db import ProxyPostgresDatabase from .loader import DatabaseType, PluginLoader, ZippedPluginLoader from .plugin_base import Plugin if TYPE_CHECKING: from .__main__ import Maubot from .server import PluginWebApp log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance")) db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db")) yaml = YAML() yaml.indent(4) yaml.width = 200 class PluginInstance(DBInstance): maubot: "Maubot" = None cache: dict[str, PluginInstance] = {} plugin_directories: list[str] = [] _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) log: logging.Logger loader: PluginLoader | None client: Client | None plugin: Plugin | None config: BaseProxyConfig | None base_cfg: RecursiveDict[CommentedMap] | None base_cfg_str: str | None inst_db: sql.engine.Engine | Database | None inst_db_tables: dict | None inst_webapp: PluginWebApp | None inst_webapp_url: str | None started: bool def __init__( self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = "", database_engine: DatabaseEngine | None = None, ) -> None: super().__init__( id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config, database_engine=database_engine, ) def __hash__(self) -> int: return hash(self.id) @classmethod def init_cls(cls, maubot: "Maubot") -> None: cls.maubot = maubot def postinit(self) -> None: self.log = log.getChild(self.id) self.cache[self.id] = self self.config = None self.started = False self.loader = None self.client = None self.plugin = None self.inst_db = None self.inst_db_tables = None self.inst_webapp = None self.inst_webapp_url = None self.base_cfg = None self.base_cfg_str = None def to_dict(self) -> dict: return { "id": self.id, "type": self.type, "enabled": self.enabled, "started": self.started, "primary_user": self.primary_user, "config": self.config_str, "base_config": self.base_cfg_str, "database": ( self.inst_db is not None and self.maubot.config["api_features.instance_database"] ), "database_interface": self.loader.meta.database_type_str if self.loader else "unknown", "database_engine": self.database_engine_str, } def _introspect_sqlalchemy(self) -> dict: metadata = MetaData() metadata.reflect(self.inst_db) return { table.name: { "columns": { column.name: { "type": str(column.type), "unique": column.unique or False, "default": column.default, "nullable": column.nullable, "primary": column.primary_key, } for column in table.columns }, } for table in metadata.tables.values() } async def _introspect_sqlite(self) -> dict: q = """ SELECT m.name AS table_name, p.cid AS col_id, p.name AS column_name, p.type AS data_type, p.pk AS is_primary, p.dflt_value AS column_default, p.[notnull] AS is_nullable FROM sqlite_master m LEFT JOIN pragma_table_info((m.name)) p WHERE m.type = 'table' ORDER BY table_name, col_id """ data = await self.inst_db.fetch(q) tables = defaultdict(lambda: {"columns": {}}) for column in data: table_name = column["table_name"] col_name = column["column_name"] tables[table_name]["columns"][col_name] = { "type": column["data_type"], "nullable": bool(column["is_nullable"]), "default": column["column_default"], "primary": bool(column["is_primary"]), # TODO uniqueness? } return tables async def _introspect_postgres(self) -> dict: assert isinstance(self.inst_db, ProxyPostgresDatabase) q = """ SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default, tc.constraint_type FROM information_schema.columns col LEFT JOIN information_schema.constraint_column_usage ccu ON ccu.column_name=col.column_name LEFT JOIN information_schema.table_constraints tc ON col.table_name=tc.table_name AND col.table_schema=tc.table_schema AND ccu.constraint_name=tc.constraint_name AND ccu.constraint_schema=tc.constraint_schema AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE') WHERE col.table_schema=$1 """ data = await self.inst_db.fetch(q, self.inst_db.schema_name) tables = defaultdict(lambda: {"columns": {}}) for column in data: table_name = column["table_name"] col_name = column["column_name"] tables[table_name]["columns"].setdefault( col_name, { "type": column["data_type"], "nullable": column["is_nullable"], "default": column["column_default"], "primary": False, "unique": False, }, ) if column["constraint_type"] == "PRIMARY KEY": tables[table_name]["columns"][col_name]["primary"] = True elif column["constraint_type"] == "UNIQUE": tables[table_name]["columns"][col_name]["unique"] = True return tables async def get_db_tables(self) -> dict: if self.inst_db_tables is None: if isinstance(self.inst_db, Engine): self.inst_db_tables = self._introspect_sqlalchemy() elif self.inst_db.scheme == Scheme.SQLITE: self.inst_db_tables = await self._introspect_sqlite() else: self.inst_db_tables = await self._introspect_postgres() return self.inst_db_tables async def load(self) -> bool: if not self.loader: try: self.loader = PluginLoader.find(self.type) except KeyError: self.log.error(f"Failed to find loader for type {self.type}") await self.update_enabled(False) return False if not self.client: self.client = await Client.get(self.primary_user) if not self.client: self.log.error(f"Failed to get client for user {self.primary_user}") await self.update_enabled(False) return False if self.loader.meta.webapp: self.enable_webapp() self.log.debug("Plugin instance dependencies loaded") self.loader.references.add(self) self.client.references.add(self) return True def enable_webapp(self) -> None: self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id) def disable_webapp(self) -> None: self.maubot.server.remove_instance_webapp(self.id) self.inst_webapp = None self.inst_webapp_url = None @property def _sqlite_db_path(self) -> str: return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db") async def delete(self) -> None: if self.loader is not None: self.loader.references.remove(self) if self.client is not None: self.client.references.remove(self) try: del self.cache[self.id] except KeyError: pass await super().delete() if self.inst_db: await self.stop_database() await self.delete_database() if self.inst_webapp: self.disable_webapp() def load_config(self) -> CommentedMap: return yaml.load(self.config_str) def save_config(self, data: RecursiveDict[CommentedMap]) -> None: buf = io.StringIO() yaml.dump(data, buf) val = buf.getvalue() if val != self.config_str: self.config_str = val self.log.debug("Creating background task to save updated config") background_task.create(self.update()) async def start_database( self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True ) -> None: if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: if self.database_engine is None: await self.update_db_engine(DatabaseEngine.SQLITE) elif self.database_engine == DatabaseEngine.POSTGRES: raise RuntimeError( "Instance database engine is marked as Postgres, but plugin uses legacy " "database interface, which doesn't support postgres." ) self.inst_db = create_engine(f"sqlite:///{self._sqlite_db_path}") elif self.loader.meta.database_type == DatabaseType.ASYNCPG: if self.database_engine is None: if os.path.exists(self._sqlite_db_path) or not self.maubot.plugin_postgres_db: await self.update_db_engine(DatabaseEngine.SQLITE) else: await self.update_db_engine(DatabaseEngine.POSTGRES) instance_db_log = db_log.getChild(self.id) if self.database_engine == DatabaseEngine.POSTGRES: if not self.maubot.plugin_postgres_db: raise RuntimeError( "Instance database engine is marked as Postgres, but this maubot isn't " "configured to support Postgres for plugin databases" ) self.inst_db = ProxyPostgresDatabase( pool=self.maubot.plugin_postgres_db, instance_id=self.id, max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"], upgrade_table=upgrade_table, log=instance_db_log, ) else: self.inst_db = Database.create( f"sqlite:{self._sqlite_db_path}", upgrade_table=upgrade_table, log=instance_db_log, ) if actually_start: await self.inst_db.start() else: raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") async def stop_database(self) -> None: if isinstance(self.inst_db, Database): await self.inst_db.stop() elif isinstance(self.inst_db, Engine): self.inst_db.dispose() else: raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}") async def delete_database(self) -> None: if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") elif self.loader.meta.database_type == DatabaseType.ASYNCPG: if self.inst_db is None: await self.start_database(None, actually_start=False) if isinstance(self.inst_db, ProxyPostgresDatabase): await self.inst_db.delete() else: ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") else: raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") self.inst_db = None async def start(self) -> None: if self.started: self.log.warning("Ignoring start() call to already started plugin") return elif not self.enabled: self.log.warning("Plugin disabled, not starting.") return if not self.client or not self.loader: self.log.warning("Missing plugin instance dependencies, attempting to load...") if not await self.load(): return cls = await self.loader.load() if self.loader.meta.webapp and self.inst_webapp is None: self.log.debug("Enabling webapp after plugin meta reload") self.enable_webapp() elif not self.loader.meta.webapp and self.inst_webapp is not None: self.log.debug("Disabling webapp after plugin meta reload") self.disable_webapp() if self.loader.meta.database: try: await self.start_database(cls.get_db_upgrade_table()) except Exception: self.log.exception("Failed to start instance database") await self.update_enabled(False) return config_class = cls.get_config_class() if config_class: try: base = await self.loader.read_file("base-config.yaml") self.base_cfg = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap) buf = io.StringIO() yaml.dump(self.base_cfg._data, buf) self.base_cfg_str = buf.getvalue() except (FileNotFoundError, KeyError): self.base_cfg = None self.base_cfg_str = None if self.base_cfg: base_cfg_func = self.base_cfg.clone else: def base_cfg_func() -> None: return None self.config = config_class(self.load_config, base_cfg_func, self.save_config) self.plugin = cls( client=self.client.client, loop=self.maubot.loop, http=self.client.http_client, instance_id=self.id, log=self.log, config=self.config, database=self.inst_db, loader=self.loader, webapp=self.inst_webapp, webapp_url=self.inst_webapp_url, ) try: await self.plugin.internal_start() except Exception: self.log.exception("Failed to start instance") await self.update_enabled(False) return self.started = True self.inst_db_tables = None self.log.info( f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} " f"with user {self.client.id}" ) async def stop(self) -> None: if not self.started: self.log.warning("Ignoring stop() call to non-running plugin") return self.log.debug("Stopping plugin instance...") self.started = False try: await self.plugin.internal_stop() except Exception: self.log.exception("Failed to stop instance") self.plugin = None if self.inst_db: try: await self.stop_database() except Exception: self.log.exception("Failed to stop instance database") self.inst_db_tables = None async def update_id(self, new_id: str | None) -> None: if new_id is not None and new_id.lower() != self.id: await super().update_id(new_id.lower()) async def update_config(self, config: str | None) -> None: if config is None or self.config_str == config: return self.config_str = config if self.started and self.plugin is not None: res = self.plugin.on_external_config_update() if inspect.isawaitable(res): await res await self.update() async def update_primary_user(self, primary_user: UserID | None) -> bool: if primary_user is None or primary_user == self.primary_user: return True client = await Client.get(primary_user) if not client: return False await self.stop() self.primary_user = client.id if self.client: self.client.references.remove(self) self.client = client self.client.references.add(self) await self.update() await self.start() self.log.debug(f"Primary user switched to {self.client.id}") return True async def update_type(self, type: str | None) -> bool: if type is None or type == self.type: return True try: loader = PluginLoader.find(type) except KeyError: return False await self.stop() self.type = loader.meta.id if self.loader: self.loader.references.remove(self) self.loader = loader self.loader.references.add(self) await self.update() await self.start() self.log.debug(f"Type switched to {self.loader.meta.id}") return True async def update_started(self, started: bool) -> None: if started is not None and started != self.started: await (self.start() if started else self.stop()) async def update_enabled(self, enabled: bool) -> None: if enabled is not None and enabled != self.enabled: self.enabled = enabled await self.update() async def update_db_engine(self, db_engine: DatabaseEngine | None) -> None: if db_engine is not None and db_engine != self.database_engine: self.database_engine = db_engine await self.update() @classmethod @async_getter_lock async def get( cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None ) -> PluginInstance | None: try: return cls.cache[instance_id] except KeyError: pass instance = cast(cls, await super().get(instance_id)) if instance is not None: instance.postinit() return instance if type and primary_user: instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user) await instance.insert() instance.postinit() return instance return None @classmethod async def all(cls) -> AsyncGenerator[PluginInstance, None]: instances = await super().all() instance: PluginInstance for instance in instances: try: yield cls.cache[instance.id] except KeyError: instance.postinit() yield instance