from datetime import datetime import logging import typing as t import sys import os import re from importlib import import_module from functools import wraps try: from functools import cached_property except ImportError: from cached_property import cached_property from app.classes.shared.helpers import helper from app.classes.shared.console import console logger = logging.getLogger(__name__) try: import peewee from playhouse.migrate import ( SchemaMigrator as ScM, SqliteMigrator as SqM, Operation, SQL, operation, SqliteDatabase, make_index_name, Context ) except ModuleNotFoundError as e: logger.critical("Import Error: Unable to load {} module".format( e.name), exc_info=True) console.critical("Import Error: Unable to load {} module".format(e.name)) sys.exit(1) class MigrateHistory(peewee.Model): """ Presents the migration history in a database. """ name = peewee.CharField(unique=True) migrated_at = peewee.DateTimeField(default=datetime.utcnow) def __unicode__(self) -> str: """ String representation of this migration """ return self.name MIGRATE_TABLE = 'migratehistory' MIGRATE_TEMPLATE = '''# Generated by database migrator def migrate(migrator, database, **kwargs): """ Write your migrations here. """ {migrate} def rollback(migrator, database, **kwargs): """ Write your rollback migrations here. """ {rollback}''' VOID: t.Callable = lambda m, d: None def get_model(method): """ Convert string to model class. """ @wraps(method) def wrapper(migrator, model, *args, **kwargs): if isinstance(model, str): return method(migrator, migrator.orm[model], *args, **kwargs) return method(migrator, model, *args, **kwargs) return wrapper class Migrator(object): def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]): """ Initializes the migrator """ if isinstance(database, peewee.Proxy): database = database.obj self.database: SqliteDatabase = database self.orm: t.Dict[str, peewee.Model] = {} self.operations: t.List[Operation] = [] self.migrator = SqliteMigrator(database) def run(self): """ Runs operations. """ for op in self.operations: if isinstance(op, Operation): op.run() else: op() self.clean() def clean(self): """ Cleans the operations. """ self.operations = list() def sql(self, sql: str, *params): """ Executes raw SQL. """ self.operations.append(self.migrator.sql(sql, *params)) def create_table(self, model: peewee.Model) -> peewee.Model: """ Creates model and table in database. """ self.orm[model._meta.table_name] = model model._meta.database = self.database self.operations.append(model.create_table) return model @get_model def drop_table(self, model: peewee.Model): """ Drops model and table from database. """ del self.orm[model._meta.table_name] self.operations.append(self.migrator.drop_table(model)) @get_model def add_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model: """ Creates new fields. """ for name, field in fields.items(): model._meta.add_field(name, field) self.operations.append(self.migrator.add_column( model._meta.table_name, field.column_name, field)) if field.unique: self.operations.append(self.migrator.add_index( model._meta.table_name, (field.column_name,), unique=True)) return model @get_model def change_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model: """ Changes fields. """ for name, field in fields.items(): old_field = model._meta.fields.get(name, field) old_column_name = old_field and old_field.column_name model._meta.add_field(name, field) if isinstance(old_field, peewee.ForeignKeyField): self.operations.append(self.migrator.drop_foreign_key_constraint( model._meta.table_name, old_column_name)) if old_column_name != field.column_name: self.operations.append( self.migrator.rename_column( model._meta.table_name, old_column_name, field.column_name)) if isinstance(field, peewee.ForeignKeyField): on_delete = field.on_delete if field.on_delete else 'RESTRICT' on_update = field.on_update if field.on_update else 'RESTRICT' self.operations.append(self.migrator.add_foreign_key_constraint( model._meta.table_name, field.column_name, field.rel_model._meta.table_name, field.rel_field.name, on_delete, on_update)) continue self.operations.append(self.migrator.change_column( model._meta.table_name, field.column_name, field)) if field.unique == old_field.unique: continue if field.unique: index = (field.column_name,), field.unique self.operations.append(self.migrator.add_index( model._meta.table_name, *index)) model._meta.indexes.append(index) else: index = (field.column_name,), old_field.unique self.operations.append(self.migrator.drop_index( model._meta.table_name, *index)) model._meta.indexes.remove(index) return model @get_model def drop_columns(self, model: peewee.Model, names: str, **kwargs) -> peewee.Model: """ Removes fields from model. """ fields = [field for field in model._meta.fields.values() if field.name in names] cascade = kwargs.pop('cascade', True) for field in fields: self.__del_field__(model, field) if field.unique: index_name = make_index_name( model._meta.table_name, [field.column_name]) self.operations.append(self.migrator.drop_index( model._meta.table_name, index_name)) self.operations.append( self.migrator.drop_column( model._meta.table_name, field.column_name, cascade=False)) return model def __del_field__(self, model: peewee.Model, field: peewee.Field): """ Deletes field from model. """ model._meta.remove_field(field.name) delattr(model, field.name) if isinstance(field, peewee.ForeignKeyField): obj_id_name = field.column_name if field.column_name == field.name: obj_id_name += '_id' delattr(model, obj_id_name) delattr(field.rel_model, field.backref) @get_model def rename_column(self, model: peewee.Model, old_name: str, new_name: str) -> peewee.Model: """ Renames field in model. """ field = model._meta.fields[old_name] if isinstance(field, peewee.ForeignKeyField): old_name = field.column_name self.__del_field__(model, field) field.name = field.column_name = new_name model._meta.add_field(new_name, field) if isinstance(field, peewee.ForeignKeyField): field.column_name = new_name = field.column_name + '_id' self.operations.append(self.migrator.rename_column( model._meta.table_name, old_name, new_name)) return model @get_model def rename_table(self, model: peewee.Model, new_name: str) -> peewee.Model: """ Renames table in database. """ old_name = model._meta.table_name del self.orm[model._meta.table_name] model._meta.table_name = new_name self.orm[model._meta.table_name] = model self.operations.append(self.migrator.rename_table(old_name, new_name)) return model @get_model def add_index(self, model: peewee.Model, *columns: str, **kwargs) -> peewee.Model: """Create indexes.""" unique = kwargs.pop('unique', False) model._meta.indexes.append((columns, unique)) columns_ = [] for col in columns: field = model._meta.fields.get(col) if len(columns) == 1: field.unique = unique field.index = not unique if isinstance(field, peewee.ForeignKeyField): col = col + '_id' columns_.append(col) self.operations.append(self.migrator.add_index( model._meta.table_name, columns_, unique=unique)) return model @get_model def drop_index(self, model: peewee.Model, *columns: str) -> peewee.Model: """Drop indexes.""" columns_ = [] for col in columns: field = model._meta.fields.get(col) if not field: continue if len(columns) == 1: field.unique = field.index = False if isinstance(field, peewee.ForeignKeyField): col = col + '_id' columns_.append(col) index_name = make_index_name(model._meta.table_name, columns_) model._meta.indexes = [(cols, _) for ( cols, _) in model._meta.indexes if columns != cols] self.operations.append(self.migrator.drop_index( model._meta.table_name, index_name)) return model @get_model def add_not_null(self, model: peewee.Model, *names: str) -> peewee.Model: """Add not null.""" for name in names: field = model._meta.fields[name] field.null = False self.operations.append(self.migrator.add_not_null( model._meta.table_name, field.column_name)) return model @get_model def drop_not_null(self, model: peewee.Model, *names: str) -> peewee.Model: """Drop not null.""" for name in names: field = model._meta.fields[name] field.null = True self.operations.append(self.migrator.drop_not_null( model._meta.table_name, field.column_name)) return model @get_model def add_default(self, model: peewee.Model, name: str, default: t.Any) -> peewee.Model: """Add default.""" field = model._meta.fields[name] model._meta.defaults[field] = field.default = default self.operations.append(self.migrator.apply_default( model._meta.table_name, name, field)) return model class SqliteMigrator(SqM): def drop_table(self, model): return lambda: model.drop_table(cascade=False) @operation def change_column(self, table: str, column_name: str, field: peewee.Field): operations = [self.alter_change_column(table, column_name, field)] if not field.null: operations.extend([self.add_not_null(table, column_name)]) return operations def alter_change_column(self, table: str, column_name: str, field: peewee.Field) -> Operation: return self._update_column(table, column_name, lambda x, y: y) @operation def sql(self, sql: str, *params) -> SQL: """ Executes raw SQL. """ return SQL(sql, *params) def alter_add_column( self, table: str, column_name: str, field: peewee.Field, **kwargs) -> Operation: """ Fixes field name for ForeignKeys. """ name = field.name op = super().alter_add_column( table, column_name, field, **kwargs) if isinstance(field, peewee.ForeignKeyField): field.name = name return op class MigrationManager(object): filemask = re.compile(r"[\d]+_[^\.]+\.py$") def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]): """ Initializes the migration manager. """ if not isinstance(database, (peewee.Database, peewee.Proxy)): raise RuntimeError('Invalid database: {}'.format(database)) self.database = database @cached_property def model(self) -> peewee.Model: """ Initialize and cache the MigrationHistory model. """ MigrateHistory._meta.database = self.database MigrateHistory._meta.table_name = 'migratehistory' MigrateHistory._meta.schema = None MigrateHistory.create_table(True) return MigrateHistory @property def done(self) -> t.List[str]: """ Scans migrations in the database. """ return [mm.name for mm in self.model.select().order_by(self.model.id)] @property def todo(self): """ Scans migrations in the file system. """ if not os.path.exists(helper.migration_dir): logger.warning('Migration directory: {} does not exist.'.format( helper.migration_dir)) os.makedirs(helper.migration_dir) return sorted(f[:-3] for f in os.listdir(helper.migration_dir) if self.filemask.match(f)) @property def diff(self) -> t.List[str]: """ Calculates difference between the filesystem and the database. """ done = set(self.done) return [name for name in self.todo if name not in done] @cached_property def migrator(self) -> Migrator: """ Create migrator and setup it with fake migrations. """ migrator = Migrator(self.database) for name in self.done: self.up_one(name, migrator, True) return migrator def compile(self, name, migrate='', rollback=''): """ Compiles a migration. """ name = datetime.utcnow().strftime('%Y%m%d%H%M%S') + '_' + name filename = name + '.py' path = os.path.join(helper.migration_dir, filename) with open(path, 'w') as f: f.write(MIGRATE_TEMPLATE.format( migrate=migrate, rollback=rollback, name=filename)) return name def create(self, name: str = 'auto', auto: bool = False) -> t.Optional[str]: """ Creates a migration. """ migrate = rollback = '' if auto: raise NotImplementedError logger.info('Creating migration "{}"'.format(name)) name = self.compile(name, migrate, rollback) logger.info('Migration has been created as "{}"'.format(name)) return name def clear(self): """Clear migrations.""" self.model.delete().execute() def up(self, name: t.Optional[str] = None): """ Runs all unapplied migrations. """ logger.info('Starting migrations') console.info('Starting migrations') done = [] diff = self.diff if not diff: logger.info('There is nothing to migrate') console.info('There is nothing to migrate') return done migrator = self.migrator for mname in diff: done.append(self.up_one(mname, self.migrator)) if name and name == mname: break return done def read(self, name: str): """ Reads a migration from a file. """ call_params = dict() if helper.is_os_windows() and sys.version_info >= (3, 0): # if system is windows - force utf-8 encoding call_params['encoding'] = 'utf-8' with open(os.path.join(helper.migration_dir, name + '.py'), **call_params) as f: code = f.read() scope = {} code = compile(code, '', 'exec', dont_inherit=True) exec(code, scope, None) return scope.get('migrate', VOID), scope.get('rollback', VOID) def up_one(self, name: str, migrator: Migrator, fake: bool = False, rollback: bool = False) -> str: """ Runs a migration with a given name. """ try: migrate_fn, rollback_fn = self.read(name) if fake: migrate_fn(migrator, self.database) migrator.clean() return name with self.database.transaction(): if rollback: logger.info('Rolling back "{}"'.format(name)) rollback_fn(migrator, self.database) migrator.run() self.model.delete().where(self.model.name == name).execute() else: logger.info('Migrate "{}"'.format(name)) migrate_fn(migrator, self.database) migrator.run() if name not in self.done: self.model.create(name=name) logger.info('Done "{}"'.format(name)) return name except Exception: self.database.rollback() operation = 'Rollback' if rollback else 'Migration' logger.exception('{} failed: {}'.format(operation, name)) raise def down(self, name: t.Optional[str] = None): """ Rolls back migrations. """ if not self.done: raise RuntimeError('No migrations are found.') name = self.done[-1] migrator = self.migrator self.up_one(name, migrator, False, True) logger.warning('Rolled back migration: {}'.format(name))