From 4f73e4c8c68f5536e1829e66b9f1fefeded2cda6 Mon Sep 17 00:00:00 2001 From: Cristian Maglie Date: Wed, 22 Jan 2025 17:08:03 +0100 Subject: [PATCH] Moved user-agent extraction deep in configuration.HttpClient This allows the extraction of the user-agent in a single place. Also it forces the context passing on all operations that requires access to network. --- commands/instances.go | 34 ++++------------------ commands/service_board_identify.go | 18 ++++++------ commands/service_board_identify_test.go | 15 +++++----- commands/service_check_for_updates.go | 6 ++-- commands/service_library_download.go | 12 +++++--- internal/arduino/resources/helpers_test.go | 2 +- internal/cli/configuration/network.go | 17 ++++++++--- internal/cli/configuration/network_test.go | 7 +++-- 8 files changed, 51 insertions(+), 60 deletions(-) diff --git a/commands/instances.go b/commands/instances.go index 098e03f76ae..8264f23bae4 100644 --- a/commands/instances.go +++ b/commands/instances.go @@ -43,7 +43,6 @@ import ( paths "github.com/arduino/go-paths-helper" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -64,29 +63,6 @@ func installTool(ctx context.Context, pm *packagemanager.PackageManager, tool *c // Create a new Instance ready to be initialized, supporting directories are also created. func (s *arduinoCoreServerImpl) Create(ctx context.Context, req *rpc.CreateRequest) (*rpc.CreateResponse, error) { - var userAgent string - if md, ok := metadata.FromIncomingContext(ctx); ok { - userAgent = strings.Join(md.Get("user-agent"), " ") - if userAgent != "" { - // s.SettingsGetValue() returns an error if the key does not exist and for this reason we are accessing - // network.user_agent_ext directly from s.settings.ExtraUserAgent() to set it - if s.settings.ExtraUserAgent() == "" { - if strings.Contains(userAgent, "arduino-ide/2") { - // needed for analytics purposes - userAgent = userAgent + " daemon" - } - _, err := s.SettingsSetValue(ctx, &rpc.SettingsSetValueRequest{ - Key: "network.user_agent_ext", - ValueFormat: "cli", - EncodedValue: userAgent, - }) - if err != nil { - return nil, err - } - } - } - } - // Setup downloads directory downloadsDir := s.settings.DownloadsDir() if downloadsDir.NotExist() { @@ -107,11 +83,11 @@ func (s *arduinoCoreServerImpl) Create(ctx context.Context, req *rpc.CreateReque } } - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(ctx) if err != nil { return nil, err } - inst, err := instances.Create(dataDir, packagesDir, userPackagesDir, downloadsDir, userAgent, config) + inst, err := instances.Create(dataDir, packagesDir, userPackagesDir, downloadsDir, "", config) if err != nil { return nil, err } @@ -395,7 +371,7 @@ func (s *arduinoCoreServerImpl) Init(req *rpc.InitRequest, stream rpc.ArduinoCor responseError(err.GRPCStatus()) continue } - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(ctx) if err != nil { taskCallback(&rpc.TaskProgress{Name: i18n.Tr("Error downloading library %s", libraryRef)}) e := &cmderrors.FailedLibraryInstallError{Cause: err} @@ -516,7 +492,7 @@ func (s *arduinoCoreServerImpl) UpdateLibrariesIndex(req *rpc.UpdateLibrariesInd } // Perform index update - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(stream.Context()) if err != nil { return err } @@ -626,7 +602,7 @@ func (s *arduinoCoreServerImpl) UpdateIndex(req *rpc.UpdateIndexRequest, stream } } - config, err := s.settings.DownloaderConfig() + config, err := s.settings.DownloaderConfig(stream.Context()) if err != nil { downloadCB.Start(u, i18n.Tr("Downloading index: %s", filepath.Base(URL.Path))) downloadCB.End(false, i18n.Tr("Invalid network configuration: %s", err)) diff --git a/commands/service_board_identify.go b/commands/service_board_identify.go index 2f4b1f54dce..787de81cee3 100644 --- a/commands/service_board_identify.go +++ b/commands/service_board_identify.go @@ -48,7 +48,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar defer release() props := properties.NewFromHashmap(req.GetProperties()) - res, err := identify(pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection()) + res, err := identify(ctx, pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection()) if err != nil { return nil, err } @@ -58,7 +58,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar } // identify returns a list of boards checking first the installed platforms or the Cloud API -func identify(pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) { +func identify(ctx context.Context, pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) { if properties == nil { return nil, nil } @@ -90,7 +90,7 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings // if installed cores didn't recognize the board, try querying // the builder API if the board is a USB device port if len(boards) == 0 && !skipCloudAPI && !settings.SkipCloudApiForBoardDetection() { - items, err := identifyViaCloudAPI(properties, settings) + items, err := identifyViaCloudAPI(ctx, properties, settings) if err != nil { // this is bad, but keep going logrus.WithError(err).Debug("Error querying builder API") @@ -119,14 +119,14 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings return boards, nil } -func identifyViaCloudAPI(props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { +func identifyViaCloudAPI(ctx context.Context, props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { // If the port is not USB do not try identification via cloud if !props.ContainsKey("vid") || !props.ContainsKey("pid") { return nil, nil } logrus.Debug("Querying builder API for board identification...") - return cachedAPIByVidPid(props.Get("vid"), props.Get("pid"), settings) + return cachedAPIByVidPid(ctx, props.Get("vid"), props.Get("pid"), settings) } var ( @@ -134,7 +134,7 @@ var ( validVidPid = regexp.MustCompile(`0[xX][a-fA-F\d]{4}`) ) -func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { +func cachedAPIByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { var resp []*rpc.BoardListItem cacheKey := fmt.Sprintf("cache.builder-api.v3/boards/byvid/pid/%s/%s", vid, pid) @@ -148,7 +148,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp } } - resp, err := apiByVidPid(vid, pid, settings) // Perform API requrest + resp, err := apiByVidPid(ctx, vid, pid, settings) // Perform API requrest if err == nil { if cachedResp, err := json.Marshal(resp); err == nil { @@ -160,7 +160,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp return resp, err } -func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { +func apiByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) { // ensure vid and pid are valid before hitting the API if !validVidPid.MatchString(vid) { return nil, errors.New(i18n.Tr("Invalid vid value: '%s'", vid)) @@ -173,7 +173,7 @@ func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.Boar req, _ := http.NewRequest("GET", url, nil) req.Header.Set("Content-Type", "application/json") - httpClient, err := settings.NewHttpClient() + httpClient, err := settings.NewHttpClient(ctx) if err != nil { return nil, fmt.Errorf("%s: %w", i18n.Tr("failed to initialize http client"), err) } diff --git a/commands/service_board_identify_test.go b/commands/service_board_identify_test.go index 98dc8e40278..31687359885 100644 --- a/commands/service_board_identify_test.go +++ b/commands/service_board_identify_test.go @@ -16,6 +16,7 @@ package commands import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -48,7 +49,7 @@ func TestGetByVidPid(t *testing.T) { vidPidURL = ts.URL settings := configuration.NewSettings() - res, err := apiByVidPid("0xf420", "0XF069", settings) + res, err := apiByVidPid(context.Background(), "0xf420", "0XF069", settings) require.Nil(t, err) require.Len(t, res, 1) require.Equal(t, "Arduino/Genuino MKR1000", res[0].GetName()) @@ -56,7 +57,7 @@ func TestGetByVidPid(t *testing.T) { // wrong vid (too long), wrong pid (not an hex value) - _, err = apiByVidPid("0xfffff", "0xDEFG", settings) + _, err = apiByVidPid(context.Background(), "0xfffff", "0xDEFG", settings) require.NotNil(t, err) } @@ -69,7 +70,7 @@ func TestGetByVidPidNotFound(t *testing.T) { defer ts.Close() vidPidURL = ts.URL - res, err := apiByVidPid("0x0420", "0x0069", settings) + res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings) require.NoError(t, err) require.Empty(t, res) } @@ -84,7 +85,7 @@ func TestGetByVidPid5xx(t *testing.T) { defer ts.Close() vidPidURL = ts.URL - res, err := apiByVidPid("0x0420", "0x0069", settings) + res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings) require.NotNil(t, err) require.Equal(t, "the server responded with status 500 Internal Server Error", err.Error()) require.Len(t, res, 0) @@ -99,7 +100,7 @@ func TestGetByVidPidMalformedResponse(t *testing.T) { defer ts.Close() vidPidURL = ts.URL - res, err := apiByVidPid("0x0420", "0x0069", settings) + res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings) require.NotNil(t, err) require.Equal(t, "wrong format in server response", err.Error()) require.Len(t, res, 0) @@ -107,7 +108,7 @@ func TestGetByVidPidMalformedResponse(t *testing.T) { func TestBoardDetectionViaAPIWithNonUSBPort(t *testing.T) { settings := configuration.NewSettings() - items, err := identifyViaCloudAPI(properties.NewMap(), settings) + items, err := identifyViaCloudAPI(context.Background(), properties.NewMap(), settings) require.NoError(t, err) require.Empty(t, items) } @@ -156,7 +157,7 @@ func TestBoardIdentifySorting(t *testing.T) { defer release() settings := configuration.NewSettings() - res, err := identify(pme, idPrefs, settings, true) + res, err := identify(context.Background(), pme, idPrefs, settings, true) require.NoError(t, err) require.NotNil(t, res) require.Len(t, res, 4) diff --git a/commands/service_check_for_updates.go b/commands/service_check_for_updates.go index 2b5c7a51b4e..cce5a4a02a1 100644 --- a/commands/service_check_for_updates.go +++ b/commands/service_check_for_updates.go @@ -43,7 +43,7 @@ func (s *arduinoCoreServerImpl) CheckForArduinoCLIUpdates(ctx context.Context, r inventory.WriteStore() }() - latestVersion, err := semver.Parse(s.getLatestRelease()) + latestVersion, err := semver.Parse(s.getLatestRelease(ctx)) if err != nil { return nil, err } @@ -82,8 +82,8 @@ func (s *arduinoCoreServerImpl) shouldCheckForUpdate(currentVersion *semver.Vers // getLatestRelease queries the official Arduino download server for the latest release, // if there are no errors or issues a version string is returned, in all other case an empty string. -func (s *arduinoCoreServerImpl) getLatestRelease() string { - client, err := s.settings.NewHttpClient() +func (s *arduinoCoreServerImpl) getLatestRelease(ctx context.Context) string { + client, err := s.settings.NewHttpClient(ctx) if err != nil { return "" } diff --git a/commands/service_library_download.go b/commands/service_library_download.go index 2384d59396f..4253be8cca1 100644 --- a/commands/service_library_download.go +++ b/commands/service_library_download.go @@ -82,11 +82,15 @@ func (s *arduinoCoreServerImpl) LibraryDownload(req *rpc.LibraryDownloadRequest, }) } -func downloadLibrary(ctx context.Context, downloadsDir *paths.Path, libRelease *librariesindex.Release, - downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB, queryParameter string, settings *configuration.Settings) error { - +func downloadLibrary( + ctx context.Context, + downloadsDir *paths.Path, libRelease *librariesindex.Release, + downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB, + queryParameter string, + settings *configuration.Settings, +) error { taskCB(&rpc.TaskProgress{Name: i18n.Tr("Downloading %s", libRelease)}) - config, err := settings.DownloaderConfig() + config, err := settings.DownloaderConfig(ctx) if err != nil { return &cmderrors.FailedDownloadError{Message: i18n.Tr("Can't download library"), Cause: err} } diff --git a/internal/arduino/resources/helpers_test.go b/internal/arduino/resources/helpers_test.go index 611de8dd518..ad1d6805254 100644 --- a/internal/arduino/resources/helpers_test.go +++ b/internal/arduino/resources/helpers_test.go @@ -55,7 +55,7 @@ func TestDownloadApplyUserAgentHeaderUsingConfig(t *testing.T) { settings := configuration.NewSettings() settings.Set("network.user_agent_ext", goldUserAgentValue) - config, err := settings.DownloaderConfig() + config, err := settings.DownloaderConfig(context.Background()) require.NoError(t, err) err = r.Download(context.Background(), tmp, config, "", func(progress *rpc.DownloadProgress) {}, "") require.NoError(t, err) diff --git a/internal/cli/configuration/network.go b/internal/cli/configuration/network.go index c570d0a3b82..43b502a03fb 100644 --- a/internal/cli/configuration/network.go +++ b/internal/cli/configuration/network.go @@ -16,18 +16,21 @@ package configuration import ( + "context" "errors" "fmt" "net/http" "net/url" "os" "runtime" + "strings" "time" "github.com/arduino/arduino-cli/commands/cmderrors" "github.com/arduino/arduino-cli/internal/i18n" "github.com/arduino/arduino-cli/internal/version" "go.bug.st/downloader/v2" + "google.golang.org/grpc/metadata" ) // UserAgent returns the user agent (mainly used by HTTP clients) @@ -84,17 +87,23 @@ func (settings *Settings) NetworkProxy() (*url.URL, error) { } // NewHttpClient returns a new http client for use in the arduino-cli -func (settings *Settings) NewHttpClient() (*http.Client, error) { +func (settings *Settings) NewHttpClient(ctx context.Context) (*http.Client, error) { proxy, err := settings.NetworkProxy() if err != nil { return nil, err } + userAgent := settings.UserAgent() + if md, ok := metadata.FromIncomingContext(ctx); ok { + if extraUserAgent := strings.Join(md.Get("user-agent"), " "); extraUserAgent != "" { + userAgent += " " + extraUserAgent + } + } return &http.Client{ Transport: &httpClientRoundTripper{ transport: &http.Transport{ Proxy: http.ProxyURL(proxy), }, - userAgent: settings.UserAgent(), + userAgent: userAgent, }, Timeout: settings.ConnectionTimeout(), }, nil @@ -111,8 +120,8 @@ func (h *httpClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, e } // DownloaderConfig returns the downloader configuration based on current settings. -func (settings *Settings) DownloaderConfig() (downloader.Config, error) { - httpClient, err := settings.NewHttpClient() +func (settings *Settings) DownloaderConfig(ctx context.Context) (downloader.Config, error) { + httpClient, err := settings.NewHttpClient(ctx) if err != nil { return downloader.Config{}, &cmderrors.InvalidArgumentError{ Message: i18n.Tr("Could not connect via HTTP"), diff --git a/internal/cli/configuration/network_test.go b/internal/cli/configuration/network_test.go index 68cc84fdd59..563c0414589 100644 --- a/internal/cli/configuration/network_test.go +++ b/internal/cli/configuration/network_test.go @@ -16,6 +16,7 @@ package configuration_test import ( + "context" "fmt" "io" "net/http" @@ -35,7 +36,7 @@ func TestUserAgentHeader(t *testing.T) { settings := configuration.NewSettings() require.NoError(t, settings.Set("network.user_agent_ext", "test-user-agent")) - client, err := settings.NewHttpClient() + client, err := settings.NewHttpClient(context.Background()) require.NoError(t, err) request, err := http.NewRequest("GET", ts.URL, nil) @@ -59,7 +60,7 @@ func TestProxy(t *testing.T) { settings := configuration.NewSettings() settings.Set("network.proxy", ts.URL) - client, err := settings.NewHttpClient() + client, err := settings.NewHttpClient(context.Background()) require.NoError(t, err) request, err := http.NewRequest("GET", "http://arduino.cc", nil) @@ -83,7 +84,7 @@ func TestConnectionTimeout(t *testing.T) { if timeout != 0 { require.NoError(t, settings.Set("network.connection_timeout", "2s")) } - client, err := settings.NewHttpClient() + client, err := settings.NewHttpClient(context.Background()) require.NoError(t, err) request, err := http.NewRequest("GET", "http://arduino.cc", nil)