Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions go/base/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ type MigrationContext struct {
AzureMySQL bool
AttemptInstantDDL bool
Resume bool
Revert bool
OldTableName string

// SkipPortValidation allows skipping the port validation in `ValidateConnection`
// This is useful when connecting to a MySQL instance where the external port
Expand Down Expand Up @@ -348,6 +350,9 @@ func getSafeTableName(baseName string, suffix string) string {
// GetGhostTableName generates the name of ghost table, based on original table name
// or a given table name
func (this *MigrationContext) GetGhostTableName() string {
if this.Revert {
return this.OldTableName
}
if this.ForceTmpTableName != "" {
return getSafeTableName(this.ForceTmpTableName, "gho")
} else {
Expand All @@ -364,14 +369,19 @@ func (this *MigrationContext) GetOldTableName() string {
tableName = this.OriginalTableName
}

suffix := "del"
if this.Revert {
// When reverting the "ghost" table is the _del table
suffix = "rev_del"
}
if this.TimestampOldTable {
t := this.StartTime
timestamp := fmt.Sprintf("%d%02d%02d%02d%02d%02d",
t.Year(), t.Month(), t.Day(),
t.Hour(), t.Minute(), t.Second())
return getSafeTableName(tableName, fmt.Sprintf("%s_del", timestamp))
return getSafeTableName(tableName, fmt.Sprintf("%s_%s", timestamp, suffix))
}
return getSafeTableName(tableName, "del")
return getSafeTableName(tableName, suffix)
}

// GetChangelogTableName generates the name of changelog table, based on original table name
Expand Down
18 changes: 15 additions & 3 deletions go/cmd/gh-ost/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ func main() {
flag.BoolVar(&migrationContext.Checkpoint, "checkpoint", false, "Enable migration checkpoints")
flag.Int64Var(&migrationContext.CheckpointIntervalSeconds, "checkpoint-seconds", 300, "The number of seconds between checkpoints")
flag.BoolVar(&migrationContext.Resume, "resume", false, "Attempt to resume migration from checkpoint")
flag.BoolVar(&migrationContext.Revert, "revert", false, "Attempt to revert completed migration")
flag.StringVar(&migrationContext.OldTableName, "old-table", "", "The name of the old table when using --revert, e.g. '_mytable_del'")

maxLoad := flag.String("max-load", "", "Comma delimited status-name=threshold. e.g: 'Threads_running=100,Threads_connected=500'. When status exceeds threshold, app throttles writes")
criticalLoad := flag.String("critical-load", "", "Comma delimited status-name=threshold, same format as --max-load. When status exceeds threshold, app panics and quits")
Expand Down Expand Up @@ -291,6 +293,10 @@ func main() {
migrationContext.Log.Fatalf("--checkpoint-seconds should be >=10")
}

if migrationContext.Revert && migrationContext.OldTableName == "" {
migrationContext.Log.Fatalf("--revert must be called with --old-table")
}

switch *cutOver {
case "atomic", "default", "":
migrationContext.CutOverType = base.CutOverAtomic
Expand Down Expand Up @@ -347,9 +353,15 @@ func main() {
acceptSignals(migrationContext)

migrator := logic.NewMigrator(migrationContext, AppVersion)
if err := migrator.Migrate(); err != nil {
migrator.ExecOnFailureHook()
migrationContext.Log.Fatale(err)
if migrationContext.Revert {
if err := migrator.Revert(); err != nil {
migrationContext.Log.Fatale(err)
}
} else {
if err := migrator.Migrate(); err != nil {
migrator.ExecOnFailureHook()
migrationContext.Log.Fatale(err)
}
}
fmt.Fprintln(os.Stdout, "# Done")
}
11 changes: 3 additions & 8 deletions go/logic/applier.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,25 +437,20 @@ func (this *Applier) CreateCheckpointTable() error {
"`gh_ost_chk_iteration` bigint",
"`gh_ost_rows_copied` bigint",
"`gh_ost_dml_applied` bigint",
"`gh_ost_is_cutover` tinyint(1) DEFAULT '0'",
}
for _, col := range this.migrationContext.UniqueKey.Columns.Columns() {
if col.MySQLType == "" {
return fmt.Errorf("CreateCheckpoinTable: column %s has no type information. applyColumnTypes must be called", sql.EscapeName(col.Name))
}
minColName := sql.TruncateColumnName(col.Name, sql.MaxColumnNameLength-4) + "_min"
colDef := fmt.Sprintf("%s %s", sql.EscapeName(minColName), col.MySQLType)
if !col.Nullable {
colDef += " NOT NULL"
}
colDefs = append(colDefs, colDef)
}

for _, col := range this.migrationContext.UniqueKey.Columns.Columns() {
maxColName := sql.TruncateColumnName(col.Name, sql.MaxColumnNameLength-4) + "_max"
colDef := fmt.Sprintf("%s %s", sql.EscapeName(maxColName), col.MySQLType)
if !col.Nullable {
colDef += " NOT NULL"
}
colDefs = append(colDefs, colDef)
}

Expand Down Expand Up @@ -627,7 +622,7 @@ func (this *Applier) WriteCheckpoint(chk *Checkpoint) (int64, error) {
if err != nil {
return insertId, err
}
args := sqlutils.Args(chk.LastTrxCoords.String(), chk.Iteration, chk.RowsCopied, chk.DMLApplied)
args := sqlutils.Args(chk.LastTrxCoords.String(), chk.Iteration, chk.RowsCopied, chk.DMLApplied, chk.IsCutover)
args = append(args, uniqueKeyArgs...)
res, err := this.db.Exec(query, args...)
if err != nil {
Expand All @@ -645,7 +640,7 @@ func (this *Applier) ReadLastCheckpoint() (*Checkpoint, error) {

var coordStr string
var timestamp int64
ptrs := []interface{}{&chk.Id, &timestamp, &coordStr, &chk.Iteration, &chk.RowsCopied, &chk.DMLApplied}
ptrs := []interface{}{&chk.Id, &timestamp, &coordStr, &chk.Iteration, &chk.RowsCopied, &chk.DMLApplied, &chk.IsCutover}
ptrs = append(ptrs, chk.IterationRangeMin.ValuesPointers...)
ptrs = append(ptrs, chk.IterationRangeMax.ValuesPointers...)
err := row.Scan(ptrs...)
Expand Down
5 changes: 4 additions & 1 deletion go/logic/applier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ func (suite *ApplierTestSuite) SetupSuite() {
testmysql.WithUsername(testMysqlUser),
testmysql.WithPassword(testMysqlPass),
testcontainers.WithWaitStrategy(wait.ForExposedPort()),
testmysql.WithConfigFile("my.cnf.test"),
)
suite.Require().NoError(err)

Expand Down Expand Up @@ -272,7 +273,7 @@ func (suite *ApplierTestSuite) TestInitDBConnections() {
mysqlVersion, _ := strings.CutPrefix(testMysqlContainerImage, "mysql:")
suite.Require().Equal(mysqlVersion, migrationContext.ApplierMySQLVersion)
suite.Require().Equal(int64(28800), migrationContext.ApplierWaitTimeout)
suite.Require().Equal("SYSTEM", migrationContext.ApplierTimeZone)
suite.Require().Equal("+00:00", migrationContext.ApplierTimeZone)

suite.Require().Equal(sql.NewColumnList([]string{"id", "item_id"}), migrationContext.OriginalTableColumnsOnApplier)
}
Expand Down Expand Up @@ -702,6 +703,7 @@ func (suite *ApplierTestSuite) TestWriteCheckpoint() {
Iteration: 2,
RowsCopied: 100000,
DMLApplied: 200000,
IsCutover: true,
}
id, err := applier.WriteCheckpoint(chk)
suite.Require().NoError(err)
Expand All @@ -716,6 +718,7 @@ func (suite *ApplierTestSuite) TestWriteCheckpoint() {
suite.Require().Equal(chk.IterationRangeMax.String(), gotChk.IterationRangeMax.String())
suite.Require().Equal(chk.RowsCopied, gotChk.RowsCopied)
suite.Require().Equal(chk.DMLApplied, gotChk.DMLApplied)
suite.Require().Equal(chk.IsCutover, gotChk.IsCutover)
}

func TestApplier(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions go/logic/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ type Checkpoint struct {
Iteration int64
RowsCopied int64
DMLApplied int64
IsCutover bool
}
1 change: 1 addition & 0 deletions go/logic/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func (this *Inspector) ValidateOriginalTable() (err error) {
}

func (this *Inspector) InspectTableColumnsAndUniqueKeys(tableName string) (columns *sql.ColumnList, virtualColumns *sql.ColumnList, uniqueKeys [](*sql.UniqueKey), err error) {
this.migrationContext.Log.Debugf("InspectTableColumnsAndUniqueKeys: %s", tableName)
uniqueKeys, err = this.getCandidateUniqueKeys(tableName)
if err != nil {
return columns, virtualColumns, uniqueKeys, err
Expand Down
Loading
Loading