diff --git a/app/App.go b/app/App.go index a0d2c4c..f506771 100644 --- a/app/App.go +++ b/app/App.go @@ -1,36 +1,118 @@ package app import ( + "context" + "os" + "os/signal" + "sync" + "syscall" + "github.com/gdamore/tcell/v2" "github.com/rivo/tview" ) -var App = tview.NewApplication() +var ( + App *Application + Styles *Theme +) + +type Application struct { + *tview.Application + + context context.Context + cancelfn context.CancelFunc + wg sync.WaitGroup +} type Theme struct { - SidebarTitleBorderColor string tview.Theme -} -var Styles = Theme{ - SidebarTitleBorderColor: "#666A7E", + SidebarTitleBorderColor string } func init() { - theme := tview.Theme{ - PrimitiveBackgroundColor: tcell.ColorDefault, - ContrastBackgroundColor: tcell.ColorBlue, - MoreContrastBackgroundColor: tcell.ColorGreen, - BorderColor: tcell.ColorWhite, - TitleColor: tcell.ColorWhite, - GraphicsColor: tcell.ColorGray, - PrimaryTextColor: tcell.ColorDefault.TrueColor(), - SecondaryTextColor: tcell.ColorYellow, - TertiaryTextColor: tcell.ColorGreen, - InverseTextColor: tcell.ColorWhite, - ContrastSecondaryTextColor: tcell.ColorBlack, + ctx, cancel := context.WithCancel(context.Background()) + + App = &Application{ + Application: tview.NewApplication(), + context: ctx, + cancelfn: cancel, + } + + App.register() + App.EnableMouse(true) + + Styles = &Theme{ + Theme: tview.Theme{ + PrimitiveBackgroundColor: tcell.ColorDefault, + ContrastBackgroundColor: tcell.ColorBlue, + MoreContrastBackgroundColor: tcell.ColorGreen, + BorderColor: tcell.ColorWhite, + TitleColor: tcell.ColorWhite, + GraphicsColor: tcell.ColorGray, + PrimaryTextColor: tcell.ColorDefault.TrueColor(), + SecondaryTextColor: tcell.ColorYellow, + TertiaryTextColor: tcell.ColorGreen, + InverseTextColor: tcell.ColorWhite, + ContrastSecondaryTextColor: tcell.ColorBlack, + }, + SidebarTitleBorderColor: "#666A7E", } - Styles.Theme = theme - tview.Styles = theme + tview.Styles = Styles.Theme +} + +// Context returns the application context. +func (a *Application) Context() context.Context { + return a.context +} + +// Register adds a task to the wait group and returns a +// function that decrements the task count when called. +// +// The application will not stop until all registered tasks +// have finished by calling the returned function! +func (a *Application) Register() func() { + a.wg.Add(1) + return a.wg.Done +} + +// Run starts and blocks until the application is stopped. +func (a *Application) Run(root *tview.Pages) error { + a.SetRoot(root, true) + return a.Application.Run() +} + +// Stop cancels the application context, waits for all +// tasks to finish, and then stops the application. +func (a *Application) Stop() { + a.cancelfn() + a.wg.Wait() + a.Application.Stop() +} + +// register listens for interrupt and termination signals to +// gracefully handle shutdowns by calling the Stop method. +func (a *Application) register() { + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + a.Stop() + <-c + os.Exit(1) + }() + + // Override the default input capture to listen for Ctrl+C + // and make it send an interrupt signal to the channel to + // trigger a graceful shutdown instead of closing the app + // immediately without waiting for tasks to finish. + a.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { + if event.Key() == tcell.KeyCtrlC { + c <- os.Interrupt + return nil + } + return event + }) } diff --git a/components/ConnectionSelection.go b/components/ConnectionSelection.go index f6ad3d1..de09c22 100644 --- a/components/ConnectionSelection.go +++ b/components/ConnectionSelection.go @@ -140,50 +140,79 @@ func NewConnectionSelection(connectionForm *ConnectionForm, connectionPages *mod return cs } -func (cs *ConnectionSelection) Connect(connection models.Connection) { - if MainPages.HasPage(connection.URL) { - MainPages.SwitchToPage(connection.URL) - App.Draw() - } else { - cs.StatusText.SetText("Connecting...").SetTextColor(app.Styles.TertiaryTextColor) - App.Draw() - - var newDbDriver drivers.Driver - - switch connection.Provider { - case drivers.DriverMySQL: - newDbDriver = &drivers.MySQL{} - case drivers.DriverPostgres: - newDbDriver = &drivers.Postgres{} - case drivers.DriverSqlite: - newDbDriver = &drivers.SQLite{} - } - - err := newDbDriver.Connect(connection.URL) +func (cs *ConnectionSelection) Connect(connection models.Connection) *tview.Application { + if MainPages.HasPage(connection.Name) { + MainPages.SwitchToPage(connection.Name) + return App.Draw() + } + if len(connection.Commands) > 0 { + port, err := helpers.GetFreePort() if err != nil { cs.StatusText.SetText(err.Error()).SetTextStyle(tcell.StyleDefault.Foreground(tcell.ColorRed)) - App.Draw() - } else { - newHome := NewHomePage(connection, newDbDriver) + return App.Draw() + } - MainPages.AddAndSwitchToPage(connection.URL, newHome, true) + // Replace ${port} with the actual port. + connection.URL = strings.ReplaceAll(connection.URL, "${port}", port) - cs.StatusText.SetText("") + for i, command := range connection.Commands { + message := fmt.Sprintf("Running command %d/%d...", i+1, len(connection.Commands)) + cs.StatusText.SetText(message).SetTextColor(app.Styles.TertiaryTextColor) App.Draw() - selectedRow, selectedCol := ConnectionListTable.GetSelection() - cell := ConnectionListTable.GetCell(selectedRow, selectedCol) - cell.SetText(fmt.Sprintf("[green]* %s", cell.Text)) + cmd := strings.ReplaceAll(command.Command, "${port}", port) + if err := helpers.RunCommand(App.Context(), cmd, App.Register()); err != nil { + cs.StatusText.SetText(err.Error()).SetTextStyle(tcell.StyleDefault.Foreground(tcell.ColorRed)) + return App.Draw() + } - ConnectionListTable.SetCell(selectedRow, selectedCol, cell) + if command.WaitForPort != "" { + port := strings.ReplaceAll(command.WaitForPort, "${port}", port) - MainPages.SwitchToPage(connection.URL) - newHome.Tree.SetCurrentNode(newHome.Tree.GetRoot()) - newHome.Tree.Wrapper.SetTitle(fmt.Sprintf("%s (%s)", connection.Name, strings.ToUpper(connection.Provider))) - App.SetFocus(newHome.Tree) - App.Draw() + message := fmt.Sprintf("Waiting for port %s...", port) + cs.StatusText.SetText(message).SetTextColor(app.Styles.TertiaryTextColor) + App.Draw() + + if err := helpers.WaitForPort(port); err != nil { + cs.StatusText.SetText(err.Error()).SetTextStyle(tcell.StyleDefault.Foreground(tcell.ColorRed)) + return App.Draw() + } + } } + } + + cs.StatusText.SetText("Connecting...").SetTextColor(app.Styles.TertiaryTextColor) + App.Draw() + + var newDBDriver drivers.Driver + + switch connection.Provider { + case drivers.DriverMySQL: + newDBDriver = &drivers.MySQL{} + case drivers.DriverPostgres: + newDBDriver = &drivers.Postgres{} + case drivers.DriverSqlite: + newDBDriver = &drivers.SQLite{} + } + err := newDBDriver.Connect(connection.URL) + if err != nil { + cs.StatusText.SetText(err.Error()).SetTextStyle(tcell.StyleDefault.Foreground(tcell.ColorRed)) + return App.Draw() } + + selectedRow, selectedCol := ConnectionListTable.GetSelection() + cell := ConnectionListTable.GetCell(selectedRow, selectedCol) + cell.SetText(fmt.Sprintf("[green]* %s", cell.Text)) + cs.StatusText.SetText("") + + newHome := NewHomePage(connection, newDBDriver) + newHome.Tree.SetCurrentNode(newHome.Tree.GetRoot()) + newHome.Tree.Wrapper.SetTitle(connection.Name) + + MainPages.AddAndSwitchToPage(connection.Name, newHome, true) + App.SetFocus(newHome.Tree) + + return App.Draw() } diff --git a/components/Home.go b/components/Home.go index 664ca9c..a355ffe 100644 --- a/components/Home.go +++ b/components/Home.go @@ -14,6 +14,7 @@ import ( type Home struct { *tview.Flex + Tree *Tree TabbedPane *TabbedPane LeftWrapper *tview.Flex @@ -22,7 +23,7 @@ type Home struct { HelpModal *HelpModal DBDriver drivers.Driver FocusedWrapper string - ListOfDbChanges []models.DbDmlChange + ListOfDBChanges []models.DbDmlChange } func NewHomePage(connection models.Connection, dbdriver drivers.Driver) *Home { @@ -41,8 +42,8 @@ func NewHomePage(connection models.Connection, dbdriver drivers.Driver) *Home { RightWrapper: rightWrapper, HelpStatus: NewHelpStatus(), HelpModal: NewHelpModal(), - ListOfDbChanges: []models.DbDmlChange{}, DBDriver: dbdriver, + ListOfDBChanges: []models.DbDmlChange{}, } go home.subscribeToTreeChanges() @@ -96,7 +97,7 @@ func (home *Home) subscribeToTreeChanges() { table = tab.Content home.TabbedPane.SwitchToTabByReference(tab.Reference) } else { - table = NewResultsTable(&home.ListOfDbChanges, home.Tree, home.DBDriver).WithFilter() + table = NewResultsTable(&home.ListOfDBChanges, home.Tree, home.DBDriver).WithFilter() table.SetDatabaseName(databaseName) table.SetTableName(tableName) @@ -286,7 +287,7 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { home.TabbedPane.SwitchToTabByName(tabNameEditor) tab.Content.SetIsFiltering(true) } else { - tableWithEditor := NewResultsTable(&home.ListOfDbChanges, home.Tree, home.DBDriver).WithEditor() + tableWithEditor := NewResultsTable(&home.ListOfDBChanges, home.Tree, home.DBDriver).WithEditor() home.TabbedPane.AppendTab(tabNameEditor, tableWithEditor, tabNameEditor) tableWithEditor.SetIsFiltering(true) } @@ -298,17 +299,11 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { MainPages.SwitchToPage(pageNameConnections) } case commands.Quit: - if tab != nil { - table := tab.Content - - if !table.GetIsFiltering() && !table.GetIsEditing() { - App.Stop() - } - } else { + if tab == nil || (!table.GetIsEditing() && !table.GetIsFiltering()) { App.Stop() } case commands.Save: - if (len(home.ListOfDbChanges) > 0) && !table.GetIsEditing() { + if (len(home.ListOfDBChanges) > 0) && !table.GetIsEditing() { confirmationModal := NewConfirmationModal("") confirmationModal.SetDoneFunc(func(_ int, buttonLabel string) { @@ -317,12 +312,12 @@ func (home *Home) homeInputCapture(event *tcell.EventKey) *tcell.EventKey { if buttonLabel == "Yes" { - err := home.DBDriver.ExecutePendingChanges(home.ListOfDbChanges) + err := home.DBDriver.ExecutePendingChanges(home.ListOfDBChanges) if err != nil { table.SetError(err.Error(), nil) } else { - home.ListOfDbChanges = []models.DbDmlChange{} + home.ListOfDBChanges = []models.DbDmlChange{} table.FetchRecords(nil) home.Tree.ForceRemoveHighlight() diff --git a/go.mod b/go.mod index 7587174..76fbc2c 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/google/uuid v1.6.0 github.com/lib/pq v1.10.9 + github.com/mitchellh/go-linereader v0.0.0-20190213213312-1b945b3263eb github.com/pelletier/go-toml/v2 v2.1.1 github.com/rivo/tview v0.0.0-20240101144852-b3bd1aa5e9f2 github.com/xo/dburl v0.20.2 diff --git a/go.sum b/go.sum index 17c9db8..adf9efc 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/go-linereader v0.0.0-20190213213312-1b945b3263eb h1:GRiLv4rgyqjqzxbhJke65IYUf4NCOOvrPOJbV/sPxkM= +github.com/mitchellh/go-linereader v0.0.0-20190213213312-1b945b3263eb/go.mod h1:OaY7UOoTkkrX3wRwjpYRKafIkkyeD0UtweSHAWWiqQM= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= diff --git a/helpers/command.go b/helpers/command.go new file mode 100644 index 0000000..fc3552c --- /dev/null +++ b/helpers/command.go @@ -0,0 +1,77 @@ +package helpers + +import ( + "context" + "errors" + "io" + "os/exec" + "strings" + "time" + + "github.com/mitchellh/go-linereader" + + "github.com/jorgerojas26/lazysql/helpers/logger" +) + +func RunCommand(ctx context.Context, command string, doneFn func()) error { + var cmd *exec.Cmd + + parts := strings.Fields(command) + if len(parts) == 1 { + cmd = exec.CommandContext(ctx, parts[0]) // #nosec G204 + } else { + cmd = exec.CommandContext(ctx, parts[0], parts[1:]...) // #nosec G204 + } + + // Create a pipe to read the output from. + pr, pw := io.Pipe() + startedCh := make(chan struct{}) + copyDoneCh := make(chan struct{}) + go logOutput(pr, startedCh, copyDoneCh) + + // Connect the pipe to stdout and stderr. + cmd.Stderr = pw + cmd.Stdout = pw + + if err := cmd.Start(); err != nil { + return err + } + + go func() { + if err := cmd.Wait(); err != nil { + logger.Error("Command stopped", map[string]any{"error": err.Error()}) + } + + _ = pw.Close() + <-copyDoneCh + doneFn() + }() + + // Wait for the command to start + select { + case <-ctx.Done(): + logger.Error("Command canceled", map[string]any{"error": ctx.Err()}) + case <-startedCh: + logger.Info("Command started", map[string]any{"command": command}) + case <-time.After(5 * time.Second): + cmd.Process.Kill() + return errors.New("command timeout") + } + + return nil +} + +func logOutput(r io.Reader, started, doneCh chan struct{}) { + defer close(doneCh) + lr := linereader.New(r) + + // Wait for the command to start + line := <-lr.Ch + started <- struct{}{} + logger.Debug("Command output", map[string]any{"line": line}) + + // Log the rest of the output + for line := range lr.Ch { + logger.Debug("Command output", map[string]any{"line": line}) + } +} diff --git a/helpers/utils.go b/helpers/utils.go index bdd3608..7e83a59 100644 --- a/helpers/utils.go +++ b/helpers/utils.go @@ -1,6 +1,11 @@ package helpers import ( + "errors" + "net" + "strconv" + "time" + "github.com/xo/dburl" "github.com/jorgerojas26/lazysql/commands" @@ -18,3 +23,32 @@ func ContainsCommand(commands []commands.Command, command commands.Command) bool } return false } + +// GetFreePort asks the kernel for a free port. +func GetFreePort() (string, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return "", err + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return "", err + } + defer l.Close() + + return strconv.Itoa(l.Addr().(*net.TCPAddr).Port), nil +} + +// WaitForPort waits for a port to be open. +func WaitForPort(port string) error { + for i := 0; i < 10; i++ { + conn, err := net.DialTimeout("tcp", "localhost:"+port, 500*time.Millisecond) + if err == nil { + _ = conn.Close() + return nil + } + time.Sleep(500 * time.Second) + } + return errors.New("Timeout waiting for port " + port) +} diff --git a/main.go b/main.go index 074efd7..155c5c4 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "flag" + "fmt" "io" "log" "os" @@ -16,45 +17,40 @@ import ( var version = "dev" func main() { - rawLogLvl := flag.String("loglvl", "info", "Log level") + logLevel := flag.String("loglevel", "info", "Log level") logFile := flag.String("logfile", "", "Log file") flag.Parse() - logLvl, parseError := logger.ParseLogLevel(*rawLogLvl) - if parseError != nil { - panic(parseError) + slogLevel, err := logger.ParseLogLevel(*logLevel) + if err != nil { + log.Fatalf("Error parsing log level: %v", err) } - logger.SetLevel(logLvl) + logger.SetLevel(slogLevel) if *logFile != "" { - fileError := logger.SetFile(*logFile) - if fileError != nil { - panic(fileError) + if err := logger.SetFile(*logFile); err != nil { + log.Fatalf("Error setting log file: %v", err) } } logger.Info("Starting LazySQL...", nil) - mysqlError := mysql.SetLogger(log.New(io.Discard, "", 0)) - if mysqlError != nil { - panic(mysqlError) + if err := mysql.SetLogger(log.New(io.Discard, "", 0)); err != nil { + log.Fatalf("Error setting MySQL logger: %v", err) } - // check if "version" arg is passed + // Check if "version" arg is passed. argsWithProg := os.Args if len(argsWithProg) > 1 { switch argsWithProg[1] { case "version": - println("LazySQL version: ", version) + fmt.Println("LazySQL version: ", version) os.Exit(0) } } - if err := app.App. - SetRoot(components.MainPages, true). - EnableMouse(true). - Run(); err != nil { - panic(err) + if err = app.App.Run(components.MainPages); err != nil { + log.Fatalf("Error running app: %v", err) } } diff --git a/models/models.go b/models/models.go index 7fbb6d5..be3f57c 100644 --- a/models/models.go +++ b/models/models.go @@ -9,6 +9,12 @@ type Connection struct { Provider string DBName string URL string + Commands []*Command +} + +type Command struct { + Command string + WaitForPort string } type StateChange struct {