Skip to content

Commit

Permalink
Moved user-agent extraction deep in configuration.HttpClient
Browse files Browse the repository at this point in the history
This allows the extraction of the user-agent in a single place. Also it
forces the context passing on all operations that requires access to
network.
  • Loading branch information
cmaglie committed Jan 22, 2025
1 parent fe84afb commit 4f73e4c
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 60 deletions.
34 changes: 5 additions & 29 deletions commands/instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
paths "github.com/arduino/go-paths-helper"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand All @@ -64,29 +63,6 @@ func installTool(ctx context.Context, pm *packagemanager.PackageManager, tool *c

// Create a new Instance ready to be initialized, supporting directories are also created.
func (s *arduinoCoreServerImpl) Create(ctx context.Context, req *rpc.CreateRequest) (*rpc.CreateResponse, error) {
var userAgent string
if md, ok := metadata.FromIncomingContext(ctx); ok {
userAgent = strings.Join(md.Get("user-agent"), " ")
if userAgent != "" {
// s.SettingsGetValue() returns an error if the key does not exist and for this reason we are accessing
// network.user_agent_ext directly from s.settings.ExtraUserAgent() to set it
if s.settings.ExtraUserAgent() == "" {
if strings.Contains(userAgent, "arduino-ide/2") {
// needed for analytics purposes
userAgent = userAgent + " daemon"
}
_, err := s.SettingsSetValue(ctx, &rpc.SettingsSetValueRequest{
Key: "network.user_agent_ext",
ValueFormat: "cli",
EncodedValue: userAgent,
})
if err != nil {
return nil, err
}
}
}
}

// Setup downloads directory
downloadsDir := s.settings.DownloadsDir()
if downloadsDir.NotExist() {
Expand All @@ -107,11 +83,11 @@ func (s *arduinoCoreServerImpl) Create(ctx context.Context, req *rpc.CreateReque
}
}

config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(ctx)
if err != nil {
return nil, err
}
inst, err := instances.Create(dataDir, packagesDir, userPackagesDir, downloadsDir, userAgent, config)
inst, err := instances.Create(dataDir, packagesDir, userPackagesDir, downloadsDir, "", config)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -395,7 +371,7 @@ func (s *arduinoCoreServerImpl) Init(req *rpc.InitRequest, stream rpc.ArduinoCor
responseError(err.GRPCStatus())
continue
}
config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(ctx)
if err != nil {
taskCallback(&rpc.TaskProgress{Name: i18n.Tr("Error downloading library %s", libraryRef)})
e := &cmderrors.FailedLibraryInstallError{Cause: err}
Expand Down Expand Up @@ -516,7 +492,7 @@ func (s *arduinoCoreServerImpl) UpdateLibrariesIndex(req *rpc.UpdateLibrariesInd
}

// Perform index update
config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(stream.Context())
if err != nil {
return err
}
Expand Down Expand Up @@ -626,7 +602,7 @@ func (s *arduinoCoreServerImpl) UpdateIndex(req *rpc.UpdateIndexRequest, stream
}
}

config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(stream.Context())
if err != nil {
downloadCB.Start(u, i18n.Tr("Downloading index: %s", filepath.Base(URL.Path)))
downloadCB.End(false, i18n.Tr("Invalid network configuration: %s", err))
Expand Down
18 changes: 9 additions & 9 deletions commands/service_board_identify.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar
defer release()

props := properties.NewFromHashmap(req.GetProperties())
res, err := identify(pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection())
res, err := identify(ctx, pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection())
if err != nil {
return nil, err
}
Expand All @@ -58,7 +58,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar
}

// identify returns a list of boards checking first the installed platforms or the Cloud API
func identify(pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) {
func identify(ctx context.Context, pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) {
if properties == nil {
return nil, nil
}
Expand Down Expand Up @@ -90,7 +90,7 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings
// if installed cores didn't recognize the board, try querying
// the builder API if the board is a USB device port
if len(boards) == 0 && !skipCloudAPI && !settings.SkipCloudApiForBoardDetection() {
items, err := identifyViaCloudAPI(properties, settings)
items, err := identifyViaCloudAPI(ctx, properties, settings)
if err != nil {
// this is bad, but keep going
logrus.WithError(err).Debug("Error querying builder API")
Expand Down Expand Up @@ -119,22 +119,22 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings
return boards, nil
}

func identifyViaCloudAPI(props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func identifyViaCloudAPI(ctx context.Context, props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
// If the port is not USB do not try identification via cloud
if !props.ContainsKey("vid") || !props.ContainsKey("pid") {
return nil, nil
}

logrus.Debug("Querying builder API for board identification...")
return cachedAPIByVidPid(props.Get("vid"), props.Get("pid"), settings)
return cachedAPIByVidPid(ctx, props.Get("vid"), props.Get("pid"), settings)
}

var (
vidPidURL = "https://builder.arduino.cc/v3/boards/byVidPid"
validVidPid = regexp.MustCompile(`0[xX][a-fA-F\d]{4}`)
)

func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func cachedAPIByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
var resp []*rpc.BoardListItem

cacheKey := fmt.Sprintf("cache.builder-api.v3/boards/byvid/pid/%s/%s", vid, pid)
Expand All @@ -148,7 +148,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp
}
}

resp, err := apiByVidPid(vid, pid, settings) // Perform API requrest
resp, err := apiByVidPid(ctx, vid, pid, settings) // Perform API requrest

if err == nil {
if cachedResp, err := json.Marshal(resp); err == nil {
Expand All @@ -160,7 +160,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp
return resp, err
}

func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func apiByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
// ensure vid and pid are valid before hitting the API
if !validVidPid.MatchString(vid) {
return nil, errors.New(i18n.Tr("Invalid vid value: '%s'", vid))
Expand All @@ -173,7 +173,7 @@ func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.Boar
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")

httpClient, err := settings.NewHttpClient()
httpClient, err := settings.NewHttpClient(ctx)
if err != nil {
return nil, fmt.Errorf("%s: %w", i18n.Tr("failed to initialize http client"), err)
}
Expand Down
15 changes: 8 additions & 7 deletions commands/service_board_identify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package commands

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -48,15 +49,15 @@ func TestGetByVidPid(t *testing.T) {

vidPidURL = ts.URL
settings := configuration.NewSettings()
res, err := apiByVidPid("0xf420", "0XF069", settings)
res, err := apiByVidPid(context.Background(), "0xf420", "0XF069", settings)
require.Nil(t, err)
require.Len(t, res, 1)
require.Equal(t, "Arduino/Genuino MKR1000", res[0].GetName())
require.Equal(t, "arduino:samd:mkr1000", res[0].GetFqbn())

// wrong vid (too long), wrong pid (not an hex value)

_, err = apiByVidPid("0xfffff", "0xDEFG", settings)
_, err = apiByVidPid(context.Background(), "0xfffff", "0xDEFG", settings)
require.NotNil(t, err)
}

Expand All @@ -69,7 +70,7 @@ func TestGetByVidPidNotFound(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NoError(t, err)
require.Empty(t, res)
}
Expand All @@ -84,7 +85,7 @@ func TestGetByVidPid5xx(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NotNil(t, err)
require.Equal(t, "the server responded with status 500 Internal Server Error", err.Error())
require.Len(t, res, 0)
Expand All @@ -99,15 +100,15 @@ func TestGetByVidPidMalformedResponse(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NotNil(t, err)
require.Equal(t, "wrong format in server response", err.Error())
require.Len(t, res, 0)
}

func TestBoardDetectionViaAPIWithNonUSBPort(t *testing.T) {
settings := configuration.NewSettings()
items, err := identifyViaCloudAPI(properties.NewMap(), settings)
items, err := identifyViaCloudAPI(context.Background(), properties.NewMap(), settings)
require.NoError(t, err)
require.Empty(t, items)
}
Expand Down Expand Up @@ -156,7 +157,7 @@ func TestBoardIdentifySorting(t *testing.T) {
defer release()

settings := configuration.NewSettings()
res, err := identify(pme, idPrefs, settings, true)
res, err := identify(context.Background(), pme, idPrefs, settings, true)
require.NoError(t, err)
require.NotNil(t, res)
require.Len(t, res, 4)
Expand Down
6 changes: 3 additions & 3 deletions commands/service_check_for_updates.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (s *arduinoCoreServerImpl) CheckForArduinoCLIUpdates(ctx context.Context, r
inventory.WriteStore()
}()

latestVersion, err := semver.Parse(s.getLatestRelease())
latestVersion, err := semver.Parse(s.getLatestRelease(ctx))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -82,8 +82,8 @@ func (s *arduinoCoreServerImpl) shouldCheckForUpdate(currentVersion *semver.Vers

// getLatestRelease queries the official Arduino download server for the latest release,
// if there are no errors or issues a version string is returned, in all other case an empty string.
func (s *arduinoCoreServerImpl) getLatestRelease() string {
client, err := s.settings.NewHttpClient()
func (s *arduinoCoreServerImpl) getLatestRelease(ctx context.Context) string {
client, err := s.settings.NewHttpClient(ctx)
if err != nil {
return ""
}
Expand Down
12 changes: 8 additions & 4 deletions commands/service_library_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,15 @@ func (s *arduinoCoreServerImpl) LibraryDownload(req *rpc.LibraryDownloadRequest,
})
}

func downloadLibrary(ctx context.Context, downloadsDir *paths.Path, libRelease *librariesindex.Release,
downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB, queryParameter string, settings *configuration.Settings) error {

func downloadLibrary(
ctx context.Context,
downloadsDir *paths.Path, libRelease *librariesindex.Release,
downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB,
queryParameter string,
settings *configuration.Settings,
) error {
taskCB(&rpc.TaskProgress{Name: i18n.Tr("Downloading %s", libRelease)})
config, err := settings.DownloaderConfig()
config, err := settings.DownloaderConfig(ctx)
if err != nil {
return &cmderrors.FailedDownloadError{Message: i18n.Tr("Can't download library"), Cause: err}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/arduino/resources/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestDownloadApplyUserAgentHeaderUsingConfig(t *testing.T) {

settings := configuration.NewSettings()
settings.Set("network.user_agent_ext", goldUserAgentValue)
config, err := settings.DownloaderConfig()
config, err := settings.DownloaderConfig(context.Background())
require.NoError(t, err)
err = r.Download(context.Background(), tmp, config, "", func(progress *rpc.DownloadProgress) {}, "")
require.NoError(t, err)
Expand Down
17 changes: 13 additions & 4 deletions internal/cli/configuration/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
package configuration

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"strings"
"time"

"github.com/arduino/arduino-cli/commands/cmderrors"
"github.com/arduino/arduino-cli/internal/i18n"
"github.com/arduino/arduino-cli/internal/version"
"go.bug.st/downloader/v2"
"google.golang.org/grpc/metadata"
)

// UserAgent returns the user agent (mainly used by HTTP clients)
Expand Down Expand Up @@ -84,17 +87,23 @@ func (settings *Settings) NetworkProxy() (*url.URL, error) {
}

// NewHttpClient returns a new http client for use in the arduino-cli
func (settings *Settings) NewHttpClient() (*http.Client, error) {
func (settings *Settings) NewHttpClient(ctx context.Context) (*http.Client, error) {
proxy, err := settings.NetworkProxy()
if err != nil {
return nil, err
}
userAgent := settings.UserAgent()
if md, ok := metadata.FromIncomingContext(ctx); ok {
if extraUserAgent := strings.Join(md.Get("user-agent"), " "); extraUserAgent != "" {
userAgent += " " + extraUserAgent
}
}
return &http.Client{
Transport: &httpClientRoundTripper{
transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
userAgent: settings.UserAgent(),
userAgent: userAgent,
},
Timeout: settings.ConnectionTimeout(),
}, nil
Expand All @@ -111,8 +120,8 @@ func (h *httpClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
}

// DownloaderConfig returns the downloader configuration based on current settings.
func (settings *Settings) DownloaderConfig() (downloader.Config, error) {
httpClient, err := settings.NewHttpClient()
func (settings *Settings) DownloaderConfig(ctx context.Context) (downloader.Config, error) {
httpClient, err := settings.NewHttpClient(ctx)
if err != nil {
return downloader.Config{}, &cmderrors.InvalidArgumentError{
Message: i18n.Tr("Could not connect via HTTP"),
Expand Down
7 changes: 4 additions & 3 deletions internal/cli/configuration/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package configuration_test

import (
"context"
"fmt"
"io"
"net/http"
Expand All @@ -35,7 +36,7 @@ func TestUserAgentHeader(t *testing.T) {

settings := configuration.NewSettings()
require.NoError(t, settings.Set("network.user_agent_ext", "test-user-agent"))
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", ts.URL, nil)
Expand All @@ -59,7 +60,7 @@ func TestProxy(t *testing.T) {

settings := configuration.NewSettings()
settings.Set("network.proxy", ts.URL)
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", "http://arduino.cc", nil)
Expand All @@ -83,7 +84,7 @@ func TestConnectionTimeout(t *testing.T) {
if timeout != 0 {
require.NoError(t, settings.Set("network.connection_timeout", "2s"))
}
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", "http://arduino.cc", nil)
Expand Down

0 comments on commit 4f73e4c

Please sign in to comment.