Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata retrieved from the context to the user agent when a new HTTP client is created #2789

Merged
merged 6 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions commands/instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (s *arduinoCoreServerImpl) Create(ctx context.Context, req *rpc.CreateReque
}
}

config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -377,7 +377,7 @@ func (s *arduinoCoreServerImpl) Init(req *rpc.InitRequest, stream rpc.ArduinoCor
responseError(err.GRPCStatus())
continue
}
config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(ctx)
if err != nil {
taskCallback(&rpc.TaskProgress{Name: i18n.Tr("Error downloading library %s", libraryRef)})
e := &cmderrors.FailedLibraryInstallError{Cause: err}
Expand Down Expand Up @@ -498,7 +498,7 @@ func (s *arduinoCoreServerImpl) UpdateLibrariesIndex(req *rpc.UpdateLibrariesInd
}

// Perform index update
config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(stream.Context())
if err != nil {
return err
}
Expand Down Expand Up @@ -608,7 +608,7 @@ func (s *arduinoCoreServerImpl) UpdateIndex(req *rpc.UpdateIndexRequest, stream
}
}

config, err := s.settings.DownloaderConfig()
config, err := s.settings.DownloaderConfig(stream.Context())
if err != nil {
downloadCB.Start(u, i18n.Tr("Downloading index: %s", filepath.Base(URL.Path)))
downloadCB.End(false, i18n.Tr("Invalid network configuration: %s", err))
Expand Down
18 changes: 9 additions & 9 deletions commands/service_board_identify.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar
defer release()

props := properties.NewFromHashmap(req.GetProperties())
res, err := identify(pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection())
res, err := identify(ctx, pme, props, s.settings, !req.GetUseCloudApiForUnknownBoardDetection())
if err != nil {
return nil, err
}
Expand All @@ -58,7 +58,7 @@ func (s *arduinoCoreServerImpl) BoardIdentify(ctx context.Context, req *rpc.Boar
}

// identify returns a list of boards checking first the installed platforms or the Cloud API
func identify(pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) {
func identify(ctx context.Context, pme *packagemanager.Explorer, properties *properties.Map, settings *configuration.Settings, skipCloudAPI bool) ([]*rpc.BoardListItem, error) {
if properties == nil {
return nil, nil
}
Expand Down Expand Up @@ -90,7 +90,7 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings
// if installed cores didn't recognize the board, try querying
// the builder API if the board is a USB device port
if len(boards) == 0 && !skipCloudAPI && !settings.SkipCloudApiForBoardDetection() {
items, err := identifyViaCloudAPI(properties, settings)
items, err := identifyViaCloudAPI(ctx, properties, settings)
if err != nil {
// this is bad, but keep going
logrus.WithError(err).Debug("Error querying builder API")
Expand Down Expand Up @@ -119,22 +119,22 @@ func identify(pme *packagemanager.Explorer, properties *properties.Map, settings
return boards, nil
}

func identifyViaCloudAPI(props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func identifyViaCloudAPI(ctx context.Context, props *properties.Map, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
// If the port is not USB do not try identification via cloud
if !props.ContainsKey("vid") || !props.ContainsKey("pid") {
return nil, nil
}

logrus.Debug("Querying builder API for board identification...")
return cachedAPIByVidPid(props.Get("vid"), props.Get("pid"), settings)
return cachedAPIByVidPid(ctx, props.Get("vid"), props.Get("pid"), settings)
}

var (
vidPidURL = "https://builder.arduino.cc/v3/boards/byVidPid"
validVidPid = regexp.MustCompile(`0[xX][a-fA-F\d]{4}`)
)

func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func cachedAPIByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
var resp []*rpc.BoardListItem

cacheKey := fmt.Sprintf("cache.builder-api.v3/boards/byvid/pid/%s/%s", vid, pid)
Expand All @@ -148,7 +148,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp
}
}

resp, err := apiByVidPid(vid, pid, settings) // Perform API requrest
resp, err := apiByVidPid(ctx, vid, pid, settings) // Perform API requrest

if err == nil {
if cachedResp, err := json.Marshal(resp); err == nil {
Expand All @@ -160,7 +160,7 @@ func cachedAPIByVidPid(vid, pid string, settings *configuration.Settings) ([]*rp
return resp, err
}

func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
func apiByVidPid(ctx context.Context, vid, pid string, settings *configuration.Settings) ([]*rpc.BoardListItem, error) {
// ensure vid and pid are valid before hitting the API
if !validVidPid.MatchString(vid) {
return nil, errors.New(i18n.Tr("Invalid vid value: '%s'", vid))
Expand All @@ -173,7 +173,7 @@ func apiByVidPid(vid, pid string, settings *configuration.Settings) ([]*rpc.Boar
req, _ := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")

httpClient, err := settings.NewHttpClient()
httpClient, err := settings.NewHttpClient(ctx)
if err != nil {
return nil, fmt.Errorf("%s: %w", i18n.Tr("failed to initialize http client"), err)
}
Expand Down
15 changes: 8 additions & 7 deletions commands/service_board_identify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package commands

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -48,15 +49,15 @@ func TestGetByVidPid(t *testing.T) {

vidPidURL = ts.URL
settings := configuration.NewSettings()
res, err := apiByVidPid("0xf420", "0XF069", settings)
res, err := apiByVidPid(context.Background(), "0xf420", "0XF069", settings)
require.Nil(t, err)
require.Len(t, res, 1)
require.Equal(t, "Arduino/Genuino MKR1000", res[0].GetName())
require.Equal(t, "arduino:samd:mkr1000", res[0].GetFqbn())

// wrong vid (too long), wrong pid (not an hex value)

_, err = apiByVidPid("0xfffff", "0xDEFG", settings)
_, err = apiByVidPid(context.Background(), "0xfffff", "0xDEFG", settings)
require.NotNil(t, err)
}

Expand All @@ -69,7 +70,7 @@ func TestGetByVidPidNotFound(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NoError(t, err)
require.Empty(t, res)
}
Expand All @@ -84,7 +85,7 @@ func TestGetByVidPid5xx(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NotNil(t, err)
require.Equal(t, "the server responded with status 500 Internal Server Error", err.Error())
require.Len(t, res, 0)
Expand All @@ -99,15 +100,15 @@ func TestGetByVidPidMalformedResponse(t *testing.T) {
defer ts.Close()

vidPidURL = ts.URL
res, err := apiByVidPid("0x0420", "0x0069", settings)
res, err := apiByVidPid(context.Background(), "0x0420", "0x0069", settings)
require.NotNil(t, err)
require.Equal(t, "wrong format in server response", err.Error())
require.Len(t, res, 0)
}

func TestBoardDetectionViaAPIWithNonUSBPort(t *testing.T) {
settings := configuration.NewSettings()
items, err := identifyViaCloudAPI(properties.NewMap(), settings)
items, err := identifyViaCloudAPI(context.Background(), properties.NewMap(), settings)
require.NoError(t, err)
require.Empty(t, items)
}
Expand Down Expand Up @@ -156,7 +157,7 @@ func TestBoardIdentifySorting(t *testing.T) {
defer release()

settings := configuration.NewSettings()
res, err := identify(pme, idPrefs, settings, true)
res, err := identify(context.Background(), pme, idPrefs, settings, true)
require.NoError(t, err)
require.NotNil(t, res)
require.Len(t, res, 4)
Expand Down
6 changes: 3 additions & 3 deletions commands/service_check_for_updates.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (s *arduinoCoreServerImpl) CheckForArduinoCLIUpdates(ctx context.Context, r
inventory.WriteStore()
}()

latestVersion, err := semver.Parse(s.getLatestRelease())
latestVersion, err := semver.Parse(s.getLatestRelease(ctx))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -82,8 +82,8 @@ func (s *arduinoCoreServerImpl) shouldCheckForUpdate(currentVersion *semver.Vers

// getLatestRelease queries the official Arduino download server for the latest release,
// if there are no errors or issues a version string is returned, in all other case an empty string.
func (s *arduinoCoreServerImpl) getLatestRelease() string {
client, err := s.settings.NewHttpClient()
func (s *arduinoCoreServerImpl) getLatestRelease(ctx context.Context) string {
client, err := s.settings.NewHttpClient(ctx)
if err != nil {
return ""
}
Expand Down
12 changes: 8 additions & 4 deletions commands/service_library_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,15 @@ func (s *arduinoCoreServerImpl) LibraryDownload(req *rpc.LibraryDownloadRequest,
})
}

func downloadLibrary(ctx context.Context, downloadsDir *paths.Path, libRelease *librariesindex.Release,
downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB, queryParameter string, settings *configuration.Settings) error {

func downloadLibrary(
ctx context.Context,
downloadsDir *paths.Path, libRelease *librariesindex.Release,
downloadCB rpc.DownloadProgressCB, taskCB rpc.TaskProgressCB,
queryParameter string,
settings *configuration.Settings,
) error {
taskCB(&rpc.TaskProgress{Name: i18n.Tr("Downloading %s", libRelease)})
config, err := settings.DownloaderConfig()
config, err := settings.DownloaderConfig(ctx)
if err != nil {
return &cmderrors.FailedDownloadError{Message: i18n.Tr("Can't download library"), Cause: err}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/arduino/resources/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestDownloadApplyUserAgentHeaderUsingConfig(t *testing.T) {

settings := configuration.NewSettings()
settings.Set("network.user_agent_ext", goldUserAgentValue)
config, err := settings.DownloaderConfig()
config, err := settings.DownloaderConfig(context.Background())
require.NoError(t, err)
err = r.Download(context.Background(), tmp, config, "", func(progress *rpc.DownloadProgress) {}, "")
require.NoError(t, err)
Expand Down
17 changes: 13 additions & 4 deletions internal/cli/configuration/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
package configuration

import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"runtime"
"strings"
"time"

"github.com/arduino/arduino-cli/commands/cmderrors"
"github.com/arduino/arduino-cli/internal/i18n"
"github.com/arduino/arduino-cli/internal/version"
"go.bug.st/downloader/v2"
"google.golang.org/grpc/metadata"
)

// UserAgent returns the user agent (mainly used by HTTP clients)
Expand Down Expand Up @@ -84,17 +87,23 @@ func (settings *Settings) NetworkProxy() (*url.URL, error) {
}

// NewHttpClient returns a new http client for use in the arduino-cli
func (settings *Settings) NewHttpClient() (*http.Client, error) {
func (settings *Settings) NewHttpClient(ctx context.Context) (*http.Client, error) {
proxy, err := settings.NetworkProxy()
if err != nil {
return nil, err
}
userAgent := settings.UserAgent()
if md, ok := metadata.FromIncomingContext(ctx); ok {
if extraUserAgent := strings.Join(md.Get("user-agent"), " "); extraUserAgent != "" {
userAgent += " " + extraUserAgent
}
}
return &http.Client{
Transport: &httpClientRoundTripper{
transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
userAgent: settings.UserAgent(),
userAgent: userAgent,
},
Timeout: settings.ConnectionTimeout(),
}, nil
Expand All @@ -111,8 +120,8 @@ func (h *httpClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
}

// DownloaderConfig returns the downloader configuration based on current settings.
func (settings *Settings) DownloaderConfig() (downloader.Config, error) {
httpClient, err := settings.NewHttpClient()
func (settings *Settings) DownloaderConfig(ctx context.Context) (downloader.Config, error) {
httpClient, err := settings.NewHttpClient(ctx)
if err != nil {
return downloader.Config{}, &cmderrors.InvalidArgumentError{
Message: i18n.Tr("Could not connect via HTTP"),
Expand Down
7 changes: 4 additions & 3 deletions internal/cli/configuration/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package configuration_test

import (
"context"
"fmt"
"io"
"net/http"
Expand All @@ -35,7 +36,7 @@ func TestUserAgentHeader(t *testing.T) {

settings := configuration.NewSettings()
require.NoError(t, settings.Set("network.user_agent_ext", "test-user-agent"))
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", ts.URL, nil)
Expand All @@ -59,7 +60,7 @@ func TestProxy(t *testing.T) {

settings := configuration.NewSettings()
settings.Set("network.proxy", ts.URL)
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", "http://arduino.cc", nil)
Expand All @@ -83,7 +84,7 @@ func TestConnectionTimeout(t *testing.T) {
if timeout != 0 {
require.NoError(t, settings.Set("network.connection_timeout", "2s"))
}
client, err := settings.NewHttpClient()
client, err := settings.NewHttpClient(context.Background())
require.NoError(t, err)

request, err := http.NewRequest("GET", "http://arduino.cc", nil)
Expand Down
6 changes: 5 additions & 1 deletion internal/integrationtest/arduino-cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,11 @@ func (cli *ArduinoCLI) StartDaemon(verbose bool) string {
for retries := 5; retries > 0; retries-- {
time.Sleep(time.Second)

conn, err := grpc.NewClient(cli.daemonAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.NewClient(
cli.daemonAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUserAgent("cli-test/0.0.0"),
)
if err != nil {
connErr = err
continue
Expand Down
Loading
Loading